6
6
//
7
7
// ===----------------------------------------------------------------------===//
8
8
//
9
- // This file provides matchers for MLIRQuery that peform slicing analysis
9
+ // This file defines slicing-analysis matchers that extend and abstract the
10
+ // core implementations from `SliceAnalysis.h`.
10
11
//
11
12
// ===----------------------------------------------------------------------===//
12
13
15
16
16
17
#include " mlir/Analysis/SliceAnalysis.h"
17
18
18
- // / A matcher encapsulating `getBackwardSlice` method from SliceAnalysis.h.
19
- // / Additionally, it limits the slice computation to a certain depth level using
20
- // / a custom filter .
19
+ // / Computes the backward-slice of all transitive defs reachable from `rootOp`,
20
+ // / if `innerMatcher` matches. The traversal stops once the desired depth level
21
+ // / is reached .
21
22
// /
22
23
// / Example: starting from node 9, assuming the matcher
23
24
// / computes the slice for the first two depth levels:
@@ -116,6 +117,81 @@ bool BackwardSliceMatcher<Matcher>::matches(
116
117
: backwardSlice.size () >= 1 ;
117
118
}
118
119
120
+ // / Computes the backward-slice of all transitive defs reachable from `rootOp`,
121
+ // / if `innerMatcher` matches. Traversal stops where `filterMatcher` matches.
122
+ template <typename BaseMatcher, typename Filter>
123
+ class PredicateBackwardSliceMatcher {
124
+ public:
125
+ PredicateBackwardSliceMatcher (BaseMatcher innerMatcher, Filter filterMatcher,
126
+ bool inclusive, bool omitBlockArguments,
127
+ bool omitUsesFromAbove)
128
+ : innerMatcher(std::move(innerMatcher)),
129
+ filterMatcher (std::move(filterMatcher)), inclusive(inclusive),
130
+ omitBlockArguments(omitBlockArguments),
131
+ omitUsesFromAbove(omitUsesFromAbove) {}
132
+
133
+ bool match (Operation *rootOp, SetVector<Operation *> &backwardSlice) {
134
+ backwardSlice.clear ();
135
+ BackwardSliceOptions options;
136
+ options.inclusive = inclusive;
137
+ options.omitUsesFromAbove = omitUsesFromAbove;
138
+ options.omitBlockArguments = omitBlockArguments;
139
+ if (innerMatcher.match (rootOp)) {
140
+ options.filter = [&](Operation *subOp) {
141
+ return !filterMatcher.match (subOp);
142
+ };
143
+ getBackwardSlice (rootOp, &backwardSlice, options);
144
+ return options.inclusive ? backwardSlice.size () > 1
145
+ : backwardSlice.size () >= 1 ;
146
+ }
147
+ return false ;
148
+ }
149
+
150
+ private:
151
+ BaseMatcher innerMatcher;
152
+ Filter filterMatcher;
153
+ bool inclusive;
154
+ bool omitBlockArguments;
155
+ bool omitUsesFromAbove;
156
+ };
157
+
158
+ // / Computes the forward-slice of all users reachable from `rootOp`,
159
+ // / if `innerMatcher` matches. Traversal stops where `filterMatcher` matches.
160
+ template <typename BaseMatcher, typename Filter>
161
+ class PredicateForwardSliceMatcher {
162
+ public:
163
+ PredicateForwardSliceMatcher (BaseMatcher innerMatcher, Filter filterMatcher,
164
+ bool inclusive)
165
+ : innerMatcher(std::move(innerMatcher)),
166
+ filterMatcher (std::move(filterMatcher)), inclusive(inclusive) {}
167
+
168
+ bool match (Operation *rootOp, SetVector<Operation *> &forward) {
169
+ forward.clear ();
170
+ ForwardSliceOptions options;
171
+ options.inclusive = inclusive;
172
+ if (innerMatcher.match (rootOp)) {
173
+ options.filter = [&](Operation *subOp) {
174
+ return !filterMatcher.match (subOp);
175
+ };
176
+ getForwardSlice (rootOp, &forward, options);
177
+ return options.inclusive ? forward.size () > 1 : forward.size () >= 1 ;
178
+ }
179
+ return false ;
180
+ }
181
+
182
+ private:
183
+ BaseMatcher innerMatcher;
184
+ Filter filterMatcher;
185
+ bool inclusive;
186
+ };
187
+
188
+ const matcher::VariadicOperatorMatcherFunc<1 ,
189
+ std::numeric_limits<unsigned >::max()>
190
+ anyOf = {matcher::DynMatcher::AnyOf};
191
+ const matcher::VariadicOperatorMatcherFunc<1 ,
192
+ std::numeric_limits<unsigned >::max()>
193
+ allOf = {matcher::DynMatcher::AllOf};
194
+
119
195
// / Matches transitive defs of a top-level operation up to N levels.
120
196
template <typename Matcher>
121
197
inline BackwardSliceMatcher<Matcher>
@@ -127,7 +203,7 @@ m_GetDefinitions(Matcher innerMatcher, int64_t maxDepth, bool inclusive,
127
203
omitUsesFromAbove);
128
204
}
129
205
130
- // / Matches all transitive defs of a top-level operation up to N levels
206
+ // / Matches all transitive defs of a top-level operation up to N levels.
131
207
template <typename Matcher>
132
208
inline BackwardSliceMatcher<Matcher> m_GetAllDefinitions (Matcher innerMatcher,
133
209
int64_t maxDepth) {
@@ -136,6 +212,28 @@ inline BackwardSliceMatcher<Matcher> m_GetAllDefinitions(Matcher innerMatcher,
136
212
false , false );
137
213
}
138
214
215
+ // / Matches all transitive defs of a top-level operation and stops where
216
+ // / `filterMatcher` rejects.
217
+ template <typename BaseMatcher, typename Filter>
218
+ inline PredicateBackwardSliceMatcher<BaseMatcher, Filter>
219
+ m_GetDefinitionsByPredicate (BaseMatcher innerMatcher, Filter filterMatcher,
220
+ bool inclusive, bool omitBlockArguments,
221
+ bool omitUsesFromAbove) {
222
+ return PredicateBackwardSliceMatcher<BaseMatcher, Filter>(
223
+ std::move (innerMatcher), std::move (filterMatcher), inclusive,
224
+ omitBlockArguments, omitUsesFromAbove);
225
+ }
226
+
227
+ // / Matches all users of a top-level operation and stops where
228
+ // / `filterMatcher` rejects.
229
+ template <typename BaseMatcher, typename Filter>
230
+ inline PredicateForwardSliceMatcher<BaseMatcher, Filter>
231
+ m_GetUsersByPredicate (BaseMatcher innerMatcher, Filter filterMatcher,
232
+ bool inclusive) {
233
+ return PredicateForwardSliceMatcher<BaseMatcher, Filter>(
234
+ std::move (innerMatcher), std::move (filterMatcher), inclusive);
235
+ }
236
+
139
237
} // namespace mlir::query::matcher
140
238
141
239
#endif // MLIR_TOOLS_MLIRQUERY_MATCHERS_SLICEMATCHERS_H
0 commit comments