1515
1616#include " MatchFinder.h"
1717#include " MatchersInternal.h"
18+ #include " mlir/IR/Region.h"
19+ #include " mlir/Query/Query.h"
20+ #include " llvm/Support/raw_ostream.h"
1821
1922namespace mlir {
2023
@@ -24,80 +27,161 @@ namespace extramatcher {
2427
2528namespace detail {
2629
27- class DefinitionsMatcher {
30+ class BackwardSliceMatcher {
2831public:
29- DefinitionsMatcher (matcher::DynMatcher &&InnerMatcher , unsigned Hops )
30- : InnerMatcher (std::move(InnerMatcher )), Hops(Hops ) {}
32+ BackwardSliceMatcher (matcher::DynMatcher &&innerMatcher , unsigned hops )
33+ : innerMatcher (std::move(innerMatcher )), hops(hops ) {}
3134
3235private:
33- bool matches (Operation *op, matcher::BoundOperationsGraphBuilder &Bound,
34- unsigned TempHops) {
35-
36- llvm::DenseSet<mlir::Value> Ccache;
37- llvm::SmallVector<std::pair<Operation *, size_t >, 4 > TempStorage;
38- TempStorage.push_back ({op, TempHops});
39- while (!TempStorage.empty ()) {
40- auto [CurrentOp, RemainingHops] = TempStorage.pop_back_val ();
41-
42- matcher::BoundOperationNode *CurrentNode =
43- Bound.addNode (CurrentOp, true , true );
44- if (RemainingHops == 0 ) {
45- continue ;
46- }
36+ bool matches (Operation *op, SetVector<Operation *> &backwardSlice,
37+ QueryOptions &options, unsigned tempHops) {
4738
48- for (auto Operand : CurrentOp->getOperands ()) {
49- if (auto DefiningOp = Operand.getDefiningOp ()) {
50- Bound.addEdge (CurrentOp, DefiningOp);
51- if (!Ccache.contains (Operand)) {
52- Ccache.insert (Operand);
53- TempStorage.emplace_back (DefiningOp, RemainingHops - 1 );
54- }
55- } else if (auto BlockArg = Operand.dyn_cast <BlockArgument>()) {
56- auto *Block = BlockArg.getOwner ();
39+ bool validSlice = true ;
40+ if (op->hasTrait <OpTrait::IsIsolatedFromAbove>()) {
41+ return false ;
42+ }
5743
58- if (Block->isEntryBlock () &&
59- isa<FunctionOpInterface>(Block->getParentOp ())) {
60- continue ;
44+ auto processValue = [&](Value value) {
45+ if (tempHops == 0 ) {
46+ return ;
47+ }
48+ if (auto *definingOp = value.getDefiningOp ()) {
49+ if (backwardSlice.count (definingOp) == 0 )
50+ matches (definingOp, backwardSlice, options, tempHops - 1 );
51+ } else if (auto blockArg = dyn_cast<BlockArgument>(value)) {
52+ if (options.omitBlockArguments )
53+ return ;
54+ Block *block = blockArg.getOwner ();
55+
56+ Operation *parentOp = block->getParentOp ();
57+
58+ if (parentOp && backwardSlice.count (parentOp) == 0 ) {
59+ if (parentOp->getNumRegions () == 1 &&
60+ parentOp->getRegion (0 ).getBlocks ().size () == 1 ) {
61+ validSlice = false ;
62+ return ;
63+ };
64+ matches (parentOp, backwardSlice, options, tempHops - 1 );
65+ }
66+ } else {
67+ validSlice = false ;
68+ return ;
69+ }
70+ };
71+
72+ if (!options.omitUsesFromAbove ) {
73+ llvm::for_each (op->getRegions (), [&](Region ®ion) {
74+ SmallPtrSet<Region *, 4 > descendents;
75+ region.walk (
76+ [&](Region *childRegion) { descendents.insert (childRegion); });
77+ region.walk ([&](Operation *op) {
78+ for (OpOperand &operand : op->getOpOperands ()) {
79+ if (!descendents.contains (operand.get ().getParentRegion ()))
80+ processValue (operand.get ());
81+ if (!validSlice)
82+ return ;
6183 }
84+ });
85+ });
86+ }
6287
63- Operation *ParentOp = BlockArg.getOwner ()->getParentOp ();
64- if (ParentOp) {
65- Bound.addEdge (CurrentOp, ParentOp);
66- if (!!Ccache.contains (BlockArg)) {
67- Ccache.insert (BlockArg);
68- TempStorage.emplace_back (ParentOp, RemainingHops - 1 );
69- }
70- }
71- }
88+ llvm::for_each (op->getOperands (), [&](Value operand) {
89+ processValue (operand);
90+ if (!validSlice)
91+ return ;
92+ });
93+ backwardSlice.insert (op);
94+ if (!validSlice) {
95+ return false ;
96+ }
97+ return true ;
98+ }
99+
100+ public:
101+ bool match (Operation *op, SetVector<Operation *> &backwardSlice,
102+ QueryOptions &options) {
103+ if (innerMatcher.match (op) && matches (op, backwardSlice, options, hops)) {
104+ if (!options.inclusive ) {
105+ backwardSlice.remove (op);
72106 }
107+ return true ;
73108 }
74- // We need at least 1 defining op
75- return Ccache.size () >= 2 ;
109+ return false ;
76110 }
77111
112+ private:
113+ matcher::DynMatcher innerMatcher;
114+ unsigned hops;
115+ };
116+
117+ class ForwardSliceMatcher {
78118public:
79- bool match (Operation *op, matcher::BoundOperationsGraphBuilder &Bound) {
80- if (InnerMatcher.match (op) && matches (op, Bound, Hops)) {
119+ ForwardSliceMatcher (matcher::DynMatcher &&innerMatcher, unsigned hops)
120+ : innerMatcher(std::move(innerMatcher)), hops(hops) {}
121+
122+ private:
123+ bool matches (Operation *op, SetVector<Operation *> &forwardSlice,
124+ QueryOptions &options, unsigned tempHops) {
125+
126+ if (tempHops == 0 ) {
127+ forwardSlice.insert (op);
128+ return true ;
129+ }
130+
131+ for (Region ®ion : op->getRegions ())
132+ for (Block &block : region)
133+ for (Operation &blockOp : block)
134+ if (forwardSlice.count (&blockOp) == 0 )
135+ matches (&blockOp, forwardSlice, options, tempHops - 1 );
136+ for (Value result : op->getResults ()) {
137+ for (Operation *userOp : result.getUsers ())
138+ if (forwardSlice.count (userOp) == 0 )
139+ matches (userOp, forwardSlice, options, tempHops - 1 );
140+ }
141+
142+ forwardSlice.insert (op);
143+ return true ;
144+ }
145+
146+ public:
147+ bool match (Operation *op, SetVector<Operation *> &forwardSlice,
148+ QueryOptions &options) {
149+ if (innerMatcher.match (op) && matches (op, forwardSlice, options, hops)) {
150+ if (!options.inclusive ) {
151+ forwardSlice.remove (op);
152+ }
153+ SmallVector<Operation *, 0 > v (forwardSlice.takeVector ());
154+ forwardSlice.insert (v.rbegin (), v.rend ());
81155 return true ;
82156 }
83157 return false ;
84158 }
85159
86160private:
87- matcher::DynMatcher InnerMatcher ;
88- unsigned Hops ;
161+ matcher::DynMatcher innerMatcher ;
162+ unsigned hops ;
89163};
164+
90165} // namespace detail
91166
92- inline detail::DefinitionsMatcher
93- definedBy (mlir::query::matcher::DynMatcher InnerMatcher) {
94- return detail::DefinitionsMatcher (std::move (InnerMatcher), 1 );
167+ inline detail::BackwardSliceMatcher
168+ definedBy (mlir::query::matcher::DynMatcher innerMatcher) {
169+ return detail::BackwardSliceMatcher (std::move (innerMatcher), 1 );
170+ }
171+
172+ inline detail::BackwardSliceMatcher
173+ getDefinitions (mlir::query::matcher::DynMatcher innerMatcher, unsigned hops) {
174+ return detail::BackwardSliceMatcher (std::move (innerMatcher), hops);
175+ }
176+
177+ inline detail::ForwardSliceMatcher
178+ usedBy (mlir::query::matcher::DynMatcher innerMatcher) {
179+ return detail::ForwardSliceMatcher (std::move (innerMatcher), 1 );
95180}
96181
97- inline detail::DefinitionsMatcher
98- getDefinitions (mlir::query::matcher::DynMatcher InnerMatcher, unsigned Hops) {
99- assert (Hops > 0 && " hops must be >= 1" );
100- return detail::DefinitionsMatcher (std::move (InnerMatcher), Hops);
182+ inline detail::ForwardSliceMatcher
183+ getUses (mlir::query::matcher::DynMatcher innerMatcher, unsigned hops) {
184+ return detail::ForwardSliceMatcher (std::move (innerMatcher), hops);
101185}
102186
103187} // namespace extramatcher
0 commit comments