Skip to content

Conversation

@chios202
Copy link
Contributor

@chios202 chios202 commented Jun 27, 2025

An uninitialized variable that caused a crash (https://lab.llvm.org/buildbot/#/builders/164/builds/11004) was identified using the memory analyzer, leading to the reversion of #141423. This pull request reapplies the previously reverted changes and includes the fix, which has been tested locally following the steps at https://github.com/google/sanitizers/wiki/SanitizerBotReproduceBuild.

Note: the fix is included as part of the second commit

chios202 added 2 commits June 27, 2025 21:16
	Limit backward-slice with nested matching
	Add variadic operators

Add test cases
	Add test cases for variadic matchers
	Relocate variadic matchers

use signed for arithemtic & avoid copy

Add verifier check for extract function

Add slicing function extraction test; improve documentation; use lowercase for errors
@chios202 chios202 force-pushed the reapply-141423-mlir-query-combinators-plus-fix branch from 3d88b01 to fc67153 Compare June 27, 2025 21:19
@chios202 chios202 marked this pull request as ready for review June 27, 2025 21:19
@llvmbot llvmbot added mlir:core MLIR Core Infrastructure mlir labels Jun 27, 2025
@llvmbot
Copy link
Member

llvmbot commented Jun 27, 2025

@llvm/pr-subscribers-mlir

Author: Denzel-Brian Budii (chios202)

Changes

An uninitialized variable that caused a crash (https://lab.llvm.org/buildbot/#/builders/164/builds/11004) was identified using the memory analyzer, leading to the reversion of #141423. This pull request reapplies the previously reverted changes and includes the fix, which has been tested locally following the steps at https://github.com/google/sanitizers/wiki/SanitizerBotReproduceBuild.


Patch is 31.52 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/146156.diff

15 Files Affected:

  • (modified) mlir/include/mlir/Query/Matcher/Marshallers.h (+61)
  • (modified) mlir/include/mlir/Query/Matcher/MatchFinder.h (+3-1)
  • (modified) mlir/include/mlir/Query/Matcher/MatchersInternal.h (+112-4)
  • (modified) mlir/include/mlir/Query/Matcher/SliceMatchers.h (+99-5)
  • (modified) mlir/include/mlir/Query/Matcher/VariantValue.h (+10-1)
  • (modified) mlir/lib/Query/Matcher/CMakeLists.txt (+1)
  • (added) mlir/lib/Query/Matcher/MatchersInternal.cpp (+33)
  • (modified) mlir/lib/Query/Matcher/RegistryManager.cpp (+5-2)
  • (modified) mlir/lib/Query/Matcher/VariantValue.cpp (+54-2)
  • (modified) mlir/lib/Query/Query.cpp (+5)
  • (renamed) mlir/test/mlir-query/backward-slice-union.mlir (+11-2)
  • (added) mlir/test/mlir-query/forward-slice-by-predicate.mlir (+27)
  • (added) mlir/test/mlir-query/logical-operator-test.mlir (+11)
  • (added) mlir/test/mlir-query/slice-function-extraction.mlir (+29)
  • (modified) mlir/tools/mlir-query/mlir-query.cpp (+12-2)
diff --git a/mlir/include/mlir/Query/Matcher/Marshallers.h b/mlir/include/mlir/Query/Matcher/Marshallers.h
index 012bf7b9ec4a9..5fe6965f32efb 100644
--- a/mlir/include/mlir/Query/Matcher/Marshallers.h
+++ b/mlir/include/mlir/Query/Matcher/Marshallers.h
@@ -108,6 +108,9 @@ class MatcherDescriptor {
                                 const llvm::ArrayRef<ParserValue> args,
                                 Diagnostics *error) const = 0;
 
+  // If the matcher is variadic, it can take any number of arguments.
+  virtual bool isVariadic() const = 0;
+
   // Returns the number of arguments accepted by the matcher.
   virtual unsigned getNumArgs() const = 0;
 
@@ -140,6 +143,8 @@ class FixedArgCountMatcherDescriptor : public MatcherDescriptor {
     return marshaller(matcherFunc, matcherName, nameRange, args, error);
   }
 
+  bool isVariadic() const override { return false; }
+
   unsigned getNumArgs() const override { return argKinds.size(); }
 
   void getArgKinds(unsigned argNo, std::vector<ArgKind> &kinds) const override {
@@ -153,6 +158,54 @@ class FixedArgCountMatcherDescriptor : public MatcherDescriptor {
   const std::vector<ArgKind> argKinds;
 };
 
+class VariadicOperatorMatcherDescriptor : public MatcherDescriptor {
+public:
+  using VarOp = DynMatcher::VariadicOperator;
+  VariadicOperatorMatcherDescriptor(unsigned minCount, unsigned maxCount,
+                                    VarOp varOp, StringRef matcherName)
+      : minCount(minCount), maxCount(maxCount), varOp(varOp),
+        matcherName(matcherName) {}
+
+  VariantMatcher create(SourceRange nameRange, ArrayRef<ParserValue> args,
+                        Diagnostics *error) const override {
+    if (args.size() < minCount || maxCount < args.size()) {
+      addError(error, nameRange, ErrorType::RegistryWrongArgCount,
+               {llvm::Twine("requires between "), llvm::Twine(minCount),
+                llvm::Twine(" and "), llvm::Twine(maxCount),
+                llvm::Twine(" args, got "), llvm::Twine(args.size())});
+      return VariantMatcher();
+    }
+
+    std::vector<VariantMatcher> innerArgs;
+    for (int64_t i = 0, e = args.size(); i != e; ++i) {
+      const ParserValue &arg = args[i];
+      const VariantValue &value = arg.value;
+      if (!value.isMatcher()) {
+        addError(error, arg.range, ErrorType::RegistryWrongArgType,
+                 {llvm::Twine(i + 1), llvm::Twine("matcher: "),
+                  llvm::Twine(value.getTypeAsString())});
+        return VariantMatcher();
+      }
+      innerArgs.push_back(value.getMatcher());
+    }
+    return VariantMatcher::VariadicOperatorMatcher(varOp, std::move(innerArgs));
+  }
+
+  bool isVariadic() const override { return true; }
+
+  unsigned getNumArgs() const override { return 0; }
+
+  void getArgKinds(unsigned argNo, std::vector<ArgKind> &kinds) const override {
+    kinds.push_back(ArgKind(ArgKind::Matcher));
+  }
+
+private:
+  const unsigned minCount;
+  const unsigned maxCount;
+  const VarOp varOp;
+  const StringRef matcherName;
+};
+
 // Helper function to check if argument count matches expected count
 inline bool checkArgCount(SourceRange nameRange, size_t expectedArgCount,
                           llvm::ArrayRef<ParserValue> args,
@@ -224,6 +277,14 @@ makeMatcherAutoMarshall(ReturnType (*matcherFunc)(ArgTypes...),
       reinterpret_cast<void (*)()>(matcherFunc), matcherName, argKinds);
 }
 
+// Variadic operator overload.
+template <unsigned MinCount, unsigned MaxCount>
+std::unique_ptr<MatcherDescriptor>
+makeMatcherAutoMarshall(VariadicOperatorMatcherFunc<MinCount, MaxCount> func,
+                        StringRef matcherName) {
+  return std::make_unique<VariadicOperatorMatcherDescriptor>(
+      MinCount, MaxCount, func.varOp, matcherName);
+}
 } // namespace mlir::query::matcher::internal
 
 #endif // MLIR_TOOLS_MLIRQUERY_MATCHER_MARSHALLERS_H
diff --git a/mlir/include/mlir/Query/Matcher/MatchFinder.h b/mlir/include/mlir/Query/Matcher/MatchFinder.h
index f8abf20ef60bb..6d06ca13d1344 100644
--- a/mlir/include/mlir/Query/Matcher/MatchFinder.h
+++ b/mlir/include/mlir/Query/Matcher/MatchFinder.h
@@ -21,7 +21,9 @@
 
 namespace mlir::query::matcher {
 
-/// A class that provides utilities to find operations in the IR.
+/// Finds and collects matches from the IR. After construction
+/// `collectMatches` can be used to traverse the IR and apply
+/// matchers.
 class MatchFinder {
 
 public:
diff --git a/mlir/include/mlir/Query/Matcher/MatchersInternal.h b/mlir/include/mlir/Query/Matcher/MatchersInternal.h
index 183b2514e109f..88109430b6feb 100644
--- a/mlir/include/mlir/Query/Matcher/MatchersInternal.h
+++ b/mlir/include/mlir/Query/Matcher/MatchersInternal.h
@@ -8,11 +8,11 @@
 //
 // Implements the base layer of the matcher framework.
 //
-// Matchers are methods that return a Matcher which provides a method one of the
-// following methods: match(Operation *op), match(Operation *op,
-// SetVector<Operation *> &matchedOps)
+// Matchers are methods that return a Matcher which provides a
+// `match(...)` method whose parameters define the context of the match.
+// Support includes simple (unary) matchers as well as matcher combinators
+// (anyOf, allOf, etc.)
 //
-// The matcher functions are defined in include/mlir/IR/Matchers.h.
 // This file contains the wrapper classes needed to construct matchers for
 // mlir-query.
 //
@@ -25,6 +25,15 @@
 #include "llvm/ADT/IntrusiveRefCntPtr.h"
 
 namespace mlir::query::matcher {
+class DynMatcher;
+namespace internal {
+
+bool allOfVariadicOperator(Operation *op, SetVector<Operation *> *matchedOps,
+                           ArrayRef<DynMatcher> innerMatchers);
+bool anyOfVariadicOperator(Operation *op, SetVector<Operation *> *matchedOps,
+                           ArrayRef<DynMatcher> innerMatchers);
+
+} // namespace internal
 
 // Defaults to false if T has no match() method with the signature:
 // match(Operation* op).
@@ -84,6 +93,27 @@ class MatcherFnImpl : public MatcherInterface {
   MatcherFn matcherFn;
 };
 
+// VariadicMatcher takes a vector of Matchers and returns true if any Matchers
+// match the given operation.
+using VariadicOperatorFunction = bool (*)(Operation *op,
+                                          SetVector<Operation *> *matchedOps,
+                                          ArrayRef<DynMatcher> innerMatchers);
+
+template <VariadicOperatorFunction Func>
+class VariadicMatcher : public MatcherInterface {
+public:
+  VariadicMatcher(std::vector<DynMatcher> matchers)
+      : matchers(std::move(matchers)) {}
+
+  bool match(Operation *op) override { return Func(op, nullptr, matchers); }
+  bool match(Operation *op, SetVector<Operation *> &matchedOps) override {
+    return Func(op, &matchedOps, matchers);
+  }
+
+private:
+  std::vector<DynMatcher> matchers;
+};
+
 // Matcher wraps a MatcherInterface implementation and provides match()
 // methods that redirect calls to the underlying implementation.
 class DynMatcher {
@@ -92,6 +122,31 @@ class DynMatcher {
   DynMatcher(MatcherInterface *implementation)
       : implementation(implementation) {}
 
+  // Construct from a variadic function.
+  enum VariadicOperator {
+    // Matches operations for which all provided matchers match.
+    AllOf,
+    // Matches operations for which at least one of the provided matchers
+    // matches.
+    AnyOf
+  };
+
+  static std::unique_ptr<DynMatcher>
+  constructVariadic(VariadicOperator Op,
+                    std::vector<DynMatcher> innerMatchers) {
+    switch (Op) {
+    case AllOf:
+      return std::make_unique<DynMatcher>(
+          new VariadicMatcher<internal::allOfVariadicOperator>(
+              std::move(innerMatchers)));
+    case AnyOf:
+      return std::make_unique<DynMatcher>(
+          new VariadicMatcher<internal::anyOfVariadicOperator>(
+              std::move(innerMatchers)));
+    }
+    llvm_unreachable("Invalid Op value.");
+  }
+
   template <typename MatcherFn>
   static std::unique_ptr<DynMatcher>
   constructDynMatcherFromMatcherFn(MatcherFn &matcherFn) {
@@ -113,6 +168,59 @@ class DynMatcher {
   std::string functionName;
 };
 
+// VariadicOperatorMatcher related types.
+template <typename... Ps>
+class VariadicOperatorMatcher {
+public:
+  VariadicOperatorMatcher(DynMatcher::VariadicOperator varOp, Ps &&...params)
+      : varOp(varOp), params(std::forward<Ps>(params)...) {}
+
+  operator std::unique_ptr<DynMatcher>() const & {
+    return DynMatcher::constructVariadic(
+        varOp, getMatchers(std::index_sequence_for<Ps...>()));
+  }
+
+  operator std::unique_ptr<DynMatcher>() && {
+    return DynMatcher::constructVariadic(
+        varOp, std::move(*this).getMatchers(std::index_sequence_for<Ps...>()));
+  }
+
+private:
+  // Helper method to unpack the tuple into a vector.
+  template <std::size_t... Is>
+  std::vector<DynMatcher> getMatchers(std::index_sequence<Is...>) const & {
+    return {DynMatcher(std::get<Is>(params))...};
+  }
+
+  template <std::size_t... Is>
+  std::vector<DynMatcher> getMatchers(std::index_sequence<Is...>) && {
+    return {DynMatcher(std::get<Is>(std::move(params)))...};
+  }
+
+  const DynMatcher::VariadicOperator varOp;
+  std::tuple<Ps...> params;
+};
+
+// Overloaded function object to generate VariadicOperatorMatcher objects from
+// arbitrary matchers.
+template <unsigned MinCount, unsigned MaxCount>
+struct VariadicOperatorMatcherFunc {
+  DynMatcher::VariadicOperator varOp;
+
+  template <typename... Ms>
+  VariadicOperatorMatcher<Ms...> operator()(Ms &&...Ps) const {
+    static_assert(MinCount <= sizeof...(Ms) && sizeof...(Ms) <= MaxCount,
+                  "invalid number of parameters for variadic matcher");
+    return VariadicOperatorMatcher<Ms...>(varOp, std::forward<Ms>(Ps)...);
+  }
+};
+
+namespace internal {
+const VariadicOperatorMatcherFunc<1, std::numeric_limits<unsigned>::max()>
+    anyOf = {DynMatcher::AnyOf};
+const VariadicOperatorMatcherFunc<1, std::numeric_limits<unsigned>::max()>
+    allOf = {DynMatcher::AllOf};
+} // namespace internal
 } // namespace mlir::query::matcher
 
 #endif // MLIR_TOOLS_MLIRQUERY_MATCHER_MATCHERSINTERNAL_H
diff --git a/mlir/include/mlir/Query/Matcher/SliceMatchers.h b/mlir/include/mlir/Query/Matcher/SliceMatchers.h
index 441205b3a9615..7181648f06f89 100644
--- a/mlir/include/mlir/Query/Matcher/SliceMatchers.h
+++ b/mlir/include/mlir/Query/Matcher/SliceMatchers.h
@@ -6,7 +6,8 @@
 //
 //===----------------------------------------------------------------------===//
 //
-// This file provides matchers for MLIRQuery that peform slicing analysis
+// This file defines slicing-analysis matchers that extend and abstract the
+// core implementations from `SliceAnalysis.h`.
 //
 //===----------------------------------------------------------------------===//
 
@@ -16,9 +17,9 @@
 #include "mlir/Analysis/SliceAnalysis.h"
 #include "mlir/IR/Operation.h"
 
-/// A matcher encapsulating `getBackwardSlice` method from SliceAnalysis.h.
-/// Additionally, it limits the slice computation to a certain depth level using
-/// a custom filter.
+/// Computes the backward-slice of all transitive defs reachable from `rootOp`,
+/// if `innerMatcher` matches. The traversal stops once the desired depth level
+/// is reached.
 ///
 /// Example: starting from node 9, assuming the matcher
 /// computes the slice for the first two depth levels:
@@ -119,6 +120,77 @@ bool BackwardSliceMatcher<Matcher>::matches(
                            : backwardSlice.size() >= 1;
 }
 
+/// Computes the backward-slice of all transitive defs reachable from `rootOp`,
+/// if `innerMatcher` matches. Traversal stops where `filterMatcher` matches.
+template <typename BaseMatcher, typename Filter>
+class PredicateBackwardSliceMatcher {
+public:
+  PredicateBackwardSliceMatcher(BaseMatcher innerMatcher, Filter filterMatcher,
+                                bool inclusive, bool omitBlockArguments,
+                                bool omitUsesFromAbove)
+      : innerMatcher(std::move(innerMatcher)),
+        filterMatcher(std::move(filterMatcher)), inclusive(inclusive),
+        omitBlockArguments(omitBlockArguments),
+        omitUsesFromAbove(omitUsesFromAbove) {}
+
+  bool match(Operation *rootOp, SetVector<Operation *> &backwardSlice) {
+    backwardSlice.clear();
+    BackwardSliceOptions options;
+    options.inclusive = inclusive;
+    options.omitUsesFromAbove = omitUsesFromAbove;
+    options.omitBlockArguments = omitBlockArguments;
+    if (innerMatcher.match(rootOp)) {
+      options.filter = [&](Operation *subOp) {
+        return !filterMatcher.match(subOp);
+      };
+      LogicalResult result = getBackwardSlice(rootOp, &backwardSlice, options);
+      assert(result.succeeded() && "expected backward slice to succeed");
+      (void)result;
+      return options.inclusive ? backwardSlice.size() > 1
+                               : backwardSlice.size() >= 1;
+    }
+    return false;
+  }
+
+private:
+  BaseMatcher innerMatcher;
+  Filter filterMatcher;
+  bool inclusive;
+  bool omitBlockArguments;
+  bool omitUsesFromAbove;
+};
+
+/// Computes the forward-slice of all users reachable from `rootOp`,
+/// if `innerMatcher` matches. Traversal stops where `filterMatcher` matches.
+template <typename BaseMatcher, typename Filter>
+class PredicateForwardSliceMatcher {
+public:
+  PredicateForwardSliceMatcher(BaseMatcher innerMatcher, Filter filterMatcher,
+                               bool inclusive)
+      : innerMatcher(std::move(innerMatcher)),
+        filterMatcher(std::move(filterMatcher)), inclusive(inclusive) {}
+
+  bool match(Operation *rootOp, SetVector<Operation *> &forwardSlice) {
+    forwardSlice.clear();
+    ForwardSliceOptions options;
+    options.inclusive = inclusive;
+    if (innerMatcher.match(rootOp)) {
+      options.filter = [&](Operation *subOp) {
+        return !filterMatcher.match(subOp);
+      };
+      getForwardSlice(rootOp, &forwardSlice, options);
+      return options.inclusive ? forwardSlice.size() > 1
+                               : forwardSlice.size() >= 1;
+    }
+    return false;
+  }
+
+private:
+  BaseMatcher innerMatcher;
+  Filter filterMatcher;
+  bool inclusive;
+};
+
 /// Matches transitive defs of a top-level operation up to N levels.
 template <typename Matcher>
 inline BackwardSliceMatcher<Matcher>
@@ -130,7 +202,7 @@ m_GetDefinitions(Matcher innerMatcher, int64_t maxDepth, bool inclusive,
                                        omitUsesFromAbove);
 }
 
-/// Matches all transitive defs of a top-level operation up to N levels
+/// Matches all transitive defs of a top-level operation up to N levels.
 template <typename Matcher>
 inline BackwardSliceMatcher<Matcher> m_GetAllDefinitions(Matcher innerMatcher,
                                                          int64_t maxDepth) {
@@ -139,6 +211,28 @@ inline BackwardSliceMatcher<Matcher> m_GetAllDefinitions(Matcher innerMatcher,
                                        false, false);
 }
 
+/// Matches all transitive defs of a top-level operation and stops where
+/// `filterMatcher` rejects.
+template <typename BaseMatcher, typename Filter>
+inline PredicateBackwardSliceMatcher<BaseMatcher, Filter>
+m_GetDefinitionsByPredicate(BaseMatcher innerMatcher, Filter filterMatcher,
+                            bool inclusive, bool omitBlockArguments,
+                            bool omitUsesFromAbove) {
+  return PredicateBackwardSliceMatcher<BaseMatcher, Filter>(
+      std::move(innerMatcher), std::move(filterMatcher), inclusive,
+      omitBlockArguments, omitUsesFromAbove);
+}
+
+/// Matches all users of a top-level operation and stops where
+/// `filterMatcher` rejects.
+template <typename BaseMatcher, typename Filter>
+inline PredicateForwardSliceMatcher<BaseMatcher, Filter>
+m_GetUsersByPredicate(BaseMatcher innerMatcher, Filter filterMatcher,
+                      bool inclusive) {
+  return PredicateForwardSliceMatcher<BaseMatcher, Filter>(
+      std::move(innerMatcher), std::move(filterMatcher), inclusive);
+}
+
 } // namespace mlir::query::matcher
 
 #endif // MLIR_TOOLS_MLIRQUERY_MATCHERS_SLICEMATCHERS_H
diff --git a/mlir/include/mlir/Query/Matcher/VariantValue.h b/mlir/include/mlir/Query/Matcher/VariantValue.h
index 98c0a18e25101..1a47576de1841 100644
--- a/mlir/include/mlir/Query/Matcher/VariantValue.h
+++ b/mlir/include/mlir/Query/Matcher/VariantValue.h
@@ -26,7 +26,12 @@ enum class ArgKind { Boolean, Matcher, Signed, String };
 // A variant matcher object to abstract simple and complex matchers into a
 // single object type.
 class VariantMatcher {
-  class MatcherOps;
+  class MatcherOps {
+  public:
+    std::optional<DynMatcher>
+    constructVariadicOperator(DynMatcher::VariadicOperator varOp,
+                              ArrayRef<VariantMatcher> innerMatchers) const;
+  };
 
   // Payload interface to be specialized by each matcher type. It follows a
   // similar interface as VariantMatcher itself.
@@ -43,6 +48,9 @@ class VariantMatcher {
 
   // Clones the provided matcher.
   static VariantMatcher SingleMatcher(DynMatcher matcher);
+  static VariantMatcher
+  VariadicOperatorMatcher(DynMatcher::VariadicOperator varOp,
+                          ArrayRef<VariantMatcher> args);
 
   // Makes the matcher the "null" matcher.
   void reset();
@@ -61,6 +69,7 @@ class VariantMatcher {
       : value(std::move(value)) {}
 
   class SinglePayload;
+  class VariadicOpPayload;
 
   std::shared_ptr<const Payload> value;
 };
diff --git a/mlir/lib/Query/Matcher/CMakeLists.txt b/mlir/lib/Query/Matcher/CMakeLists.txt
index 629479bf7adc1..ba202762fdfbb 100644
--- a/mlir/lib/Query/Matcher/CMakeLists.txt
+++ b/mlir/lib/Query/Matcher/CMakeLists.txt
@@ -1,5 +1,6 @@
 add_mlir_library(MLIRQueryMatcher
   MatchFinder.cpp
+  MatchersInternal.cpp
   Parser.cpp
   RegistryManager.cpp
   VariantValue.cpp
diff --git a/mlir/lib/Query/Matcher/MatchersInternal.cpp b/mlir/lib/Query/Matcher/MatchersInternal.cpp
new file mode 100644
index 0000000000000..01f412ade846b
--- /dev/null
+++ b/mlir/lib/Query/Matcher/MatchersInternal.cpp
@@ -0,0 +1,33 @@
+//===--- MatchersInternal.cpp----------------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Query/Matcher/MatchersInternal.h"
+#include "llvm/ADT/SetVector.h"
+
+namespace mlir::query::matcher {
+
+namespace internal {
+
+bool allOfVariadicOperator(Operation *op, SetVector<Operation *> *matchedOps,
+                           ArrayRef<DynMatcher> innerMatchers) {
+  return llvm::all_of(innerMatchers, [&](const DynMatcher &matcher) {
+    if (matchedOps)
+      return matcher.match(op, *matchedOps);
+    return matcher.match(op);
+  });
+}
+bool anyOfVariadicOperator(Operation *op, SetVector<Operation *> *matchedOps,
+                           ArrayRef<DynMatcher> innerMatchers) {
+  return llvm::any_of(innerMatchers, [&](const DynMatcher &matcher) {
+    if (matchedOps)
+      return matcher.match(op, *matchedOps);
+    return matcher.match(op);
+  });
+}
+} // namespace internal
+} // namespace mlir::query::matcher
diff --git a/mlir/lib/Query/Matcher/RegistryManager.cpp b/mlir/lib/Query/Matcher/RegistryManager.cpp
index 4b511c5f009e7..08b610453b11a 100644
--- a/mlir/lib/Query/Matcher/RegistryManager.cpp
+++ b/mlir/lib/Query/Matcher/RegistryManager.cpp
@@ -64,7 +64,7 @@ std::vector<ArgKind> RegistryManager::getAcceptedCompletionTypes(
     unsigned argNumber = ctxEntry.second;
     std::vector<ArgKind> nextTypeSet;
 
-    if (argNumber < ctor->getNumArgs())
+    if (ctor->isVariadic() || argNumber < ctor->getNumArgs())
       ctor->getArgKinds(argNumber, nextTypeSet);
 
     typeSet.insert(nextTypeSet.begin(), nextTypeSet.end());
@@ -83,7 +83,7 @@ RegistryManager::getMatcherCompletions(llvm::ArrayRef<ArgKind> acceptedTypes,
     const internal::MatcherDescriptor &matcher = *m.getValue();
     llvm::StringRef name = m.getKey();
 
-    unsigned numArgs = matcher.getNumArgs();
+    unsigned numArgs = matcher.isVariadic() ? 1 : matcher.getNumArgs();
     std::vector<std::vector<ArgKind>> argKinds(numArgs);
 
     for (const ArgKind &kind : acceptedTypes) {
@@ -115,6 +115,9 @@ RegistryManager::getMatcherCompletions(llvm::ArrayRef<ArgKind> acceptedTypes,
       }
     }
 ...
[truncated]

@llvmbot
Copy link
Member

llvmbot commented Jun 27, 2025

@llvm/pr-subscribers-mlir-core

Author: Denzel-Brian Budii (chios202)

Changes

An uninitialized variable that caused a crash (https://lab.llvm.org/buildbot/#/builders/164/builds/11004) was identified using the memory analyzer, leading to the reversion of #141423. This pull request reapplies the previously reverted changes and includes the fix, which has been tested locally following the steps at https://github.com/google/sanitizers/wiki/SanitizerBotReproduceBuild.


Patch is 31.52 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/146156.diff

15 Files Affected:

  • (modified) mlir/include/mlir/Query/Matcher/Marshallers.h (+61)
  • (modified) mlir/include/mlir/Query/Matcher/MatchFinder.h (+3-1)
  • (modified) mlir/include/mlir/Query/Matcher/MatchersInternal.h (+112-4)
  • (modified) mlir/include/mlir/Query/Matcher/SliceMatchers.h (+99-5)
  • (modified) mlir/include/mlir/Query/Matcher/VariantValue.h (+10-1)
  • (modified) mlir/lib/Query/Matcher/CMakeLists.txt (+1)
  • (added) mlir/lib/Query/Matcher/MatchersInternal.cpp (+33)
  • (modified) mlir/lib/Query/Matcher/RegistryManager.cpp (+5-2)
  • (modified) mlir/lib/Query/Matcher/VariantValue.cpp (+54-2)
  • (modified) mlir/lib/Query/Query.cpp (+5)
  • (renamed) mlir/test/mlir-query/backward-slice-union.mlir (+11-2)
  • (added) mlir/test/mlir-query/forward-slice-by-predicate.mlir (+27)
  • (added) mlir/test/mlir-query/logical-operator-test.mlir (+11)
  • (added) mlir/test/mlir-query/slice-function-extraction.mlir (+29)
  • (modified) mlir/tools/mlir-query/mlir-query.cpp (+12-2)
diff --git a/mlir/include/mlir/Query/Matcher/Marshallers.h b/mlir/include/mlir/Query/Matcher/Marshallers.h
index 012bf7b9ec4a9..5fe6965f32efb 100644
--- a/mlir/include/mlir/Query/Matcher/Marshallers.h
+++ b/mlir/include/mlir/Query/Matcher/Marshallers.h
@@ -108,6 +108,9 @@ class MatcherDescriptor {
                                 const llvm::ArrayRef<ParserValue> args,
                                 Diagnostics *error) const = 0;
 
+  // If the matcher is variadic, it can take any number of arguments.
+  virtual bool isVariadic() const = 0;
+
   // Returns the number of arguments accepted by the matcher.
   virtual unsigned getNumArgs() const = 0;
 
@@ -140,6 +143,8 @@ class FixedArgCountMatcherDescriptor : public MatcherDescriptor {
     return marshaller(matcherFunc, matcherName, nameRange, args, error);
   }
 
+  bool isVariadic() const override { return false; }
+
   unsigned getNumArgs() const override { return argKinds.size(); }
 
   void getArgKinds(unsigned argNo, std::vector<ArgKind> &kinds) const override {
@@ -153,6 +158,54 @@ class FixedArgCountMatcherDescriptor : public MatcherDescriptor {
   const std::vector<ArgKind> argKinds;
 };
 
+class VariadicOperatorMatcherDescriptor : public MatcherDescriptor {
+public:
+  using VarOp = DynMatcher::VariadicOperator;
+  VariadicOperatorMatcherDescriptor(unsigned minCount, unsigned maxCount,
+                                    VarOp varOp, StringRef matcherName)
+      : minCount(minCount), maxCount(maxCount), varOp(varOp),
+        matcherName(matcherName) {}
+
+  VariantMatcher create(SourceRange nameRange, ArrayRef<ParserValue> args,
+                        Diagnostics *error) const override {
+    if (args.size() < minCount || maxCount < args.size()) {
+      addError(error, nameRange, ErrorType::RegistryWrongArgCount,
+               {llvm::Twine("requires between "), llvm::Twine(minCount),
+                llvm::Twine(" and "), llvm::Twine(maxCount),
+                llvm::Twine(" args, got "), llvm::Twine(args.size())});
+      return VariantMatcher();
+    }
+
+    std::vector<VariantMatcher> innerArgs;
+    for (int64_t i = 0, e = args.size(); i != e; ++i) {
+      const ParserValue &arg = args[i];
+      const VariantValue &value = arg.value;
+      if (!value.isMatcher()) {
+        addError(error, arg.range, ErrorType::RegistryWrongArgType,
+                 {llvm::Twine(i + 1), llvm::Twine("matcher: "),
+                  llvm::Twine(value.getTypeAsString())});
+        return VariantMatcher();
+      }
+      innerArgs.push_back(value.getMatcher());
+    }
+    return VariantMatcher::VariadicOperatorMatcher(varOp, std::move(innerArgs));
+  }
+
+  bool isVariadic() const override { return true; }
+
+  unsigned getNumArgs() const override { return 0; }
+
+  void getArgKinds(unsigned argNo, std::vector<ArgKind> &kinds) const override {
+    kinds.push_back(ArgKind(ArgKind::Matcher));
+  }
+
+private:
+  const unsigned minCount;
+  const unsigned maxCount;
+  const VarOp varOp;
+  const StringRef matcherName;
+};
+
 // Helper function to check if argument count matches expected count
 inline bool checkArgCount(SourceRange nameRange, size_t expectedArgCount,
                           llvm::ArrayRef<ParserValue> args,
@@ -224,6 +277,14 @@ makeMatcherAutoMarshall(ReturnType (*matcherFunc)(ArgTypes...),
       reinterpret_cast<void (*)()>(matcherFunc), matcherName, argKinds);
 }
 
+// Variadic operator overload.
+template <unsigned MinCount, unsigned MaxCount>
+std::unique_ptr<MatcherDescriptor>
+makeMatcherAutoMarshall(VariadicOperatorMatcherFunc<MinCount, MaxCount> func,
+                        StringRef matcherName) {
+  return std::make_unique<VariadicOperatorMatcherDescriptor>(
+      MinCount, MaxCount, func.varOp, matcherName);
+}
 } // namespace mlir::query::matcher::internal
 
 #endif // MLIR_TOOLS_MLIRQUERY_MATCHER_MARSHALLERS_H
diff --git a/mlir/include/mlir/Query/Matcher/MatchFinder.h b/mlir/include/mlir/Query/Matcher/MatchFinder.h
index f8abf20ef60bb..6d06ca13d1344 100644
--- a/mlir/include/mlir/Query/Matcher/MatchFinder.h
+++ b/mlir/include/mlir/Query/Matcher/MatchFinder.h
@@ -21,7 +21,9 @@
 
 namespace mlir::query::matcher {
 
-/// A class that provides utilities to find operations in the IR.
+/// Finds and collects matches from the IR. After construction
+/// `collectMatches` can be used to traverse the IR and apply
+/// matchers.
 class MatchFinder {
 
 public:
diff --git a/mlir/include/mlir/Query/Matcher/MatchersInternal.h b/mlir/include/mlir/Query/Matcher/MatchersInternal.h
index 183b2514e109f..88109430b6feb 100644
--- a/mlir/include/mlir/Query/Matcher/MatchersInternal.h
+++ b/mlir/include/mlir/Query/Matcher/MatchersInternal.h
@@ -8,11 +8,11 @@
 //
 // Implements the base layer of the matcher framework.
 //
-// Matchers are methods that return a Matcher which provides a method one of the
-// following methods: match(Operation *op), match(Operation *op,
-// SetVector<Operation *> &matchedOps)
+// Matchers are methods that return a Matcher which provides a
+// `match(...)` method whose parameters define the context of the match.
+// Support includes simple (unary) matchers as well as matcher combinators
+// (anyOf, allOf, etc.)
 //
-// The matcher functions are defined in include/mlir/IR/Matchers.h.
 // This file contains the wrapper classes needed to construct matchers for
 // mlir-query.
 //
@@ -25,6 +25,15 @@
 #include "llvm/ADT/IntrusiveRefCntPtr.h"
 
 namespace mlir::query::matcher {
+class DynMatcher;
+namespace internal {
+
+bool allOfVariadicOperator(Operation *op, SetVector<Operation *> *matchedOps,
+                           ArrayRef<DynMatcher> innerMatchers);
+bool anyOfVariadicOperator(Operation *op, SetVector<Operation *> *matchedOps,
+                           ArrayRef<DynMatcher> innerMatchers);
+
+} // namespace internal
 
 // Defaults to false if T has no match() method with the signature:
 // match(Operation* op).
@@ -84,6 +93,27 @@ class MatcherFnImpl : public MatcherInterface {
   MatcherFn matcherFn;
 };
 
+// VariadicMatcher takes a vector of Matchers and returns true if any Matchers
+// match the given operation.
+using VariadicOperatorFunction = bool (*)(Operation *op,
+                                          SetVector<Operation *> *matchedOps,
+                                          ArrayRef<DynMatcher> innerMatchers);
+
+template <VariadicOperatorFunction Func>
+class VariadicMatcher : public MatcherInterface {
+public:
+  VariadicMatcher(std::vector<DynMatcher> matchers)
+      : matchers(std::move(matchers)) {}
+
+  bool match(Operation *op) override { return Func(op, nullptr, matchers); }
+  bool match(Operation *op, SetVector<Operation *> &matchedOps) override {
+    return Func(op, &matchedOps, matchers);
+  }
+
+private:
+  std::vector<DynMatcher> matchers;
+};
+
 // Matcher wraps a MatcherInterface implementation and provides match()
 // methods that redirect calls to the underlying implementation.
 class DynMatcher {
@@ -92,6 +122,31 @@ class DynMatcher {
   DynMatcher(MatcherInterface *implementation)
       : implementation(implementation) {}
 
+  // Construct from a variadic function.
+  enum VariadicOperator {
+    // Matches operations for which all provided matchers match.
+    AllOf,
+    // Matches operations for which at least one of the provided matchers
+    // matches.
+    AnyOf
+  };
+
+  static std::unique_ptr<DynMatcher>
+  constructVariadic(VariadicOperator Op,
+                    std::vector<DynMatcher> innerMatchers) {
+    switch (Op) {
+    case AllOf:
+      return std::make_unique<DynMatcher>(
+          new VariadicMatcher<internal::allOfVariadicOperator>(
+              std::move(innerMatchers)));
+    case AnyOf:
+      return std::make_unique<DynMatcher>(
+          new VariadicMatcher<internal::anyOfVariadicOperator>(
+              std::move(innerMatchers)));
+    }
+    llvm_unreachable("Invalid Op value.");
+  }
+
   template <typename MatcherFn>
   static std::unique_ptr<DynMatcher>
   constructDynMatcherFromMatcherFn(MatcherFn &matcherFn) {
@@ -113,6 +168,59 @@ class DynMatcher {
   std::string functionName;
 };
 
+// VariadicOperatorMatcher related types.
+template <typename... Ps>
+class VariadicOperatorMatcher {
+public:
+  VariadicOperatorMatcher(DynMatcher::VariadicOperator varOp, Ps &&...params)
+      : varOp(varOp), params(std::forward<Ps>(params)...) {}
+
+  operator std::unique_ptr<DynMatcher>() const & {
+    return DynMatcher::constructVariadic(
+        varOp, getMatchers(std::index_sequence_for<Ps...>()));
+  }
+
+  operator std::unique_ptr<DynMatcher>() && {
+    return DynMatcher::constructVariadic(
+        varOp, std::move(*this).getMatchers(std::index_sequence_for<Ps...>()));
+  }
+
+private:
+  // Helper method to unpack the tuple into a vector.
+  template <std::size_t... Is>
+  std::vector<DynMatcher> getMatchers(std::index_sequence<Is...>) const & {
+    return {DynMatcher(std::get<Is>(params))...};
+  }
+
+  template <std::size_t... Is>
+  std::vector<DynMatcher> getMatchers(std::index_sequence<Is...>) && {
+    return {DynMatcher(std::get<Is>(std::move(params)))...};
+  }
+
+  const DynMatcher::VariadicOperator varOp;
+  std::tuple<Ps...> params;
+};
+
+// Overloaded function object to generate VariadicOperatorMatcher objects from
+// arbitrary matchers.
+template <unsigned MinCount, unsigned MaxCount>
+struct VariadicOperatorMatcherFunc {
+  DynMatcher::VariadicOperator varOp;
+
+  template <typename... Ms>
+  VariadicOperatorMatcher<Ms...> operator()(Ms &&...Ps) const {
+    static_assert(MinCount <= sizeof...(Ms) && sizeof...(Ms) <= MaxCount,
+                  "invalid number of parameters for variadic matcher");
+    return VariadicOperatorMatcher<Ms...>(varOp, std::forward<Ms>(Ps)...);
+  }
+};
+
+namespace internal {
+const VariadicOperatorMatcherFunc<1, std::numeric_limits<unsigned>::max()>
+    anyOf = {DynMatcher::AnyOf};
+const VariadicOperatorMatcherFunc<1, std::numeric_limits<unsigned>::max()>
+    allOf = {DynMatcher::AllOf};
+} // namespace internal
 } // namespace mlir::query::matcher
 
 #endif // MLIR_TOOLS_MLIRQUERY_MATCHER_MATCHERSINTERNAL_H
diff --git a/mlir/include/mlir/Query/Matcher/SliceMatchers.h b/mlir/include/mlir/Query/Matcher/SliceMatchers.h
index 441205b3a9615..7181648f06f89 100644
--- a/mlir/include/mlir/Query/Matcher/SliceMatchers.h
+++ b/mlir/include/mlir/Query/Matcher/SliceMatchers.h
@@ -6,7 +6,8 @@
 //
 //===----------------------------------------------------------------------===//
 //
-// This file provides matchers for MLIRQuery that peform slicing analysis
+// This file defines slicing-analysis matchers that extend and abstract the
+// core implementations from `SliceAnalysis.h`.
 //
 //===----------------------------------------------------------------------===//
 
@@ -16,9 +17,9 @@
 #include "mlir/Analysis/SliceAnalysis.h"
 #include "mlir/IR/Operation.h"
 
-/// A matcher encapsulating `getBackwardSlice` method from SliceAnalysis.h.
-/// Additionally, it limits the slice computation to a certain depth level using
-/// a custom filter.
+/// Computes the backward-slice of all transitive defs reachable from `rootOp`,
+/// if `innerMatcher` matches. The traversal stops once the desired depth level
+/// is reached.
 ///
 /// Example: starting from node 9, assuming the matcher
 /// computes the slice for the first two depth levels:
@@ -119,6 +120,77 @@ bool BackwardSliceMatcher<Matcher>::matches(
                            : backwardSlice.size() >= 1;
 }
 
+/// Computes the backward-slice of all transitive defs reachable from `rootOp`,
+/// if `innerMatcher` matches. Traversal stops where `filterMatcher` matches.
+template <typename BaseMatcher, typename Filter>
+class PredicateBackwardSliceMatcher {
+public:
+  PredicateBackwardSliceMatcher(BaseMatcher innerMatcher, Filter filterMatcher,
+                                bool inclusive, bool omitBlockArguments,
+                                bool omitUsesFromAbove)
+      : innerMatcher(std::move(innerMatcher)),
+        filterMatcher(std::move(filterMatcher)), inclusive(inclusive),
+        omitBlockArguments(omitBlockArguments),
+        omitUsesFromAbove(omitUsesFromAbove) {}
+
+  bool match(Operation *rootOp, SetVector<Operation *> &backwardSlice) {
+    backwardSlice.clear();
+    BackwardSliceOptions options;
+    options.inclusive = inclusive;
+    options.omitUsesFromAbove = omitUsesFromAbove;
+    options.omitBlockArguments = omitBlockArguments;
+    if (innerMatcher.match(rootOp)) {
+      options.filter = [&](Operation *subOp) {
+        return !filterMatcher.match(subOp);
+      };
+      LogicalResult result = getBackwardSlice(rootOp, &backwardSlice, options);
+      assert(result.succeeded() && "expected backward slice to succeed");
+      (void)result;
+      return options.inclusive ? backwardSlice.size() > 1
+                               : backwardSlice.size() >= 1;
+    }
+    return false;
+  }
+
+private:
+  BaseMatcher innerMatcher;
+  Filter filterMatcher;
+  bool inclusive;
+  bool omitBlockArguments;
+  bool omitUsesFromAbove;
+};
+
+/// Computes the forward-slice of all users reachable from `rootOp`,
+/// if `innerMatcher` matches. Traversal stops where `filterMatcher` matches.
+template <typename BaseMatcher, typename Filter>
+class PredicateForwardSliceMatcher {
+public:
+  PredicateForwardSliceMatcher(BaseMatcher innerMatcher, Filter filterMatcher,
+                               bool inclusive)
+      : innerMatcher(std::move(innerMatcher)),
+        filterMatcher(std::move(filterMatcher)), inclusive(inclusive) {}
+
+  bool match(Operation *rootOp, SetVector<Operation *> &forwardSlice) {
+    forwardSlice.clear();
+    ForwardSliceOptions options;
+    options.inclusive = inclusive;
+    if (innerMatcher.match(rootOp)) {
+      options.filter = [&](Operation *subOp) {
+        return !filterMatcher.match(subOp);
+      };
+      getForwardSlice(rootOp, &forwardSlice, options);
+      return options.inclusive ? forwardSlice.size() > 1
+                               : forwardSlice.size() >= 1;
+    }
+    return false;
+  }
+
+private:
+  BaseMatcher innerMatcher;
+  Filter filterMatcher;
+  bool inclusive;
+};
+
 /// Matches transitive defs of a top-level operation up to N levels.
 template <typename Matcher>
 inline BackwardSliceMatcher<Matcher>
@@ -130,7 +202,7 @@ m_GetDefinitions(Matcher innerMatcher, int64_t maxDepth, bool inclusive,
                                        omitUsesFromAbove);
 }
 
-/// Matches all transitive defs of a top-level operation up to N levels
+/// Matches all transitive defs of a top-level operation up to N levels.
 template <typename Matcher>
 inline BackwardSliceMatcher<Matcher> m_GetAllDefinitions(Matcher innerMatcher,
                                                          int64_t maxDepth) {
@@ -139,6 +211,28 @@ inline BackwardSliceMatcher<Matcher> m_GetAllDefinitions(Matcher innerMatcher,
                                        false, false);
 }
 
+/// Matches all transitive defs of a top-level operation and stops where
+/// `filterMatcher` rejects.
+template <typename BaseMatcher, typename Filter>
+inline PredicateBackwardSliceMatcher<BaseMatcher, Filter>
+m_GetDefinitionsByPredicate(BaseMatcher innerMatcher, Filter filterMatcher,
+                            bool inclusive, bool omitBlockArguments,
+                            bool omitUsesFromAbove) {
+  return PredicateBackwardSliceMatcher<BaseMatcher, Filter>(
+      std::move(innerMatcher), std::move(filterMatcher), inclusive,
+      omitBlockArguments, omitUsesFromAbove);
+}
+
+/// Matches all users of a top-level operation and stops where
+/// `filterMatcher` rejects.
+template <typename BaseMatcher, typename Filter>
+inline PredicateForwardSliceMatcher<BaseMatcher, Filter>
+m_GetUsersByPredicate(BaseMatcher innerMatcher, Filter filterMatcher,
+                      bool inclusive) {
+  return PredicateForwardSliceMatcher<BaseMatcher, Filter>(
+      std::move(innerMatcher), std::move(filterMatcher), inclusive);
+}
+
 } // namespace mlir::query::matcher
 
 #endif // MLIR_TOOLS_MLIRQUERY_MATCHERS_SLICEMATCHERS_H
diff --git a/mlir/include/mlir/Query/Matcher/VariantValue.h b/mlir/include/mlir/Query/Matcher/VariantValue.h
index 98c0a18e25101..1a47576de1841 100644
--- a/mlir/include/mlir/Query/Matcher/VariantValue.h
+++ b/mlir/include/mlir/Query/Matcher/VariantValue.h
@@ -26,7 +26,12 @@ enum class ArgKind { Boolean, Matcher, Signed, String };
 // A variant matcher object to abstract simple and complex matchers into a
 // single object type.
 class VariantMatcher {
-  class MatcherOps;
+  class MatcherOps {
+  public:
+    std::optional<DynMatcher>
+    constructVariadicOperator(DynMatcher::VariadicOperator varOp,
+                              ArrayRef<VariantMatcher> innerMatchers) const;
+  };
 
   // Payload interface to be specialized by each matcher type. It follows a
   // similar interface as VariantMatcher itself.
@@ -43,6 +48,9 @@ class VariantMatcher {
 
   // Clones the provided matcher.
   static VariantMatcher SingleMatcher(DynMatcher matcher);
+  static VariantMatcher
+  VariadicOperatorMatcher(DynMatcher::VariadicOperator varOp,
+                          ArrayRef<VariantMatcher> args);
 
   // Makes the matcher the "null" matcher.
   void reset();
@@ -61,6 +69,7 @@ class VariantMatcher {
       : value(std::move(value)) {}
 
   class SinglePayload;
+  class VariadicOpPayload;
 
   std::shared_ptr<const Payload> value;
 };
diff --git a/mlir/lib/Query/Matcher/CMakeLists.txt b/mlir/lib/Query/Matcher/CMakeLists.txt
index 629479bf7adc1..ba202762fdfbb 100644
--- a/mlir/lib/Query/Matcher/CMakeLists.txt
+++ b/mlir/lib/Query/Matcher/CMakeLists.txt
@@ -1,5 +1,6 @@
 add_mlir_library(MLIRQueryMatcher
   MatchFinder.cpp
+  MatchersInternal.cpp
   Parser.cpp
   RegistryManager.cpp
   VariantValue.cpp
diff --git a/mlir/lib/Query/Matcher/MatchersInternal.cpp b/mlir/lib/Query/Matcher/MatchersInternal.cpp
new file mode 100644
index 0000000000000..01f412ade846b
--- /dev/null
+++ b/mlir/lib/Query/Matcher/MatchersInternal.cpp
@@ -0,0 +1,33 @@
+//===--- MatchersInternal.cpp----------------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Query/Matcher/MatchersInternal.h"
+#include "llvm/ADT/SetVector.h"
+
+namespace mlir::query::matcher {
+
+namespace internal {
+
+bool allOfVariadicOperator(Operation *op, SetVector<Operation *> *matchedOps,
+                           ArrayRef<DynMatcher> innerMatchers) {
+  return llvm::all_of(innerMatchers, [&](const DynMatcher &matcher) {
+    if (matchedOps)
+      return matcher.match(op, *matchedOps);
+    return matcher.match(op);
+  });
+}
+bool anyOfVariadicOperator(Operation *op, SetVector<Operation *> *matchedOps,
+                           ArrayRef<DynMatcher> innerMatchers) {
+  return llvm::any_of(innerMatchers, [&](const DynMatcher &matcher) {
+    if (matchedOps)
+      return matcher.match(op, *matchedOps);
+    return matcher.match(op);
+  });
+}
+} // namespace internal
+} // namespace mlir::query::matcher
diff --git a/mlir/lib/Query/Matcher/RegistryManager.cpp b/mlir/lib/Query/Matcher/RegistryManager.cpp
index 4b511c5f009e7..08b610453b11a 100644
--- a/mlir/lib/Query/Matcher/RegistryManager.cpp
+++ b/mlir/lib/Query/Matcher/RegistryManager.cpp
@@ -64,7 +64,7 @@ std::vector<ArgKind> RegistryManager::getAcceptedCompletionTypes(
     unsigned argNumber = ctxEntry.second;
     std::vector<ArgKind> nextTypeSet;
 
-    if (argNumber < ctor->getNumArgs())
+    if (ctor->isVariadic() || argNumber < ctor->getNumArgs())
       ctor->getArgKinds(argNumber, nextTypeSet);
 
     typeSet.insert(nextTypeSet.begin(), nextTypeSet.end());
@@ -83,7 +83,7 @@ RegistryManager::getMatcherCompletions(llvm::ArrayRef<ArgKind> acceptedTypes,
     const internal::MatcherDescriptor &matcher = *m.getValue();
     llvm::StringRef name = m.getKey();
 
-    unsigned numArgs = matcher.getNumArgs();
+    unsigned numArgs = matcher.isVariadic() ? 1 : matcher.getNumArgs();
     std::vector<std::vector<ArgKind>> argKinds(numArgs);
 
     for (const ArgKind &kind : acceptedTypes) {
@@ -115,6 +115,9 @@ RegistryManager::getMatcherCompletions(llvm::ArrayRef<ArgKind> acceptedTypes,
       }
     }
 ...
[truncated]

@chios202 chios202 changed the title Reapply 141423 mlir query combinators plus fix [mlir] Reapply 141423 mlir-query combinators plus fix Jun 28, 2025
@jpienaar jpienaar merged commit 3702d64 into llvm:main Jul 1, 2025
11 checks passed
@llvm-ci
Copy link
Collaborator

llvm-ci commented Jul 1, 2025

LLVM Buildbot has detected a new failure on builder flang-arm64-windows-msvc running on linaro-armv8-windows-msvc-01 while building mlir at step 7 "test-build-unified-tree-check-flang".

Full details are available at: https://lab.llvm.org/buildbot/#/builders/207/builds/3192

Here is the relevant piece of the build log for the reference
Step 7 (test-build-unified-tree-check-flang) failure: test (failure)
******************** TEST 'Flang :: Parser/OpenMP/declare-reduction-operator.f90' FAILED ********************
Exit Code: 2

Command Output (stdout):
--
# RUN: at line 1
c:\users\tcwg\llvm-worker\flang-arm64-windows-msvc\build\bin\flang.exe -fc1 -fdebug-unparse -fopenmp C:\Users\tcwg\llvm-worker\flang-arm64-windows-msvc\llvm-project\flang\test\Parser\OpenMP\declare-reduction-operator.f90 | c:\users\tcwg\llvm-worker\flang-arm64-windows-msvc\build\bin\filecheck.exe --ignore-case C:\Users\tcwg\llvm-worker\flang-arm64-windows-msvc\llvm-project\flang\test\Parser\OpenMP\declare-reduction-operator.f90
# executed command: 'c:\users\tcwg\llvm-worker\flang-arm64-windows-msvc\build\bin\flang.exe' -fc1 -fdebug-unparse -fopenmp 'C:\Users\tcwg\llvm-worker\flang-arm64-windows-msvc\llvm-project\flang\test\Parser\OpenMP\declare-reduction-operator.f90'
# executed command: 'c:\users\tcwg\llvm-worker\flang-arm64-windows-msvc\build\bin\filecheck.exe' --ignore-case 'C:\Users\tcwg\llvm-worker\flang-arm64-windows-msvc\llvm-project\flang\test\Parser\OpenMP\declare-reduction-operator.f90'
# RUN: at line 2
c:\users\tcwg\llvm-worker\flang-arm64-windows-msvc\build\bin\flang.exe -fc1 -fdebug-dump-parse-tree -fopenmp C:\Users\tcwg\llvm-worker\flang-arm64-windows-msvc\llvm-project\flang\test\Parser\OpenMP\declare-reduction-operator.f90 | c:\users\tcwg\llvm-worker\flang-arm64-windows-msvc\build\bin\filecheck.exe --check-prefix="PARSE-TREE" C:\Users\tcwg\llvm-worker\flang-arm64-windows-msvc\llvm-project\flang\test\Parser\OpenMP\declare-reduction-operator.f90
# executed command: 'c:\users\tcwg\llvm-worker\flang-arm64-windows-msvc\build\bin\flang.exe' -fc1 -fdebug-dump-parse-tree -fopenmp 'C:\Users\tcwg\llvm-worker\flang-arm64-windows-msvc\llvm-project\flang\test\Parser\OpenMP\declare-reduction-operator.f90'
# .---command stderr------------
# | PLEASE submit a bug report to https://github.com/llvm/llvm-project/issues/ and include the crash backtrace.
# | Stack dump:
# | 0.	Program arguments: c:\\users\\tcwg\\llvm-worker\\flang-arm64-windows-msvc\\build\\bin\\flang.exe -fc1 -fdebug-dump-parse-tree -fopenmp C:\\Users\\tcwg\\llvm-worker\\flang-arm64-windows-msvc\\llvm-project\\flang\\test\\Parser\\OpenMP\\declare-reduction-operator.f90
# | Exception Code: 0xC0000005
# |  #0 0x00007ff615261990 (c:\users\tcwg\llvm-worker\flang-arm64-windows-msvc\build\bin\flang.exe+0x1a91990)
# |  #1 0x00007ff61525f678 (c:\users\tcwg\llvm-worker\flang-arm64-windows-msvc\build\bin\flang.exe+0x1a8f678)
# |  #2 0x00007ff61525d50c (c:\users\tcwg\llvm-worker\flang-arm64-windows-msvc\build\bin\flang.exe+0x1a8d50c)
# |  #3 0x00007ff615263220 (c:\users\tcwg\llvm-worker\flang-arm64-windows-msvc\build\bin\flang.exe+0x1a93220)
# |  #4 0x00007ff61525d074 (c:\users\tcwg\llvm-worker\flang-arm64-windows-msvc\build\bin\flang.exe+0x1a8d074)
# |  #5 0x00007ff61407bef8 (c:\users\tcwg\llvm-worker\flang-arm64-windows-msvc\build\bin\flang.exe+0x8abef8)
# |  #6 0x00007ff6140c5358 (c:\users\tcwg\llvm-worker\flang-arm64-windows-msvc\build\bin\flang.exe+0x8f5358)
# |  #7 0x00007ff61416e060 (c:\users\tcwg\llvm-worker\flang-arm64-windows-msvc\build\bin\flang.exe+0x99e060)
# |  #8 0x00007ff6140c4794 (c:\users\tcwg\llvm-worker\flang-arm64-windows-msvc\build\bin\flang.exe+0x8f4794)
# |  #9 0x00007ff613868a94 (c:\users\tcwg\llvm-worker\flang-arm64-windows-msvc\build\bin\flang.exe+0x98a94)
# | #10 0x00007ff61387ea5c (c:\users\tcwg\llvm-worker\flang-arm64-windows-msvc\build\bin\flang.exe+0xaea5c)
# | #11 0x00007ff6137d34e8 (c:\users\tcwg\llvm-worker\flang-arm64-windows-msvc\build\bin\flang.exe+0x34e8)
# | #12 0x00007ff6137d20a4 (c:\users\tcwg\llvm-worker\flang-arm64-windows-msvc\build\bin\flang.exe+0x20a4)
# | #13 0x00007ff618478bb8 mlir::detail::FallbackTypeIDResolver::registerImplicitTypeID(class llvm::StringRef) (c:\users\tcwg\llvm-worker\flang-arm64-windows-msvc\build\bin\flang.exe+0x4ca8bb8)
# | #14 0xa3177ff618478c54 
# `-----------------------------
# error: command failed with exit status: 0xc0000005
# executed command: 'c:\users\tcwg\llvm-worker\flang-arm64-windows-msvc\build\bin\filecheck.exe' --check-prefix=PARSE-TREE 'C:\Users\tcwg\llvm-worker\flang-arm64-windows-msvc\llvm-project\flang\test\Parser\OpenMP\declare-reduction-operator.f90'
# .---command stderr------------
# | FileCheck error: '<stdin>' is empty.
# | FileCheck command line:  c:\users\tcwg\llvm-worker\flang-arm64-windows-msvc\build\bin\filecheck.exe --check-prefix=PARSE-TREE C:\Users\tcwg\llvm-worker\flang-arm64-windows-msvc\llvm-project\flang\test\Parser\OpenMP\declare-reduction-operator.f90
# `-----------------------------
# error: command failed with exit status: 2

--

********************


Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

mlir:core MLIR Core Infrastructure mlir

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants