1212
1313#ifndef MLIR_TOOLS_MLIRQUERY_MATCHERS_EXTRAMATCHERS_H
1414#define MLIR_TOOLS_MLIRQUERY_MATCHERS_EXTRAMATCHERS_H
15+
1516#include " mlir/Analysis/SliceAnalysis.h"
16- #include " mlir/Query/Matcher/MatchersInternal.h"
1717
18- // / A matcher encapsulating the initial `getBackwardSlice` method from
19- // / SliceAnalysis.h
18+ // / A matcher encapsulating `getBackwardSlice` method from SliceAnalysis.h.
2019// / Additionally, it limits the slice computation to a certain depth level using
21- // / a custom filter
20+ // / a custom filter.
2221// /
23- // / Example starting from node 9, assuming the matcher
24- // / computes the slice for the first two depth levels
22+ // / Example: starting from node 9, assuming the matcher
23+ // / computes the slice for the first two depth levels:
2524// / ============================
2625// / 1 2 3 4
2726// / |_______| |______|
3736// / Assuming all local orders match the numbering order:
3837// / {5, 7, 6, 8, 9}
3938namespace mlir ::query::matcher {
39+
40+ template <typename Matcher>
4041class BackwardSliceMatcher {
4142public:
42- explicit BackwardSliceMatcher (query::matcher::DynMatcher &&innerMatcher,
43- int64_t maxDepth, bool inclusive,
44- bool omitBlockArguments, bool omitUsesFromAbove)
43+ BackwardSliceMatcher (Matcher innerMatcher, int64_t maxDepth, bool inclusive,
44+ bool omitBlockArguments, bool omitUsesFromAbove)
4545 : innerMatcher(std::move(innerMatcher)), maxDepth(maxDepth),
4646 inclusive (inclusive), omitBlockArguments(omitBlockArguments),
4747 omitUsesFromAbove(omitUsesFromAbove) {}
48- bool match (Operation *op, SetVector<Operation *> &backwardSlice) {
48+
49+ bool match (Operation *rootOp, SetVector<Operation *> &backwardSlice) {
4950 BackwardSliceOptions options;
50- return (innerMatcher.match (op) &&
51- matches (op, backwardSlice, options, maxDepth));
51+ options.inclusive = inclusive;
52+ options.omitUsesFromAbove = omitUsesFromAbove;
53+ options.omitBlockArguments = omitBlockArguments;
54+ return (innerMatcher.match (rootOp) &&
55+ matches (rootOp, backwardSlice, options, maxDepth));
5256 }
5357
5458private:
@@ -57,29 +61,75 @@ class BackwardSliceMatcher {
5761
5862private:
5963 // The outer matcher (e.g., BackwardSliceMatcher) relies on the innerMatcher
60- // to determine whether we want to traverse the DAG or not. For example, we
61- // want to explore the DAG only if the top-level operation name is
62- // "arith.addf".
63- query::matcher::DynMatcher innerMatcher;
64- // maxDepth specifies the maximum depth that the matcher can traverse in the
65- // DAG . For example, if maxDepth is 2, the matcher will explore the defining
64+ // to determine whether we want to traverse the IR or not. For example, we
65+ // want to explore the IR only if the top-level operation name is
66+ // ` "arith.addf"` .
67+ Matcher innerMatcher;
68+ // ` maxDepth` specifies the maximum depth that the matcher can traverse the
69+ // IR . For example, if ` maxDepth` is 2, the matcher will explore the defining
6670 // operations of the top-level op up to 2 levels.
6771 int64_t maxDepth;
68-
6972 bool inclusive;
7073 bool omitBlockArguments;
7174 bool omitUsesFromAbove;
7275};
7376
74- // Matches transitive defs of a top level operation up to N levels
75- inline BackwardSliceMatcher
76- m_GetDefinitions (query::matcher::DynMatcher innerMatcher, int64_t maxDepth,
77- bool inclusive, bool omitBlockArguments,
78- bool omitUsesFromAbove) {
77+ template <typename Matcher>
78+ bool BackwardSliceMatcher<Matcher>::matches(
79+ Operation *rootOp, llvm::SetVector<Operation *> &backwardSlice,
80+ BackwardSliceOptions &options, int64_t maxDepth) {
81+ backwardSlice.clear ();
82+ llvm::DenseMap<Operation *, int64_t > opDepths;
83+ // The starting point is the root op; therefore, we set its depth to 0.
84+ opDepths[rootOp] = 0 ;
85+ options.filter = [&](Operation *subOp) {
86+ // If the subOp's depth exceeds maxDepth, we stop further slicing for this
87+ // branch.
88+ if (opDepths[subOp] > maxDepth)
89+ return false ;
90+ // Examine subOp's operands to compute depths of their defining operations.
91+ for (auto operand : subOp->getOperands ()) {
92+ if (auto definingOp = operand.getDefiningOp ()) {
93+ // Set the defining operation's depth to one level greater than
94+ // subOp's depth.
95+ int64_t newDepth = opDepths[subOp] + 1 ;
96+ if (!opDepths.contains (definingOp)) {
97+ opDepths[definingOp] = newDepth;
98+ } else {
99+ opDepths[definingOp] = std::min (opDepths[definingOp], newDepth);
100+ }
101+ return !(opDepths[subOp] > maxDepth);
102+ } else {
103+ auto blockArgument = cast<BlockArgument>(operand);
104+ Operation *parentOp = blockArgument.getOwner ()->getParentOp ();
105+ if (!parentOp)
106+ continue ;
107+ int64_t newDepth = opDepths[subOp] + 1 ;
108+ if (!opDepths.contains (parentOp)) {
109+ opDepths[parentOp] = newDepth;
110+ } else {
111+ opDepths[parentOp] = std::min (opDepths[parentOp], newDepth);
112+ }
113+ return !(opDepths[parentOp] > maxDepth);
114+ }
115+ }
116+ return true ;
117+ };
118+ getBackwardSlice (rootOp, &backwardSlice, options);
119+ return true ;
120+ }
121+
122+ // Matches transitive defs of a top-level operation up to N levels.
123+ template <typename Matcher>
124+ inline BackwardSliceMatcher<Matcher>
125+ m_GetDefinitions (Matcher innerMatcher, int64_t maxDepth, bool inclusive,
126+ bool omitBlockArguments, bool omitUsesFromAbove) {
79127 assert (maxDepth >= 0 && " maxDepth must be non-negative" );
80- return BackwardSliceMatcher (std::move (innerMatcher), maxDepth, inclusive,
81- omitBlockArguments, omitUsesFromAbove);
128+ return BackwardSliceMatcher<Matcher>(std::move (innerMatcher), maxDepth,
129+ inclusive, omitBlockArguments,
130+ omitUsesFromAbove);
82131}
132+
83133} // namespace mlir::query::matcher
84134
85135#endif // MLIR_TOOLS_MLIRQUERY_MATCHERS_EXTRAMATCHERS_H
0 commit comments