1818
1919#include " mlir/IR/Operation.h"
2020#include " mlir/Support/StorageUniquer.h"
21+ #include " llvm/ADT/EquivalenceClasses.h"
2122#include " llvm/ADT/Hashing.h"
2223#include " llvm/ADT/SetVector.h"
2324#include " llvm/Support/Compiler.h"
@@ -265,6 +266,14 @@ struct LatticeAnchor
265266// / Forward declaration of the data-flow analysis class.
266267class DataFlowAnalysis ;
267268
269+ } // namespace mlir
270+
271+ template <>
272+ struct llvm ::DenseMapInfo<mlir::LatticeAnchor>
273+ : public llvm::DenseMapInfo<mlir::LatticeAnchor::ParentTy> {};
274+
275+ namespace mlir {
276+
268277// ===----------------------------------------------------------------------===//
269278// DataFlowConfig
270279// ===----------------------------------------------------------------------===//
@@ -332,7 +341,9 @@ class DataFlowSolver {
332341 // / does not exist.
333342 template <typename StateT, typename AnchorT>
334343 const StateT *lookupState (AnchorT anchor) const {
335- const auto &mapIt = analysisStates.find (LatticeAnchor (anchor));
344+ LatticeAnchor latticeAnchor =
345+ getLeaderAnchorOrSelf<StateT>(LatticeAnchor (anchor));
346+ const auto &mapIt = analysisStates.find (latticeAnchor);
336347 if (mapIt == analysisStates.end ())
337348 return nullptr ;
338349 auto it = mapIt->second .find (TypeID::get<StateT>());
@@ -344,12 +355,34 @@ class DataFlowSolver {
344355 // / Erase any analysis state associated with the given lattice anchor.
345356 template <typename AnchorT>
346357 void eraseState (AnchorT anchor) {
347- LatticeAnchor la (anchor);
348- analysisStates.erase (LatticeAnchor (anchor));
358+ LatticeAnchor latticeAnchor (anchor);
359+
360+ // Update equivalentAnchorMap.
361+ for (auto &&[TypeId, eqClass] : equivalentAnchorMap) {
362+ if (!eqClass.contains (latticeAnchor)) {
363+ continue ;
364+ }
365+ llvm::EquivalenceClasses<LatticeAnchor>::member_iterator leaderIt =
366+ eqClass.findLeader (latticeAnchor);
367+
368+ // Update analysis states with new leader if needed.
369+ if (*leaderIt == latticeAnchor && ++leaderIt != eqClass.member_end ()) {
370+ analysisStates[*leaderIt][TypeId] =
371+ std::move (analysisStates[latticeAnchor][TypeId]);
372+ }
373+
374+ eqClass.erase (latticeAnchor);
375+ }
376+
377+ // Update analysis states.
378+ analysisStates.erase (latticeAnchor);
349379 }
350380
351- // Erase all analysis states
352- void eraseAllStates () { analysisStates.clear (); }
381+ // Erase all analysis states.
382+ void eraseAllStates () {
383+ analysisStates.clear ();
384+ equivalentAnchorMap.clear ();
385+ }
353386
354387 // / Get a uniqued lattice anchor instance. If one is not present, it is
355388 // / created with the provided arguments.
@@ -399,6 +432,19 @@ class DataFlowSolver {
399432 template <typename StateT, typename AnchorT>
400433 StateT *getOrCreateState (AnchorT anchor);
401434
435+ // / Get leader lattice anchor in equivalence lattice anchor group, return
436+ // / input lattice anchor if input not found in equivalece lattice anchor group.
437+ template <typename StateT>
438+ LatticeAnchor getLeaderAnchorOrSelf (LatticeAnchor latticeAnchor) const ;
439+
440+ // / Union input anchors under the given state.
441+ template <typename StateT, typename AnchorT>
442+ void unionLatticeAnchors (AnchorT anchor, AnchorT other);
443+
444+ // / Return given lattice is equivalent on given state.
445+ template <typename StateT>
446+ bool isEquivalent (LatticeAnchor lhs, LatticeAnchor rhs) const ;
447+
402448 // / Propagate an update to an analysis state if it changed by pushing
403449 // / dependent work items to the back of the queue.
404450 // / This should only be used when DataFlowSolver is running.
@@ -429,10 +475,15 @@ class DataFlowSolver {
429475
430476 // / A type-erased map of lattice anchors to associated analysis states for
431477 // / first-class lattice anchors.
432- DenseMap<LatticeAnchor, DenseMap<TypeID, std::unique_ptr<AnalysisState>>,
433- DenseMapInfo<LatticeAnchor::ParentTy>>
478+ DenseMap<LatticeAnchor, DenseMap<TypeID, std::unique_ptr<AnalysisState>>>
434479 analysisStates;
435480
481+ // / A type-erased map of lattice type to the equivalet lattice anchors.
482+ // / Lattice anchors are considered equivalent under a certain lattice type if
483+ // / and only if, under this lattice type, the lattices pointed to by these
484+ // / lattice anchors necessarily contain identical value.
485+ DenseMap<TypeID, llvm::EquivalenceClasses<LatticeAnchor>> equivalentAnchorMap;
486+
436487 // / Allow the base child analysis class to access the internals of the solver.
437488 friend class DataFlowAnalysis ;
438489};
@@ -564,6 +615,14 @@ class DataFlowAnalysis {
564615 // / will provide a value for then.
565616 virtual LogicalResult visit (ProgramPoint *point) = 0;
566617
618+ // / Initialize lattice anchor equivalence class from the provided top-level
619+ // / operation.
620+ // /
621+ // / This function will union lattice anchor to same equivalent class if the
622+ // / analysis can determine the lattice content of lattice anchor is
623+ // / necessarily identical under the corrensponding lattice type.
624+ virtual void initializeEquivalentLatticeAnchor (Operation *top) { return ; }
625+
567626protected:
568627 // / Create a dependency between the given analysis state and lattice anchor
569628 // / on this analysis.
@@ -584,6 +643,12 @@ class DataFlowAnalysis {
584643 return solver.getLatticeAnchor <AnchorT>(std::forward<Args>(args)...);
585644 }
586645
646+ // / Union input anchors under the given state.
647+ template <typename StateT, typename AnchorT>
648+ void unionLatticeAnchors (AnchorT anchor, AnchorT other) {
649+ return solver.unionLatticeAnchors <StateT>(anchor, other);
650+ }
651+
587652 // / Get the analysis state associated with the lattice anchor. The returned
588653 // / state is expected to be "write-only", and any updates need to be
589654 // / propagated by `propagateIfChanged`.
@@ -598,7 +663,9 @@ class DataFlowAnalysis {
598663 template <typename StateT, typename AnchorT>
599664 const StateT *getOrCreateFor (ProgramPoint *dependent, AnchorT anchor) {
600665 StateT *state = getOrCreate<StateT>(anchor);
601- addDependency (state, dependent);
666+ if (!solver.isEquivalent <StateT>(LatticeAnchor (anchor),
667+ LatticeAnchor (dependent)))
668+ addDependency (state, dependent);
602669 return state;
603670 }
604671
@@ -644,10 +711,26 @@ AnalysisT *DataFlowSolver::load(Args &&...args) {
644711 return static_cast <AnalysisT *>(childAnalyses.back ().get ());
645712}
646713
714+ template <typename StateT>
715+ LatticeAnchor
716+ DataFlowSolver::getLeaderAnchorOrSelf (LatticeAnchor latticeAnchor) const {
717+ const llvm::EquivalenceClasses<LatticeAnchor> eqClass =
718+ equivalentAnchorMap.lookup (TypeID::get<StateT>());
719+ llvm::EquivalenceClasses<LatticeAnchor>::member_iterator leaderIt =
720+ eqClass.findLeader (latticeAnchor);
721+ if (leaderIt != eqClass.member_end ()) {
722+ return *leaderIt;
723+ }
724+ return latticeAnchor;
725+ }
726+
647727template <typename StateT, typename AnchorT>
648728StateT *DataFlowSolver::getOrCreateState (AnchorT anchor) {
729+ // Replace to leader anchor if found.
730+ LatticeAnchor latticeAnchor (anchor);
731+ latticeAnchor = getLeaderAnchorOrSelf<StateT>(latticeAnchor);
649732 std::unique_ptr<AnalysisState> &state =
650- analysisStates[LatticeAnchor (anchor) ][TypeID::get<StateT>()];
733+ analysisStates[latticeAnchor ][TypeID::get<StateT>()];
651734 if (!state) {
652735 state = std::unique_ptr<StateT>(new StateT (anchor));
653736#if LLVM_ENABLE_ABI_BREAKING_CHECKS
@@ -657,6 +740,22 @@ StateT *DataFlowSolver::getOrCreateState(AnchorT anchor) {
657740 return static_cast <StateT *>(state.get ());
658741}
659742
743+ template <typename StateT>
744+ bool DataFlowSolver::isEquivalent (LatticeAnchor lhs, LatticeAnchor rhs) const {
745+ const llvm::EquivalenceClasses<LatticeAnchor> eqClass =
746+ equivalentAnchorMap.lookup (TypeID::get<StateT>());
747+ if (!eqClass.contains (lhs) || !eqClass.contains (rhs))
748+ return false ;
749+ return eqClass.isEquivalent (lhs, rhs);
750+ }
751+
752+ template <typename StateT, typename AnchorT>
753+ void DataFlowSolver::unionLatticeAnchors (AnchorT anchor, AnchorT other) {
754+ llvm::EquivalenceClasses<LatticeAnchor> &eqClass =
755+ equivalentAnchorMap[TypeID::get<StateT>()];
756+ eqClass.unionSets (LatticeAnchor (anchor), LatticeAnchor (other));
757+ }
758+
660759inline raw_ostream &operator <<(raw_ostream &os, const AnalysisState &state) {
661760 state.print (os);
662761 return os;
0 commit comments