Skip to content

Commit 623992f

Browse files
committed
[mlir] [dataflow] : Improve the time and space footprint of data flow.
MLIR's data flow analysis (especially dense data flow analysis) constructs a lattice at every lattice anchor (which, for dense data flow, means every program point). As the program grows larger, the time and space complexity can become unmanageable. However, in many programs, the lattice values at numerous lattice anchors are actually identical. We can leverage this observation to improve the complexity of data flow analysis. This patch introducing equivalence lattice anchor to group lattice anchors that must contains identical lattice on certain state to improve the time and space footprint of data flow.
1 parent 862e719 commit 623992f

File tree

6 files changed

+182
-9
lines changed

6 files changed

+182
-9
lines changed

mlir/include/mlir/Analysis/DataFlow/DenseAnalysis.h

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,14 @@ class AbstractDenseForwardDataFlowAnalysis : public DataFlowAnalysis {
7373
/// may modify the program state; that is, every operation and block.
7474
LogicalResult initialize(Operation *top) override;
7575

76+
/// Initialize lattice anchor equivalence class from the provided top-level
77+
/// operation.
78+
///
79+
/// This function will union lattice anchor to same equivalent class if the
80+
/// analysis can determine the lattice content of lattice anchor is
81+
/// necessarily identical under the corrensponding lattice type.
82+
virtual void initializeEquivalentLatticeAnchor(Operation *top) override;
83+
7684
/// Visit a program point that modifies the state of the program. If the
7785
/// program point is at the beginning of a block, then the state is propagated
7886
/// from control-flow predecessors or callsites. If the operation before
@@ -114,6 +122,11 @@ class AbstractDenseForwardDataFlowAnalysis : public DataFlowAnalysis {
114122
/// operation transfer function.
115123
virtual LogicalResult processOperation(Operation *op);
116124

125+
/// Visit an operation. If this analysis can confirm that lattice content
126+
/// of lattice anchors around operation are necessarily identical, join
127+
/// them into the same equivalent class.
128+
virtual void buildOperationEquivalentLatticeAnchor(Operation *op) { return; }
129+
117130
/// Propagate the dense lattice forward along the control flow edge from
118131
/// `regionFrom` to `regionTo` regions of the `branch` operation. `nullopt`
119132
/// values correspond to control flow branches originating at or targeting the
@@ -310,6 +323,14 @@ class AbstractDenseBackwardDataFlowAnalysis : public DataFlowAnalysis {
310323
/// may modify the program state; that is, every operation and block.
311324
LogicalResult initialize(Operation *top) override;
312325

326+
/// Initialize lattice anchor equivalence class from the provided top-level
327+
/// operation.
328+
///
329+
/// This function will union lattice anchor to same equivalent class if the
330+
/// analysis can determine the lattice content of lattice anchor is
331+
/// necessarily identical under the corrensponding lattice type.
332+
virtual void initializeEquivalentLatticeAnchor(Operation *top) override;
333+
313334
/// Visit a program point that modifies the state of the program. The state is
314335
/// propagated along control flow directions for branch-, region- and
315336
/// call-based control flow using the respective interfaces. For other
@@ -353,6 +374,11 @@ class AbstractDenseBackwardDataFlowAnalysis : public DataFlowAnalysis {
353374
/// transfer function.
354375
virtual LogicalResult processOperation(Operation *op);
355376

377+
/// Visit an operation. If this analysis can confirm that lattice content
378+
/// of lattice anchors around operation are necessarily identical, join
379+
/// them into the same equivalent class.
380+
virtual void buildOperationEquivalentLatticeAnchor(Operation *op) { return; }
381+
356382
/// Propagate the dense lattice backwards along the control flow edge from
357383
/// `regionFrom` to `regionTo` regions of the `branch` operation. `nullopt`
358384
/// values correspond to control flow branches originating at or targeting the

mlir/include/mlir/Analysis/DataFlowFramework.h

Lines changed: 108 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818

1919
#include "mlir/IR/Operation.h"
2020
#include "mlir/Support/StorageUniquer.h"
21+
#include "llvm/ADT/EquivalenceClasses.h"
2122
#include "llvm/ADT/Hashing.h"
2223
#include "llvm/ADT/SetVector.h"
2324
#include "llvm/Support/Compiler.h"
@@ -265,6 +266,14 @@ struct LatticeAnchor
265266
/// Forward declaration of the data-flow analysis class.
266267
class DataFlowAnalysis;
267268

269+
} // namespace mlir
270+
271+
template <>
272+
struct llvm::DenseMapInfo<mlir::LatticeAnchor>
273+
: public llvm::DenseMapInfo<mlir::LatticeAnchor::ParentTy> {};
274+
275+
namespace mlir {
276+
268277
//===----------------------------------------------------------------------===//
269278
// DataFlowConfig
270279
//===----------------------------------------------------------------------===//
@@ -332,7 +341,9 @@ class DataFlowSolver {
332341
/// does not exist.
333342
template <typename StateT, typename AnchorT>
334343
const StateT *lookupState(AnchorT anchor) const {
335-
const auto &mapIt = analysisStates.find(LatticeAnchor(anchor));
344+
LatticeAnchor latticeAnchor =
345+
getLeaderAnchorOrSelf<StateT>(LatticeAnchor(anchor));
346+
const auto &mapIt = analysisStates.find(latticeAnchor);
336347
if (mapIt == analysisStates.end())
337348
return nullptr;
338349
auto it = mapIt->second.find(TypeID::get<StateT>());
@@ -344,12 +355,34 @@ class DataFlowSolver {
344355
/// Erase any analysis state associated with the given lattice anchor.
345356
template <typename AnchorT>
346357
void eraseState(AnchorT anchor) {
347-
LatticeAnchor la(anchor);
348-
analysisStates.erase(LatticeAnchor(anchor));
358+
LatticeAnchor latticeAnchor(anchor);
359+
360+
// Update equivalentAnchorMap.
361+
for (auto &&[TypeId, eqClass] : equivalentAnchorMap) {
362+
if (!eqClass.contains(latticeAnchor)) {
363+
continue;
364+
}
365+
llvm::EquivalenceClasses<LatticeAnchor>::member_iterator leaderIt =
366+
eqClass.findLeader(latticeAnchor);
367+
368+
// Update analysis states with new leader if needed.
369+
if (*leaderIt == latticeAnchor && ++leaderIt != eqClass.member_end()) {
370+
analysisStates[*leaderIt][TypeId] =
371+
std::move(analysisStates[latticeAnchor][TypeId]);
372+
}
373+
374+
eqClass.erase(latticeAnchor);
375+
}
376+
377+
// Update analysis states.
378+
analysisStates.erase(latticeAnchor);
349379
}
350380

351-
// Erase all analysis states
352-
void eraseAllStates() { analysisStates.clear(); }
381+
// Erase all analysis states.
382+
void eraseAllStates() {
383+
analysisStates.clear();
384+
equivalentAnchorMap.clear();
385+
}
353386

354387
/// Get a uniqued lattice anchor instance. If one is not present, it is
355388
/// created with the provided arguments.
@@ -399,6 +432,19 @@ class DataFlowSolver {
399432
template <typename StateT, typename AnchorT>
400433
StateT *getOrCreateState(AnchorT anchor);
401434

435+
/// Get leader lattice anchor in equivalence lattice anchor group, return
436+
/// input lattice anchor if input not found in equivalece lattice anchor group.
437+
template <typename StateT>
438+
LatticeAnchor getLeaderAnchorOrSelf(LatticeAnchor latticeAnchor) const;
439+
440+
/// Union input anchors under the given state.
441+
template <typename StateT, typename AnchorT>
442+
void unionLatticeAnchors(AnchorT anchor, AnchorT other);
443+
444+
/// Return given lattice is equivalent on given state.
445+
template <typename StateT>
446+
bool isEquivalent(LatticeAnchor lhs, LatticeAnchor rhs) const;
447+
402448
/// Propagate an update to an analysis state if it changed by pushing
403449
/// dependent work items to the back of the queue.
404450
/// This should only be used when DataFlowSolver is running.
@@ -429,10 +475,15 @@ class DataFlowSolver {
429475

430476
/// A type-erased map of lattice anchors to associated analysis states for
431477
/// first-class lattice anchors.
432-
DenseMap<LatticeAnchor, DenseMap<TypeID, std::unique_ptr<AnalysisState>>,
433-
DenseMapInfo<LatticeAnchor::ParentTy>>
478+
DenseMap<LatticeAnchor, DenseMap<TypeID, std::unique_ptr<AnalysisState>>>
434479
analysisStates;
435480

481+
/// A type-erased map of lattice type to the equivalet lattice anchors.
482+
/// Lattice anchors are considered equivalent under a certain lattice type if
483+
/// and only if, under this lattice type, the lattices pointed to by these
484+
/// lattice anchors necessarily contain identical value.
485+
DenseMap<TypeID, llvm::EquivalenceClasses<LatticeAnchor>> equivalentAnchorMap;
486+
436487
/// Allow the base child analysis class to access the internals of the solver.
437488
friend class DataFlowAnalysis;
438489
};
@@ -564,6 +615,14 @@ class DataFlowAnalysis {
564615
/// will provide a value for then.
565616
virtual LogicalResult visit(ProgramPoint *point) = 0;
566617

618+
/// Initialize lattice anchor equivalence class from the provided top-level
619+
/// operation.
620+
///
621+
/// This function will union lattice anchor to same equivalent class if the
622+
/// analysis can determine the lattice content of lattice anchor is
623+
/// necessarily identical under the corrensponding lattice type.
624+
virtual void initializeEquivalentLatticeAnchor(Operation *top) { return; }
625+
567626
protected:
568627
/// Create a dependency between the given analysis state and lattice anchor
569628
/// on this analysis.
@@ -584,6 +643,12 @@ class DataFlowAnalysis {
584643
return solver.getLatticeAnchor<AnchorT>(std::forward<Args>(args)...);
585644
}
586645

646+
/// Union input anchors under the given state.
647+
template <typename StateT, typename AnchorT>
648+
void unionLatticeAnchors(AnchorT anchor, AnchorT other) {
649+
return solver.unionLatticeAnchors<StateT>(anchor, other);
650+
}
651+
587652
/// Get the analysis state associated with the lattice anchor. The returned
588653
/// state is expected to be "write-only", and any updates need to be
589654
/// propagated by `propagateIfChanged`.
@@ -598,7 +663,9 @@ class DataFlowAnalysis {
598663
template <typename StateT, typename AnchorT>
599664
const StateT *getOrCreateFor(ProgramPoint *dependent, AnchorT anchor) {
600665
StateT *state = getOrCreate<StateT>(anchor);
601-
addDependency(state, dependent);
666+
if (!solver.isEquivalent<StateT>(LatticeAnchor(anchor),
667+
LatticeAnchor(dependent)))
668+
addDependency(state, dependent);
602669
return state;
603670
}
604671

@@ -644,10 +711,26 @@ AnalysisT *DataFlowSolver::load(Args &&...args) {
644711
return static_cast<AnalysisT *>(childAnalyses.back().get());
645712
}
646713

714+
template <typename StateT>
715+
LatticeAnchor
716+
DataFlowSolver::getLeaderAnchorOrSelf(LatticeAnchor latticeAnchor) const {
717+
const llvm::EquivalenceClasses<LatticeAnchor> eqClass =
718+
equivalentAnchorMap.lookup(TypeID::get<StateT>());
719+
llvm::EquivalenceClasses<LatticeAnchor>::member_iterator leaderIt =
720+
eqClass.findLeader(latticeAnchor);
721+
if (leaderIt != eqClass.member_end()) {
722+
return *leaderIt;
723+
}
724+
return latticeAnchor;
725+
}
726+
647727
template <typename StateT, typename AnchorT>
648728
StateT *DataFlowSolver::getOrCreateState(AnchorT anchor) {
729+
// Replace to leader anchor if found.
730+
LatticeAnchor latticeAnchor(anchor);
731+
latticeAnchor = getLeaderAnchorOrSelf<StateT>(latticeAnchor);
649732
std::unique_ptr<AnalysisState> &state =
650-
analysisStates[LatticeAnchor(anchor)][TypeID::get<StateT>()];
733+
analysisStates[latticeAnchor][TypeID::get<StateT>()];
651734
if (!state) {
652735
state = std::unique_ptr<StateT>(new StateT(anchor));
653736
#if LLVM_ENABLE_ABI_BREAKING_CHECKS
@@ -657,6 +740,22 @@ StateT *DataFlowSolver::getOrCreateState(AnchorT anchor) {
657740
return static_cast<StateT *>(state.get());
658741
}
659742

743+
template <typename StateT>
744+
bool DataFlowSolver::isEquivalent(LatticeAnchor lhs, LatticeAnchor rhs) const {
745+
const llvm::EquivalenceClasses<LatticeAnchor> eqClass =
746+
equivalentAnchorMap.lookup(TypeID::get<StateT>());
747+
if (!eqClass.contains(lhs) || !eqClass.contains(rhs))
748+
return false;
749+
return eqClass.isEquivalent(lhs, rhs);
750+
}
751+
752+
template <typename StateT, typename AnchorT>
753+
void DataFlowSolver::unionLatticeAnchors(AnchorT anchor, AnchorT other) {
754+
llvm::EquivalenceClasses<LatticeAnchor> &eqClass =
755+
equivalentAnchorMap[TypeID::get<StateT>()];
756+
eqClass.unionSets(LatticeAnchor(anchor), LatticeAnchor(other));
757+
}
758+
660759
inline raw_ostream &operator<<(raw_ostream &os, const AnalysisState &state) {
661760
state.print(os);
662761
return os;

mlir/lib/Analysis/DataFlow/DenseAnalysis.cpp

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,15 @@ using namespace mlir::dataflow;
2828
// AbstractDenseForwardDataFlowAnalysis
2929
//===----------------------------------------------------------------------===//
3030

31+
void AbstractDenseForwardDataFlowAnalysis::initializeEquivalentLatticeAnchor(
32+
Operation *top) {
33+
top->walk([&](Operation *op) {
34+
if (isa<RegionBranchOpInterface, CallOpInterface>(op))
35+
return;
36+
buildOperationEquivalentLatticeAnchor(op);
37+
});
38+
}
39+
3140
LogicalResult AbstractDenseForwardDataFlowAnalysis::initialize(Operation *top) {
3241
// Visit every operation and block.
3342
if (failed(processOperation(top)))
@@ -252,6 +261,15 @@ AbstractDenseForwardDataFlowAnalysis::getLatticeFor(ProgramPoint *dependent,
252261
// AbstractDenseBackwardDataFlowAnalysis
253262
//===----------------------------------------------------------------------===//
254263

264+
void AbstractDenseBackwardDataFlowAnalysis::initializeEquivalentLatticeAnchor(
265+
Operation *top) {
266+
top->walk([&](Operation *op) {
267+
if (isa<RegionBranchOpInterface, CallOpInterface>(op))
268+
return;
269+
buildOperationEquivalentLatticeAnchor(op);
270+
});
271+
}
272+
255273
LogicalResult
256274
AbstractDenseBackwardDataFlowAnalysis::initialize(Operation *top) {
257275
// Visit every operation and block.

mlir/lib/Analysis/DataFlowFramework.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,11 @@ LogicalResult DataFlowSolver::initializeAndRun(Operation *top) {
109109
isRunning = true;
110110
auto guard = llvm::make_scope_exit([&]() { isRunning = false; });
111111

112+
// Initialize equivalent lattice anchors.
113+
for (DataFlowAnalysis &analysis : llvm::make_pointee_range(childAnalyses)) {
114+
analysis.initializeEquivalentLatticeAnchor(top);
115+
}
116+
112117
// Initialize the analyses.
113118
for (DataFlowAnalysis &analysis : llvm::make_pointee_range(childAnalyses)) {
114119
DATAFLOW_DEBUG(llvm::dbgs()

mlir/test/lib/Analysis/DataFlow/TestDenseBackwardDataFlowAnalysis.cpp

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,11 @@ class NextAccessAnalysis : public DenseBackwardDataFlowAnalysis<NextAccess> {
7676
propagateIfChanged(lattice, lattice->setKnownToUnknown());
7777
}
7878

79+
/// Visit an operation. If this analysis can confirm that lattice content
80+
/// of lattice anchors around operation are necessarily identical, join
81+
/// them into the same equivalent class.
82+
void buildOperationEquivalentLatticeAnchor(Operation *op) override;
83+
7984
const bool assumeFuncReads;
8085
};
8186
} // namespace
@@ -141,6 +146,13 @@ LogicalResult NextAccessAnalysis::visitOperation(Operation *op,
141146
return success();
142147
}
143148

149+
void NextAccessAnalysis::buildOperationEquivalentLatticeAnchor(Operation *op) {
150+
if (isMemoryEffectFree(op)) {
151+
unionLatticeAnchors<NextAccess>(getProgramPointBefore(op),
152+
getProgramPointAfter(op));
153+
}
154+
}
155+
144156
void NextAccessAnalysis::visitCallControlFlowTransfer(
145157
CallOpInterface call, CallControlFlowAction action, const NextAccess &after,
146158
NextAccess *before) {

mlir/test/lib/Analysis/DataFlow/TestDenseForwardDataFlowAnalysis.cpp

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,11 @@ class LastModifiedAnalysis
7272
const LastModification &before,
7373
LastModification *after) override;
7474

75+
/// Visit an operation. If this analysis can confirm that lattice content
76+
/// of lattice anchors around operation are necessarily identical, join
77+
/// them into the same equivalent class.
78+
void buildOperationEquivalentLatticeAnchor(Operation *op) override;
79+
7580
/// At an entry point, the last modifications of all memory resources are
7681
/// unknown.
7782
void setToEntryState(LastModification *lattice) override {
@@ -147,6 +152,14 @@ LogicalResult LastModifiedAnalysis::visitOperation(
147152
return success();
148153
}
149154

155+
void LastModifiedAnalysis::buildOperationEquivalentLatticeAnchor(
156+
Operation *op) {
157+
if (isMemoryEffectFree(op)) {
158+
unionLatticeAnchors<LastModification>(getProgramPointBefore(op),
159+
getProgramPointAfter(op));
160+
}
161+
}
162+
150163
void LastModifiedAnalysis::visitCallControlFlowTransfer(
151164
CallOpInterface call, CallControlFlowAction action,
152165
const LastModification &before, LastModification *after) {

0 commit comments

Comments
 (0)