Skip to content

Commit 1dc0cf1

Browse files
committed
Improve MLIR-Query by adding matcher combinators
Limit backward-slice with nested matching Add variadic operators
1 parent 84c1564 commit 1dc0cf1

File tree

9 files changed

+377
-12
lines changed

9 files changed

+377
-12
lines changed

mlir/include/mlir/Query/Matcher/Marshallers.h

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,9 @@ class MatcherDescriptor {
108108
const llvm::ArrayRef<ParserValue> args,
109109
Diagnostics *error) const = 0;
110110

111+
// If the matcher is variadic, it can take any number of arguments.
112+
virtual bool isVariadic() const = 0;
113+
111114
// Returns the number of arguments accepted by the matcher.
112115
virtual unsigned getNumArgs() const = 0;
113116

@@ -140,6 +143,8 @@ class FixedArgCountMatcherDescriptor : public MatcherDescriptor {
140143
return marshaller(matcherFunc, matcherName, nameRange, args, error);
141144
}
142145

146+
bool isVariadic() const override { return false; }
147+
143148
unsigned getNumArgs() const override { return argKinds.size(); }
144149

145150
void getArgKinds(unsigned argNo, std::vector<ArgKind> &kinds) const override {
@@ -153,6 +158,54 @@ class FixedArgCountMatcherDescriptor : public MatcherDescriptor {
153158
const std::vector<ArgKind> argKinds;
154159
};
155160

161+
class VariadicOperatorMatcherDescriptor : public MatcherDescriptor {
162+
public:
163+
using VarOp = DynMatcher::VariadicOperator;
164+
VariadicOperatorMatcherDescriptor(unsigned minCount, unsigned maxCount,
165+
VarOp varOp, StringRef matcherName)
166+
: minCount(minCount), maxCount(maxCount), varOp(varOp),
167+
matcherName(matcherName) {}
168+
169+
VariantMatcher create(SourceRange nameRange, ArrayRef<ParserValue> args,
170+
Diagnostics *error) const override {
171+
if (args.size() < minCount || maxCount < args.size()) {
172+
addError(error, nameRange, ErrorType::RegistryWrongArgCount,
173+
{llvm::Twine("requires between "), llvm::Twine(minCount),
174+
llvm::Twine(" and "), llvm::Twine(maxCount),
175+
llvm::Twine(" args, got "), llvm::Twine(args.size())});
176+
return VariantMatcher();
177+
}
178+
179+
std::vector<VariantMatcher> innerArgs;
180+
for (size_t i = 0, e = args.size(); i != e; ++i) {
181+
const ParserValue &arg = args[i];
182+
const VariantValue &value = arg.value;
183+
if (!value.isMatcher()) {
184+
addError(error, arg.range, ErrorType::RegistryWrongArgType,
185+
{llvm::Twine(i + 1), llvm::Twine("Matcher: "),
186+
llvm::Twine(value.getTypeAsString())});
187+
return VariantMatcher();
188+
}
189+
innerArgs.push_back(value.getMatcher());
190+
}
191+
return VariantMatcher::VariadicOperatorMatcher(varOp, std::move(innerArgs));
192+
}
193+
194+
bool isVariadic() const override { return true; }
195+
196+
unsigned getNumArgs() const override { return 0; }
197+
198+
void getArgKinds(unsigned argNo, std::vector<ArgKind> &kinds) const override {
199+
kinds.push_back(ArgKind(ArgKind::Matcher));
200+
}
201+
202+
private:
203+
const unsigned minCount;
204+
const unsigned maxCount;
205+
const VarOp varOp;
206+
const StringRef matcherName;
207+
};
208+
156209
// Helper function to check if argument count matches expected count
157210
inline bool checkArgCount(SourceRange nameRange, size_t expectedArgCount,
158211
llvm::ArrayRef<ParserValue> args,
@@ -224,6 +277,14 @@ makeMatcherAutoMarshall(ReturnType (*matcherFunc)(ArgTypes...),
224277
reinterpret_cast<void (*)()>(matcherFunc), matcherName, argKinds);
225278
}
226279

280+
// Variadic operator overload.
281+
template <unsigned MinCount, unsigned MaxCount>
282+
std::unique_ptr<MatcherDescriptor>
283+
makeMatcherAutoMarshall(VariadicOperatorMatcherFunc<MinCount, MaxCount> func,
284+
StringRef matcherName) {
285+
return std::make_unique<VariadicOperatorMatcherDescriptor>(
286+
MinCount, MaxCount, func.varOp, matcherName);
287+
}
227288
} // namespace mlir::query::matcher::internal
228289

