diff --git a/mlir/include/mlir/Query/Matcher/Marshallers.h b/mlir/include/mlir/Query/Matcher/Marshallers.h index 012bf7b9ec4a9..5fe6965f32efb 100644 --- a/mlir/include/mlir/Query/Matcher/Marshallers.h +++ b/mlir/include/mlir/Query/Matcher/Marshallers.h @@ -108,6 +108,9 @@ class MatcherDescriptor { const llvm::ArrayRef args, Diagnostics *error) const = 0; + // If the matcher is variadic, it can take any number of arguments. + virtual bool isVariadic() const = 0; + // Returns the number of arguments accepted by the matcher. virtual unsigned getNumArgs() const = 0; @@ -140,6 +143,8 @@ class FixedArgCountMatcherDescriptor : public MatcherDescriptor { return marshaller(matcherFunc, matcherName, nameRange, args, error); } + bool isVariadic() const override { return false; } + unsigned getNumArgs() const override { return argKinds.size(); } void getArgKinds(unsigned argNo, std::vector &kinds) const override { @@ -153,6 +158,54 @@ class FixedArgCountMatcherDescriptor : public MatcherDescriptor { const std::vector argKinds; }; +class VariadicOperatorMatcherDescriptor : public MatcherDescriptor { +public: + using VarOp = DynMatcher::VariadicOperator; + VariadicOperatorMatcherDescriptor(unsigned minCount, unsigned maxCount, + VarOp varOp, StringRef matcherName) + : minCount(minCount), maxCount(maxCount), varOp(varOp), + matcherName(matcherName) {} + + VariantMatcher create(SourceRange nameRange, ArrayRef args, + Diagnostics *error) const override { + if (args.size() < minCount || maxCount < args.size()) { + addError(error, nameRange, ErrorType::RegistryWrongArgCount, + {llvm::Twine("requires between "), llvm::Twine(minCount), + llvm::Twine(" and "), llvm::Twine(maxCount), + llvm::Twine(" args, got "), llvm::Twine(args.size())}); + return VariantMatcher(); + } + + std::vector innerArgs; + for (int64_t i = 0, e = args.size(); i != e; ++i) { + const ParserValue &arg = args[i]; + const VariantValue &value = arg.value; + if (!value.isMatcher()) { + addError(error, arg.range, ErrorType::RegistryWrongArgType, + {llvm::Twine(i + 1), llvm::Twine("matcher: "), + llvm::Twine(value.getTypeAsString())}); + return VariantMatcher(); + } + innerArgs.push_back(value.getMatcher()); + } + return VariantMatcher::VariadicOperatorMatcher(varOp, std::move(innerArgs)); + } + + bool isVariadic() const override { return true; } + + unsigned getNumArgs() const override { return 0; } + + void getArgKinds(unsigned argNo, std::vector &kinds) const override { + kinds.push_back(ArgKind(ArgKind::Matcher)); + } + +private: + const unsigned minCount; + const unsigned maxCount; + const VarOp varOp; + const StringRef matcherName; +}; + // Helper function to check if argument count matches expected count inline bool checkArgCount(SourceRange nameRange, size_t expectedArgCount, llvm::ArrayRef args, @@ -224,6 +277,14 @@ makeMatcherAutoMarshall(ReturnType (*matcherFunc)(ArgTypes...), reinterpret_cast(matcherFunc), matcherName, argKinds); } +// Variadic operator overload. +template +std::unique_ptr +makeMatcherAutoMarshall(VariadicOperatorMatcherFunc func, + StringRef matcherName) { + return std::make_unique( + MinCount, MaxCount, func.varOp, matcherName); +} } // namespace mlir::query::matcher::internal #endif // MLIR_TOOLS_MLIRQUERY_MATCHER_MARSHALLERS_H diff --git a/mlir/include/mlir/Query/Matcher/MatchFinder.h b/mlir/include/mlir/Query/Matcher/MatchFinder.h index f8abf20ef60bb..6d06ca13d1344 100644 --- a/mlir/include/mlir/Query/Matcher/MatchFinder.h +++ b/mlir/include/mlir/Query/Matcher/MatchFinder.h @@ -21,7 +21,9 @@ namespace mlir::query::matcher { -/// A class that provides utilities to find operations in the IR. +/// Finds and collects matches from the IR. After construction +/// `collectMatches` can be used to traverse the IR and apply +/// matchers. class MatchFinder { public: diff --git a/mlir/include/mlir/Query/Matcher/MatchersInternal.h b/mlir/include/mlir/Query/Matcher/MatchersInternal.h index 183b2514e109f..88109430b6feb 100644 --- a/mlir/include/mlir/Query/Matcher/MatchersInternal.h +++ b/mlir/include/mlir/Query/Matcher/MatchersInternal.h @@ -8,11 +8,11 @@ // // Implements the base layer of the matcher framework. // -// Matchers are methods that return a Matcher which provides a method one of the -// following methods: match(Operation *op), match(Operation *op, -// SetVector &matchedOps) +// Matchers are methods that return a Matcher which provides a +// `match(...)` method whose parameters define the context of the match. +// Support includes simple (unary) matchers as well as matcher combinators +// (anyOf, allOf, etc.) // -// The matcher functions are defined in include/mlir/IR/Matchers.h. // This file contains the wrapper classes needed to construct matchers for // mlir-query. // @@ -25,6 +25,15 @@ #include "llvm/ADT/IntrusiveRefCntPtr.h" namespace mlir::query::matcher { +class DynMatcher; +namespace internal { + +bool allOfVariadicOperator(Operation *op, SetVector *matchedOps, + ArrayRef innerMatchers); +bool anyOfVariadicOperator(Operation *op, SetVector *matchedOps, + ArrayRef innerMatchers); + +} // namespace internal // Defaults to false if T has no match() method with the signature: // match(Operation* op). @@ -84,6 +93,27 @@ class MatcherFnImpl : public MatcherInterface { MatcherFn matcherFn; }; +// VariadicMatcher takes a vector of Matchers and returns true if any Matchers +// match the given operation. +using VariadicOperatorFunction = bool (*)(Operation *op, + SetVector *matchedOps, + ArrayRef innerMatchers); + +template +class VariadicMatcher : public MatcherInterface { +public: + VariadicMatcher(std::vector matchers) + : matchers(std::move(matchers)) {} + + bool match(Operation *op) override { return Func(op, nullptr, matchers); } + bool match(Operation *op, SetVector &matchedOps) override { + return Func(op, &matchedOps, matchers); + } + +private: + std::vector matchers; +}; + // Matcher wraps a MatcherInterface implementation and provides match() // methods that redirect calls to the underlying implementation. class DynMatcher { @@ -92,6 +122,31 @@ class DynMatcher { DynMatcher(MatcherInterface *implementation) : implementation(implementation) {} + // Construct from a variadic function. + enum VariadicOperator { + // Matches operations for which all provided matchers match. + AllOf, + // Matches operations for which at least one of the provided matchers + // matches. + AnyOf + }; + + static std::unique_ptr + constructVariadic(VariadicOperator Op, + std::vector innerMatchers) { + switch (Op) { + case AllOf: + return std::make_unique( + new VariadicMatcher( + std::move(innerMatchers))); + case AnyOf: + return std::make_unique( + new VariadicMatcher( + std::move(innerMatchers))); + } + llvm_unreachable("Invalid Op value."); + } + template static std::unique_ptr constructDynMatcherFromMatcherFn(MatcherFn &matcherFn) { @@ -113,6 +168,59 @@ class DynMatcher { std::string functionName; }; +// VariadicOperatorMatcher related types. +template +class VariadicOperatorMatcher { +public: + VariadicOperatorMatcher(DynMatcher::VariadicOperator varOp, Ps &&...params) + : varOp(varOp), params(std::forward(params)...) {} + + operator std::unique_ptr() const & { + return DynMatcher::constructVariadic( + varOp, getMatchers(std::index_sequence_for())); + } + + operator std::unique_ptr() && { + return DynMatcher::constructVariadic( + varOp, std::move(*this).getMatchers(std::index_sequence_for())); + } + +private: + // Helper method to unpack the tuple into a vector. + template + std::vector getMatchers(std::index_sequence) const & { + return {DynMatcher(std::get(params))...}; + } + + template + std::vector getMatchers(std::index_sequence) && { + return {DynMatcher(std::get(std::move(params)))...}; + } + + const DynMatcher::VariadicOperator varOp; + std::tuple params; +}; + +// Overloaded function object to generate VariadicOperatorMatcher objects from +// arbitrary matchers. +template +struct VariadicOperatorMatcherFunc { + DynMatcher::VariadicOperator varOp; + + template + VariadicOperatorMatcher operator()(Ms &&...Ps) const { + static_assert(MinCount <= sizeof...(Ms) && sizeof...(Ms) <= MaxCount, + "invalid number of parameters for variadic matcher"); + return VariadicOperatorMatcher(varOp, std::forward(Ps)...); + } +}; + +namespace internal { +const VariadicOperatorMatcherFunc<1, std::numeric_limits::max()> + anyOf = {DynMatcher::AnyOf}; +const VariadicOperatorMatcherFunc<1, std::numeric_limits::max()> + allOf = {DynMatcher::AllOf}; +} // namespace internal } // namespace mlir::query::matcher #endif // MLIR_TOOLS_MLIRQUERY_MATCHER_MATCHERSINTERNAL_H diff --git a/mlir/include/mlir/Query/Matcher/SliceMatchers.h b/mlir/include/mlir/Query/Matcher/SliceMatchers.h index 441205b3a9615..7181648f06f89 100644 --- a/mlir/include/mlir/Query/Matcher/SliceMatchers.h +++ b/mlir/include/mlir/Query/Matcher/SliceMatchers.h @@ -6,7 +6,8 @@ // //===----------------------------------------------------------------------===// // -// This file provides matchers for MLIRQuery that peform slicing analysis +// This file defines slicing-analysis matchers that extend and abstract the +// core implementations from `SliceAnalysis.h`. // //===----------------------------------------------------------------------===// @@ -16,9 +17,9 @@ #include "mlir/Analysis/SliceAnalysis.h" #include "mlir/IR/Operation.h" -/// A matcher encapsulating `getBackwardSlice` method from SliceAnalysis.h. -/// Additionally, it limits the slice computation to a certain depth level using -/// a custom filter. +/// Computes the backward-slice of all transitive defs reachable from `rootOp`, +/// if `innerMatcher` matches. The traversal stops once the desired depth level +/// is reached. /// /// Example: starting from node 9, assuming the matcher /// computes the slice for the first two depth levels: @@ -119,6 +120,77 @@ bool BackwardSliceMatcher::matches( : backwardSlice.size() >= 1; } +/// Computes the backward-slice of all transitive defs reachable from `rootOp`, +/// if `innerMatcher` matches. Traversal stops where `filterMatcher` matches. +template +class PredicateBackwardSliceMatcher { +public: + PredicateBackwardSliceMatcher(BaseMatcher innerMatcher, Filter filterMatcher, + bool inclusive, bool omitBlockArguments, + bool omitUsesFromAbove) + : innerMatcher(std::move(innerMatcher)), + filterMatcher(std::move(filterMatcher)), inclusive(inclusive), + omitBlockArguments(omitBlockArguments), + omitUsesFromAbove(omitUsesFromAbove) {} + + bool match(Operation *rootOp, SetVector &backwardSlice) { + backwardSlice.clear(); + BackwardSliceOptions options; + options.inclusive = inclusive; + options.omitUsesFromAbove = omitUsesFromAbove; + options.omitBlockArguments = omitBlockArguments; + if (innerMatcher.match(rootOp)) { + options.filter = [&](Operation *subOp) { + return !filterMatcher.match(subOp); + }; + LogicalResult result = getBackwardSlice(rootOp, &backwardSlice, options); + assert(result.succeeded() && "expected backward slice to succeed"); + (void)result; + return options.inclusive ? backwardSlice.size() > 1 + : backwardSlice.size() >= 1; + } + return false; + } + +private: + BaseMatcher innerMatcher; + Filter filterMatcher; + bool inclusive; + bool omitBlockArguments; + bool omitUsesFromAbove; +}; + +/// Computes the forward-slice of all users reachable from `rootOp`, +/// if `innerMatcher` matches. Traversal stops where `filterMatcher` matches. +template +class PredicateForwardSliceMatcher { +public: + PredicateForwardSliceMatcher(BaseMatcher innerMatcher, Filter filterMatcher, + bool inclusive) + : innerMatcher(std::move(innerMatcher)), + filterMatcher(std::move(filterMatcher)), inclusive(inclusive) {} + + bool match(Operation *rootOp, SetVector &forwardSlice) { + forwardSlice.clear(); + ForwardSliceOptions options; + options.inclusive = inclusive; + if (innerMatcher.match(rootOp)) { + options.filter = [&](Operation *subOp) { + return !filterMatcher.match(subOp); + }; + getForwardSlice(rootOp, &forwardSlice, options); + return options.inclusive ? forwardSlice.size() > 1 + : forwardSlice.size() >= 1; + } + return false; + } + +private: + BaseMatcher innerMatcher; + Filter filterMatcher; + bool inclusive; +}; + /// Matches transitive defs of a top-level operation up to N levels. template inline BackwardSliceMatcher @@ -130,7 +202,7 @@ m_GetDefinitions(Matcher innerMatcher, int64_t maxDepth, bool inclusive, omitUsesFromAbove); } -/// Matches all transitive defs of a top-level operation up to N levels +/// Matches all transitive defs of a top-level operation up to N levels. template inline BackwardSliceMatcher m_GetAllDefinitions(Matcher innerMatcher, int64_t maxDepth) { @@ -139,6 +211,28 @@ inline BackwardSliceMatcher m_GetAllDefinitions(Matcher innerMatcher, false, false); } +/// Matches all transitive defs of a top-level operation and stops where +/// `filterMatcher` rejects. +template +inline PredicateBackwardSliceMatcher +m_GetDefinitionsByPredicate(BaseMatcher innerMatcher, Filter filterMatcher, + bool inclusive, bool omitBlockArguments, + bool omitUsesFromAbove) { + return PredicateBackwardSliceMatcher( + std::move(innerMatcher), std::move(filterMatcher), inclusive, + omitBlockArguments, omitUsesFromAbove); +} + +/// Matches all users of a top-level operation and stops where +/// `filterMatcher` rejects. +template +inline PredicateForwardSliceMatcher +m_GetUsersByPredicate(BaseMatcher innerMatcher, Filter filterMatcher, + bool inclusive) { + return PredicateForwardSliceMatcher( + std::move(innerMatcher), std::move(filterMatcher), inclusive); +} + } // namespace mlir::query::matcher #endif // MLIR_TOOLS_MLIRQUERY_MATCHERS_SLICEMATCHERS_H diff --git a/mlir/include/mlir/Query/Matcher/VariantValue.h b/mlir/include/mlir/Query/Matcher/VariantValue.h index 98c0a18e25101..1a47576de1841 100644 --- a/mlir/include/mlir/Query/Matcher/VariantValue.h +++ b/mlir/include/mlir/Query/Matcher/VariantValue.h @@ -26,7 +26,12 @@ enum class ArgKind { Boolean, Matcher, Signed, String }; // A variant matcher object to abstract simple and complex matchers into a // single object type. class VariantMatcher { - class MatcherOps; + class MatcherOps { + public: + std::optional + constructVariadicOperator(DynMatcher::VariadicOperator varOp, + ArrayRef innerMatchers) const; + }; // Payload interface to be specialized by each matcher type. It follows a // similar interface as VariantMatcher itself. @@ -43,6 +48,9 @@ class VariantMatcher { // Clones the provided matcher. static VariantMatcher SingleMatcher(DynMatcher matcher); + static VariantMatcher + VariadicOperatorMatcher(DynMatcher::VariadicOperator varOp, + ArrayRef args); // Makes the matcher the "null" matcher. void reset(); @@ -61,6 +69,7 @@ class VariantMatcher { : value(std::move(value)) {} class SinglePayload; + class VariadicOpPayload; std::shared_ptr value; }; diff --git a/mlir/lib/Query/Matcher/CMakeLists.txt b/mlir/lib/Query/Matcher/CMakeLists.txt index 629479bf7adc1..ba202762fdfbb 100644 --- a/mlir/lib/Query/Matcher/CMakeLists.txt +++ b/mlir/lib/Query/Matcher/CMakeLists.txt @@ -1,5 +1,6 @@ add_mlir_library(MLIRQueryMatcher MatchFinder.cpp + MatchersInternal.cpp Parser.cpp RegistryManager.cpp VariantValue.cpp diff --git a/mlir/lib/Query/Matcher/MatchersInternal.cpp b/mlir/lib/Query/Matcher/MatchersInternal.cpp new file mode 100644 index 0000000000000..01f412ade846b --- /dev/null +++ b/mlir/lib/Query/Matcher/MatchersInternal.cpp @@ -0,0 +1,33 @@ +//===--- MatchersInternal.cpp----------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Query/Matcher/MatchersInternal.h" +#include "llvm/ADT/SetVector.h" + +namespace mlir::query::matcher { + +namespace internal { + +bool allOfVariadicOperator(Operation *op, SetVector *matchedOps, + ArrayRef innerMatchers) { + return llvm::all_of(innerMatchers, [&](const DynMatcher &matcher) { + if (matchedOps) + return matcher.match(op, *matchedOps); + return matcher.match(op); + }); +} +bool anyOfVariadicOperator(Operation *op, SetVector *matchedOps, + ArrayRef innerMatchers) { + return llvm::any_of(innerMatchers, [&](const DynMatcher &matcher) { + if (matchedOps) + return matcher.match(op, *matchedOps); + return matcher.match(op); + }); +} +} // namespace internal +} // namespace mlir::query::matcher diff --git a/mlir/lib/Query/Matcher/RegistryManager.cpp b/mlir/lib/Query/Matcher/RegistryManager.cpp index 4b511c5f009e7..08b610453b11a 100644 --- a/mlir/lib/Query/Matcher/RegistryManager.cpp +++ b/mlir/lib/Query/Matcher/RegistryManager.cpp @@ -64,7 +64,7 @@ std::vector RegistryManager::getAcceptedCompletionTypes( unsigned argNumber = ctxEntry.second; std::vector nextTypeSet; - if (argNumber < ctor->getNumArgs()) + if (ctor->isVariadic() || argNumber < ctor->getNumArgs()) ctor->getArgKinds(argNumber, nextTypeSet); typeSet.insert(nextTypeSet.begin(), nextTypeSet.end()); @@ -83,7 +83,7 @@ RegistryManager::getMatcherCompletions(llvm::ArrayRef acceptedTypes, const internal::MatcherDescriptor &matcher = *m.getValue(); llvm::StringRef name = m.getKey(); - unsigned numArgs = matcher.getNumArgs(); + unsigned numArgs = matcher.isVariadic() ? 1 : matcher.getNumArgs(); std::vector> argKinds(numArgs); for (const ArgKind &kind : acceptedTypes) { @@ -115,6 +115,9 @@ RegistryManager::getMatcherCompletions(llvm::ArrayRef acceptedTypes, } } + if (matcher.isVariadic()) + os << ",..."; + os << ")"; typedText += "("; diff --git a/mlir/lib/Query/Matcher/VariantValue.cpp b/mlir/lib/Query/Matcher/VariantValue.cpp index 1cb2d48f9d56f..7bf4774dba830 100644 --- a/mlir/lib/Query/Matcher/VariantValue.cpp +++ b/mlir/lib/Query/Matcher/VariantValue.cpp @@ -27,12 +27,64 @@ class VariantMatcher::SinglePayload : public VariantMatcher::Payload { DynMatcher matcher; }; +class VariantMatcher::VariadicOpPayload : public VariantMatcher::Payload { +public: + VariadicOpPayload(DynMatcher::VariadicOperator varOp, + std::vector args) + : varOp(varOp), args(std::move(args)) {} + + std::optional getDynMatcher() const override { + std::vector dynMatchers; + for (auto variantMatcher : args) { + std::optional dynMatcher = variantMatcher.getDynMatcher(); + if (dynMatcher) + dynMatchers.push_back(dynMatcher.value()); + } + auto result = DynMatcher::constructVariadic(varOp, dynMatchers); + return *result; + } + + std::string getTypeAsString() const override { + std::string inner; + llvm::interleave( + args, [&](auto const &arg) { inner += arg.getTypeAsString(); }, + [&] { inner += " & "; }); + return inner; + } + +private: + const DynMatcher::VariadicOperator varOp; + const std::vector args; +}; + VariantMatcher::VariantMatcher() = default; VariantMatcher VariantMatcher::SingleMatcher(DynMatcher matcher) { return VariantMatcher(std::make_shared(std::move(matcher))); } +VariantMatcher +VariantMatcher::VariadicOperatorMatcher(DynMatcher::VariadicOperator varOp, + ArrayRef args) { + return VariantMatcher( + std::make_shared(varOp, std::move(args))); +} + +std::optional VariantMatcher::MatcherOps::constructVariadicOperator( + DynMatcher::VariadicOperator varOp, + ArrayRef innerMatchers) const { + std::vector dynMatchers; + for (const auto &innerMatcher : innerMatchers) { + if (!innerMatcher.value) + return std::nullopt; + std::optional inner = innerMatcher.value->getDynMatcher(); + if (!inner) + return std::nullopt; + dynMatchers.push_back(*inner); + } + return *DynMatcher::constructVariadic(varOp, dynMatchers); +} + std::optional VariantMatcher::getDynMatcher() const { return value ? value->getDynMatcher() : std::nullopt; } diff --git a/mlir/lib/Query/Query.cpp b/mlir/lib/Query/Query.cpp index 803284d6df86a..637e1f3cdef87 100644 --- a/mlir/lib/Query/Query.cpp +++ b/mlir/lib/Query/Query.cpp @@ -10,6 +10,7 @@ #include "QueryParser.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/IR/IRMapping.h" +#include "mlir/IR/Verifier.h" #include "mlir/Query/Matcher/MatchFinder.h" #include "mlir/Query/QuerySession.h" #include "llvm/ADT/SetVector.h" @@ -68,6 +69,8 @@ static Operation *extractFunction(std::vector &ops, // Clone operations and build function body std::vector clonedOps; std::vector clonedVals; + // TODO: Handle extraction of operations with compute payloads defined via + // regions. for (Operation *slicedOp : slice) { Operation *clonedOp = clonedOps.emplace_back(builder.clone(*slicedOp, mapper)); @@ -129,6 +132,8 @@ LogicalResult MatchQuery::run(llvm::raw_ostream &os, QuerySession &qs) const { finder.flattenMatchedOps(matches); Operation *function = extractFunction(flattenedMatches, rootOp->getContext(), functionName); + if (failed(verify(function))) + return mlir::failure(); os << "\n" << *function << "\n\n"; function->erase(); return mlir::success(); diff --git a/mlir/test/mlir-query/complex-test.mlir b/mlir/test/mlir-query/backward-slice-union.mlir similarity index 71% rename from mlir/test/mlir-query/complex-test.mlir rename to mlir/test/mlir-query/backward-slice-union.mlir index ad96f03747a43..f8f88c2043749 100644 --- a/mlir/test/mlir-query/complex-test.mlir +++ b/mlir/test/mlir-query/backward-slice-union.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-query %s -c "m getAllDefinitions(hasOpName(\"arith.addf\"),2)" | FileCheck %s +// RUN: mlir-query %s -c "m anyOf(getAllDefinitions(hasOpName(\"arith.addf\"),2),getAllDefinitions(hasOpName(\"tensor.extract\"),1))" | FileCheck %s #map = affine_map<(d0, d1) -> (d0, d1)> func.func @slice_use_from_above(%arg0: tensor<5x5xf32>, %arg1: tensor<5x5xf32>) { @@ -19,14 +19,23 @@ func.func @slice_use_from_above(%arg0: tensor<5x5xf32>, %arg1: tensor<5x5xf32>) } // CHECK: Match #1: - // CHECK: %[[LINALG:.*]] = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel"]} // CHECK-SAME: ins(%arg0 : tensor<5x5xf32>) outs(%arg1 : tensor<5x5xf32>) + +// CHECK: {{.*}}.mlir:7:10: note: "root" binds here // CHECK: %[[ADDF1:.*]] = arith.addf %in, %in : f32 // CHECK: Match #2: +// CHECK: %[[COLLAPSED:.*]] = tensor.collapse_shape %[[LINALG]] {{\[\[.*\]\]}} : tensor<5x5xf32> into tensor<25xf32> +// CHECK: %[[C2:.*]] = arith.constant {{.*}} : index +// CHECK: {{.*}}.mlir:14:18: note: "root" binds here +// CHECK: %[[EXTRACTED:.*]] = tensor.extract %[[COLLAPSED]][%[[C2]]] : tensor<25xf32> + +// CHECK: Match #3: // CHECK: %[[COLLAPSED:.*]] = tensor.collapse_shape %[[LINALG]] {{\[\[.*\]\]}} : tensor<5x5xf32> into tensor<25xf32> // CHECK: %[[C2:.*]] = arith.constant {{.*}} : index // CHECK: %[[EXTRACTED:.*]] = tensor.extract %[[COLLAPSED]][%[[C2]]] : tensor<25xf32> + +// CHECK: {{.*}}.mlir:15:10: note: "root" binds here // CHECK: %[[ADDF2:.*]] = arith.addf %[[EXTRACTED]], %[[EXTRACTED]] : f32 diff --git a/mlir/test/mlir-query/forward-slice-by-predicate.mlir b/mlir/test/mlir-query/forward-slice-by-predicate.mlir new file mode 100644 index 0000000000000..e11378da89d9f --- /dev/null +++ b/mlir/test/mlir-query/forward-slice-by-predicate.mlir @@ -0,0 +1,27 @@ +// RUN: mlir-query %s -c "m getUsersByPredicate(anyOf(hasOpName(\"memref.alloc\"),isConstantOp()),anyOf(hasOpName(\"affine.load\"), hasOpName(\"memref.dealloc\")),true)" | FileCheck %s + +func.func @slice_depth1_loop_nest_with_offsets() { + %0 = memref.alloc() : memref<100xf32> + %cst = arith.constant 7.000000e+00 : f32 + affine.for %i0 = 0 to 16 { + %a0 = affine.apply affine_map<(d0) -> (d0 + 2)>(%i0) + affine.store %cst, %0[%a0] : memref<100xf32> + } + affine.for %i1 = 4 to 8 { + %a1 = affine.apply affine_map<(d0) -> (d0 - 1)>(%i1) + %1 = affine.load %0[%a1] : memref<100xf32> + } + return +} + +// CHECK: Match #1: +// CHECK: {{.*}}.mlir:4:8: note: "root" binds here +// CHECK: %[[ALLOC:.*]] = memref.alloc() : memref<100xf32> + +// CHECK: affine.store %cst, %0[%a0] : memref<100xf32> + +// CHECK: Match #2: +// CHECK: {{.*}}.mlir:5:10: note: "root" binds here +// CHECK: %[[CST:.*]] = arith.constant 7.000000e+00 : f32 + +// CHECK: affine.store %[[CST]], %0[%a0] : memref<100xf32> diff --git a/mlir/test/mlir-query/logical-operator-test.mlir b/mlir/test/mlir-query/logical-operator-test.mlir new file mode 100644 index 0000000000000..ac05428287abd --- /dev/null +++ b/mlir/test/mlir-query/logical-operator-test.mlir @@ -0,0 +1,11 @@ +// RUN: mlir-query %s -c "m allOf(hasOpName(\"memref.alloca\"), hasOpAttrName(\"alignment\"))" | FileCheck %s + +func.func @dynamic_alloca(%arg0: index, %arg1: index) -> memref { + %0 = memref.alloca(%arg0, %arg1) : memref + memref.alloca(%arg0, %arg1) {alignment = 32} : memref + return %0 : memref +} + +// CHECK: Match #1: +// CHECK: {{.*}}.mlir:5:3: note: "root" binds here +// CHECK: memref.alloca(%arg0, %arg1) {alignment = 32} : memref diff --git a/mlir/test/mlir-query/slice-function-extraction.mlir b/mlir/test/mlir-query/slice-function-extraction.mlir new file mode 100644 index 0000000000000..e55d5e77c5736 --- /dev/null +++ b/mlir/test/mlir-query/slice-function-extraction.mlir @@ -0,0 +1,29 @@ +// RUN: mlir-query %s -c "m getDefinitionsByPredicate(hasOpName(\"memref.store\"),hasOpName(\"memref.alloc\"),true,false,false).extract(\"backward_slice\")" | FileCheck %s + +// CHECK: func.func @backward_slice(%{{.*}}: memref<10xf32>) -> (f32, index, index, f32, index, index, f32) { +// CHECK: %[[CST0:.*]] = arith.constant 0.000000e+00 : f32 +// CHECK-NEXT: %[[C0:.*]] = arith.constant 0 : index +// CHECK-NEXT: %[[I0:.*]] = affine.apply affine_map<()[s0] -> (s0)>()[%[[C0]]] +// CHECK-NEXT: memref.store %[[CST0]], %{{.*}}[%[[I0]]] : memref<10xf32> +// CHECK-NEXT: %[[CST2:.*]] = arith.constant 0.000000e+00 : f32 +// CHECK-NEXT: %[[I1:.*]] = affine.apply affine_map<() -> (0)>() +// CHECK-NEXT: memref.store %[[CST2]], %{{.*}}[%[[I1]]] : memref<10xf32> +// CHECK-NEXT: %[[C1:.*]] = arith.constant 0 : index +// CHECK-NEXT: %[[LOAD:.*]] = memref.load %{{.*}}[%[[C1]]] : memref<10xf32> +// CHECK-NEXT: memref.store %[[LOAD]], %{{.*}}[%[[C1]]] : memref<10xf32> +// CHECK-NEXT: return %[[CST0]], %[[C0]], %[[I0]], %[[CST2]], %[[I1]], %[[C1]], %[[LOAD]] : f32, index, index, f32, index, index, f32 + +func.func @slicing_memref_store_trivial() { + %0 = memref.alloc() : memref<10xf32> + %c0 = arith.constant 0 : index + %cst = arith.constant 0.000000e+00 : f32 + affine.for %i1 = 0 to 10 { + %1 = affine.apply affine_map<()[s0] -> (s0)>()[%c0] + memref.store %cst, %0[%1] : memref<10xf32> + %2 = memref.load %0[%c0] : memref<10xf32> + %3 = affine.apply affine_map<()[] -> (0)>()[] + memref.store %cst, %0[%3] : memref<10xf32> + memref.store %2, %0[%c0] : memref<10xf32> + } + return +} diff --git a/mlir/tools/mlir-query/mlir-query.cpp b/mlir/tools/mlir-query/mlir-query.cpp index 78c0ec97c0cdf..8a17a33c61838 100644 --- a/mlir/tools/mlir-query/mlir-query.cpp +++ b/mlir/tools/mlir-query/mlir-query.cpp @@ -40,12 +40,22 @@ int main(int argc, char **argv) { query::matcher::Registry matcherRegistry; // Matchers registered in alphabetical order for consistency: + matcherRegistry.registerMatcher("allOf", query::matcher::internal::allOf); + matcherRegistry.registerMatcher("anyOf", query::matcher::internal::anyOf); + matcherRegistry.registerMatcher( + "getAllDefinitions", + query::matcher::m_GetAllDefinitions); matcherRegistry.registerMatcher( "getDefinitions", query::matcher::m_GetDefinitions); matcherRegistry.registerMatcher( - "getAllDefinitions", - query::matcher::m_GetAllDefinitions); + "getDefinitionsByPredicate", + query::matcher::m_GetDefinitionsByPredicate); + matcherRegistry.registerMatcher( + "getUsersByPredicate", + query::matcher::m_GetUsersByPredicate); matcherRegistry.registerMatcher("hasOpAttrName", static_cast(m_Attr)); matcherRegistry.registerMatcher("hasOpName", static_cast(m_Op));