66//
77// ===----------------------------------------------------------------------===//
88//
9- // This file provides matchers that depend on Query.
9+ // This file provides matchers for MLIRQuery with more involved pattern-matching
10+ // logic.
1011//
1112// ===----------------------------------------------------------------------===//
1213
@@ -80,46 +81,43 @@ bool BackwardSliceMatcher<Matcher>::matches(
8081 BackwardSliceOptions &options, int64_t maxDepth) {
8182 backwardSlice.clear ();
8283 llvm::DenseMap<Operation *, int64_t > opDepths;
83- // The starting point is the root op; therefore, we set its depth to 0.
84+ // Initializing the root op with a depth of 0
8485 opDepths[rootOp] = 0 ;
8586 options.filter = [&](Operation *subOp) {
86- // If the subOp's depth exceeds maxDepth, we stop further slicing for this
87- // branch .
88- if (opDepths[ subOp] > maxDepth )
87+ // If the subOp hasn't been recorded in opDepths, it is deeper than
88+ // maxDepth .
89+ if (! opDepths. contains ( subOp) )
8990 return false ;
9091 // Examine subOp's operands to compute depths of their defining operations.
9192 for (auto operand : subOp->getOperands ()) {
93+ int64_t newDepth = opDepths[subOp] + 1 ;
94+ // If the newDepth is greater than maxDepth, further computation can be
95+ // skipped.
96+ if (newDepth > maxDepth)
97+ continue ;
98+
9299 if (auto definingOp = operand.getDefiningOp ()) {
93- // Set the defining operation's depth to one level greater than
94- // subOp's depth.
95- int64_t newDepth = opDepths[subOp] + 1 ;
96- if (!opDepths.contains (definingOp)) {
100+ // Registers the minimum depth
101+ if (!opDepths.contains (definingOp) || newDepth < opDepths[definingOp])
97102 opDepths[definingOp] = newDepth;
98- } else {
99- opDepths[definingOp] = std::min (opDepths[definingOp], newDepth);
100- }
101- return !(opDepths[subOp] > maxDepth);
102103 } else {
103104 auto blockArgument = cast<BlockArgument>(operand);
104105 Operation *parentOp = blockArgument.getOwner ()->getParentOp ();
105106 if (!parentOp)
106107 continue ;
107- int64_t newDepth = opDepths[subOp] + 1 ;
108- if (!opDepths.contains (parentOp)) {
108+
109+ if (!opDepths.contains (parentOp) || newDepth < opDepths[parentOp])
109110 opDepths[parentOp] = newDepth;
110- } else {
111- opDepths[parentOp] = std::min (opDepths[parentOp], newDepth);
112- }
113- return !(opDepths[parentOp] > maxDepth);
114111 }
115112 }
116113 return true ;
117114 };
118115 getBackwardSlice (rootOp, &backwardSlice, options);
119- return true ;
116+ return options.inclusive ? backwardSlice.size () > 1
117+ : backwardSlice.size () >= 1 ;
120118}
121119
122- // Matches transitive defs of a top-level operation up to N levels.
120+ // / Matches transitive defs of a top-level operation up to N levels.
123121template <typename Matcher>
124122inline BackwardSliceMatcher<Matcher>
125123m_GetDefinitions (Matcher innerMatcher, int64_t maxDepth, bool inclusive,
@@ -130,6 +128,15 @@ m_GetDefinitions(Matcher innerMatcher, int64_t maxDepth, bool inclusive,
130128 omitUsesFromAbove);
131129}
132130
131+ // / Matches all transitive defs of a top-level operation up to N levels
132+ template <typename Matcher>
133+ inline BackwardSliceMatcher<Matcher> m_GetAllDefinitions (Matcher innerMatcher,
134+ int64_t maxDepth) {
135+ assert (maxDepth >= 0 && " maxDepth must be non-negative" );
136+ return BackwardSliceMatcher<Matcher>(std::move (innerMatcher), maxDepth, true ,
137+ false , false );
138+ }
139+
133140} // namespace mlir::query::matcher
134141
135142#endif // MLIR_TOOLS_MLIRQUERY_MATCHERS_EXTRAMATCHERS_H
0 commit comments