229290
#endif // MLIR_TOOLS_MLIRQUERY_MATCHER_MARSHALLERS_H

mlir/include/mlir/Query/Matcher/MatchersInternal.h

Lines changed: 103 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,11 +8,11 @@
88
//
99
// Implements the base layer of the matcher framework.
1010
//
11-
// Matchers are methods that return a Matcher which provides a method one of the
12-
// following methods: match(Operation *op), match(Operation *op,
13-
// SetVector<Operation *> &matchedOps)
11+
// Matchers are methods that return a Matcher which provide a method
12+
// `match(...)` method. The method's parameters define the context of the match.
13+
// Support includes simple (unary) matchers as well as matcher combinators.
14+
// (anyOf, allOf, etc.)
1415
//
15-
// The matcher functions are defined in include/mlir/IR/Matchers.h.
1616
// This file contains the wrapper classes needed to construct matchers for
1717
// mlir-query.
1818
//
@@ -25,6 +25,13 @@
2525
#include "llvm/ADT/IntrusiveRefCntPtr.h"
2626

2727
namespace mlir::query::matcher {
28+
class DynMatcher;
29+
namespace internal {
30+
31+
bool allOfVariadicOperator(Operation *op, ArrayRef<DynMatcher> innerMatchers);
32+
bool anyOfVariadicOperator(Operation *op, ArrayRef<DynMatcher> innerMatchers);
33+
34+
} // namespace internal
2835

2936
// Defaults to false if T has no match() method with the signature:
3037
// match(Operation* op).
@@ -84,6 +91,26 @@ class MatcherFnImpl : public MatcherInterface {
8491
MatcherFn matcherFn;
8592
};
8693

94+
// VariadicMatcher takes a vector of Matchers and returns true if any Matchers
95+
// match the given operation.
96+
using VariadicOperatorFunction = bool (*)(Operation *op,
97+
ArrayRef<DynMatcher> innerMatchers);
98+
99+
template <VariadicOperatorFunction Func>
100+
class VariadicMatcher : public MatcherInterface {
101+
public:
102+
VariadicMatcher(std::vector<DynMatcher> matchers) : matchers(matchers) {}
103+
104+
bool match(Operation *op) override { return Func(op, matchers); }
105+
// Fallback case
106+
bool match(Operation *op, SetVector<Operation *> &matchedOps) override {
107+
return false;
108+
}
109+
110+
private:
111+
std::vector<DynMatcher> matchers;
112+
};
113+
87114
// Matcher wraps a MatcherInterface implementation and provides match()
88115
// methods that redirect calls to the underlying implementation.
89116
class DynMatcher {
@@ -92,6 +119,31 @@ class DynMatcher {
92119
DynMatcher(MatcherInterface *implementation)
93120
: implementation(implementation) {}
94121

122+
// Construct from a variadic function.
123+
enum VariadicOperator {
124+
// Matches operations for which all provided matchers match.
125+
AllOf,
126+
// Matches operations for which at least one of the provided matchers
127+
// matches.
128+
AnyOf
129+
};
130+
131+
static std::unique_ptr<DynMatcher>
132+
constructVariadic(VariadicOperator Op,
133+
std::vector<DynMatcher> innerMatchers) {
134+
switch (Op) {
135+
case AllOf:
136+
return std::make_unique<DynMatcher>(
137+
new VariadicMatcher<internal::allOfVariadicOperator>(
138+
std::move(innerMatchers)));
139+
case AnyOf:
140+
return std::make_unique<DynMatcher>(
141+
new VariadicMatcher<internal::anyOfVariadicOperator>(
142+
std::move(innerMatchers)));
143+
}
144+
llvm_unreachable("Invalid Op value.");
145+
}
146+
95147
template <typename MatcherFn>
96148
static std::unique_ptr<DynMatcher>
97149
constructDynMatcherFromMatcherFn(MatcherFn &matcherFn) {
@@ -113,6 +165,53 @@ class DynMatcher {
113165
std::string functionName;
114166
};
115167

168+
// VariadicOperatorMatcher related types.
169+
template <typename... Ps>
170+
class VariadicOperatorMatcher {
171+
public:
172+
VariadicOperatorMatcher(DynMatcher::VariadicOperator varOp, Ps &&...params)
173+
: varOp(varOp), params(std::forward<Ps>(params)...) {}
174+
175+
operator std::unique_ptr<DynMatcher>() const & {
176+
return DynMatcher::constructVariadic(
177+
varOp, getMatchers(std::index_sequence_for<Ps...>()));
178+
}
179+
180+
operator std::unique_ptr<DynMatcher>() && {
181+
return DynMatcher::constructVariadic(
182+
varOp, std::move(*this).getMatchers(std::index_sequence_for<Ps...>()));
183+
}
184+
185+
private:
186+
// Helper method to unpack the tuple into a vector.
187+
template <std::size_t... Is>
188+
std::vector<DynMatcher> getMatchers(std::index_sequence<Is...>) const & {
189+
return {DynMatcher(std::get<Is>(params))...};
190+
}
191+
192+
template <std::size_t... Is>
193+
std::vector<DynMatcher> getMatchers(std::index_sequence<Is...>) && {
194+
return {DynMatcher(std::get<Is>(std::move(params)))...};
195+
}
196+
197+
const DynMatcher::VariadicOperator varOp;
198+
std::tuple<Ps...> params;
199+
};
200+
201+
// Overloaded function object to generate VariadicOperatorMatcher objects from
202+
// arbitrary matchers.
203+
template <unsigned MinCount, unsigned MaxCount>
204+
struct VariadicOperatorMatcherFunc {
205+
DynMatcher::VariadicOperator varOp;
206+
207+
template <typename... Ms>
208+
VariadicOperatorMatcher<Ms...> operator()(Ms &&...Ps) const {
209+
static_assert(MinCount <= sizeof...(Ms) && sizeof...(Ms) <= MaxCount,
210+
"invalid number of parameters for variadic matcher");
211+
return VariadicOperatorMatcher<Ms...>(varOp, std::forward<Ms>(Ps)...);
212+
}
213+
};
214+
116215
} // namespace mlir::query::matcher
117216

118217
#endif // MLIR_TOOLS_MLIRQUERY_MATCHER_MATCHERSINTERNAL_H

mlir/include/mlir/Query/Matcher/SliceMatchers.h

Lines changed: 103 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,8 @@
66
//
77
//===----------------------------------------------------------------------===//
88
//
9-
// This file provides matchers for MLIRQuery that peform slicing analysis
9+
// This file defines slicing-analysis matchers that extend and abstract the
10+
// core implementations from `SliceAnalysis.h`.
1011
//
1112
//===----------------------------------------------------------------------===//
1213

@@ -15,9 +16,9 @@
1516

1617
#include "mlir/Analysis/SliceAnalysis.h"
1718

18-
/// A matcher encapsulating `getBackwardSlice` method from SliceAnalysis.h.
19-
/// Additionally, it limits the slice computation to a certain depth level using
20-
/// a custom filter.
19+
/// Computes the backward-slice of all transitive defs reachable from `rootOp`,
20+
/// if `innerMatcher` matches. The traversal stops once the desired depth level
21+
/// is reached.
2122
///
2223
/// Example: starting from node 9, assuming the matcher
2324
/// computes the slice for the first two depth levels:
@@ -116,6 +117,81 @@ bool BackwardSliceMatcher<Matcher>::matches(
116117
: backwardSlice.size() >= 1;
117118
}
118119

120+
/// Computes the backward-slice of all transitive defs reachable from `rootOp`,
121+
/// if `innerMatcher` matches. Traversal stops where `filterMatcher` matches.
122+
template <typename BaseMatcher, typename Filter>
123+
class PredicateBackwardSliceMatcher {
124+
public:
125+
PredicateBackwardSliceMatcher(BaseMatcher innerMatcher, Filter filterMatcher,
126+
bool inclusive, bool omitBlockArguments,
127+
bool omitUsesFromAbove)
128+
: innerMatcher(std::move(innerMatcher)),
129+
filterMatcher(std::move(filterMatcher)), inclusive(inclusive),
130+
omitBlockArguments(omitBlockArguments),
131+
omitUsesFromAbove(omitUsesFromAbove) {}
132+
133+
bool match(Operation *rootOp, SetVector<Operation *> &backwardSlice) {
134+
backwardSlice.clear();
135+
BackwardSliceOptions options;
136+
options.inclusive = inclusive;
137+
options.omitUsesFromAbove = omitUsesFromAbove;
138+
options.omitBlockArguments = omitBlockArguments;
139+
if (innerMatcher.match(rootOp)) {
140+
options.filter = [&](Operation *subOp) {
141+
return !filterMatcher.match(subOp);
142+
};
143+
getBackwardSlice(rootOp, &backwardSlice, options);
144+
return options.inclusive ? backwardSlice.size() > 1
145+
: backwardSlice.size() >= 1;
146+
}
147+
return false;
148+
}
149+
150+
private:
151+
BaseMatcher innerMatcher;
152+
Filter filterMatcher;
153+
bool inclusive;
154+
bool omitBlockArguments;
155+
bool omitUsesFromAbove;
156+
};
157+
158+
/// Computes the forward-slice of all users reachable from `rootOp`,
159+
/// if `innerMatcher` matches. Traversal stops where `filterMatcher` matches.
160+
template <typename BaseMatcher, typename Filter>
161+
class PredicateForwardSliceMatcher {
162+
public:
163+
PredicateForwardSliceMatcher(BaseMatcher innerMatcher, Filter filterMatcher,
164+
bool inclusive)
165+
: innerMatcher(std::move(innerMatcher)),
166+
filterMatcher(std::move(filterMatcher)), inclusive(inclusive) {}
167+
168+
bool match(Operation *rootOp, SetVector<Operation *> &forward) {
169+
forward.clear();
170+
ForwardSliceOptions options;
171+
options.inclusive = inclusive;
172+
if (innerMatcher.match(rootOp)) {
173+
options.filter = [&](Operation *subOp) {
174+
return !filterMatcher.match(subOp);
175+
};
176+
getForwardSlice(rootOp, &forward, options);
177+
return options.inclusive ? forward.size() > 1 : forward.size() >= 1;
178+
}
179+
return false;
180+
}
181+
182+
private:
183+
BaseMatcher innerMatcher;
184+
Filter filterMatcher;
185+
bool inclusive;
186+
};
187+
188+
const matcher::VariadicOperatorMatcherFunc<1,
189+
std::numeric_limits<unsigned>::max()>
190+
anyOf = {matcher::DynMatcher::AnyOf};
191+
const matcher::VariadicOperatorMatcherFunc<1,
192+
std::numeric_limits<unsigned>::max()>
193+
allOf = {matcher::DynMatcher::AllOf};
194+
119195
/// Matches transitive defs of a top-level operation up to N levels.
120196
template <typename Matcher>
121197
inline BackwardSliceMatcher<Matcher>
@@ -127,7 +203,7 @@ m_GetDefinitions(Matcher innerMatcher, int64_t maxDepth, bool inclusive,
127203
omitUsesFromAbove);
128204
}
129205

130-
/// Matches all transitive defs of a top-level operation up to N levels
206+
/// Matches all transitive defs of a top-level operation up to N levels.
131207
template <typename Matcher>
132208
inline BackwardSliceMatcher<Matcher> m_GetAllDefinitions(Matcher innerMatcher,
133209
int64_t maxDepth) {
@@ -136,6 +212,28 @@ inline BackwardSliceMatcher<Matcher> m_GetAllDefinitions(Matcher innerMatcher,
136212
false, false);
137213
}
138214

215+
/// Matches all transitive defs of a top-level operation and stops where
216+
/// `filterMatcher` rejects.
217+
template <typename BaseMatcher, typename Filter>
218+
inline PredicateBackwardSliceMatcher<BaseMatcher, Filter>
219+
m_GetDefinitionsByPredicate(BaseMatcher innerMatcher, Filter filterMatcher,
220+
bool inclusive, bool omitBlockArguments,
221+
bool omitUsesFromAbove) {
222+
return PredicateBackwardSliceMatcher<BaseMatcher, Filter>(
223+
std::move(innerMatcher), std::move(filterMatcher), inclusive,
224+
omitBlockArguments, omitUsesFromAbove);
225+
}
226+
227+
/// Matches all users of a top-level operation and stops where
228+
/// `filterMatcher` rejects.
229+
template <typename BaseMatcher, typename Filter>
230+
inline PredicateForwardSliceMatcher<BaseMatcher, Filter>
231+
m_GetUsersByPredicate(BaseMatcher innerMatcher, Filter filterMatcher,
232+
bool inclusive) {
233+
return PredicateForwardSliceMatcher<BaseMatcher, Filter>(
234+
std::move(innerMatcher), std::move(filterMatcher), inclusive);
235+
}
236+
139237
} // namespace mlir::query::matcher
140238

141239
#endif // MLIR_TOOLS_MLIRQUERY_MATCHERS_SLICEMATCHERS_H

0 commit comments

Comments
 (0)