Skip to content

Conversation

@aidint
Copy link
Contributor

@aidint aidint commented Nov 2, 2025

Currently, we don't have support for patterns that need access to a Tester instance in mlir-reduce. This PR adds DialectReductionPatternWithTesterInterface to the set of supported interfaces. Dialects can implement this interface to inject the tester into their pattern classes.

@llvmbot llvmbot added mlir:core MLIR Core Infrastructure mlir labels Nov 2, 2025
@aidint aidint changed the title mlir-reduce: add reduction interface with tester [MLIR] Add reduction interface with tester to mlir-reduce Nov 2, 2025
@llvmbot
Copy link
Member

llvmbot commented Nov 2, 2025

@llvm/pr-subscribers-mlir-core

@llvm/pr-subscribers-mlir

Author: AidinT (aidint)

Changes

Currently we don't have a support for patterns that need access to Tester instance in mlir-reduce. This PR adds DialectReductionPatternWithTesterInterface to the set of supported interfaces. Dialects can use this interface to inject tester into the pattern classes.


Full diff: https://github.com/llvm/llvm-project/pull/166096.diff

3 Files Affected:

  • (modified) mlir/include/mlir/Reducer/ReductionPatternInterface.h (+26)
  • (modified) mlir/include/mlir/Reducer/Tester.h (+6)
  • (modified) mlir/lib/Reducer/ReductionTreePass.cpp (+29-2)
diff --git a/mlir/include/mlir/Reducer/ReductionPatternInterface.h b/mlir/include/mlir/Reducer/ReductionPatternInterface.h
index a85562fda4d93..c4d7a94479358 100644
--- a/mlir/include/mlir/Reducer/ReductionPatternInterface.h
+++ b/mlir/include/mlir/Reducer/ReductionPatternInterface.h
@@ -10,6 +10,7 @@
 #define MLIR_REDUCER_REDUCTIONPATTERNINTERFACE_H
 
 #include "mlir/IR/DialectInterface.h"
+#include "mlir/Reducer/Tester.h"
 
 namespace mlir {
 
@@ -51,6 +52,31 @@ class DialectReductionPatternInterface
   DialectReductionPatternInterface(Dialect *dialect) : Base(dialect) {}
 };
 
+/// This interface extends the `dialectreductionpatterninterface` by allowing
+/// reduction patterns to use a `Tester` instance. Some reduction patterns may
+/// need to run tester to determine whether certain transformations preserve the
+/// "interesting" behavior of the program. This is mostly useful when pattern
+/// should choose between multiple modifications.
+/// Implementation follows the same logic as the
+/// `dialectreductionpatterninterface`.
+///
+/// Example:
+///   MyDialectReductionPatternWithTester::populateReductionPatterns(
+///       RewritePatternSet &patterns, Tester &tester) {
+///       patterns.add<PatternWithTester>(patterns.getContext(), tester);
+///   }
+class DialectReductionPatternWithTesterInterface
+    : public DialectInterface::Base<
+          DialectReductionPatternWithTesterInterface> {
+public:
+  virtual void populateReductionPatterns(RewritePatternSet &patterns,
+                                         Tester &tester) const = 0;
+
+protected:
+  DialectReductionPatternWithTesterInterface(Dialect *dialect)
+      : Base(dialect) {}
+};
+
 } // namespace mlir
 
 #endif // MLIR_REDUCER_REDUCTIONPATTERNINTERFACE_H
diff --git a/mlir/include/mlir/Reducer/Tester.h b/mlir/include/mlir/Reducer/Tester.h
index eb44afc7c1c15..bed4408342034 100644
--- a/mlir/include/mlir/Reducer/Tester.h
+++ b/mlir/include/mlir/Reducer/Tester.h
@@ -36,6 +36,9 @@ class Tester {
     Untested,
   };
 
+  Tester() = default;
+  Tester(const Tester &) = default;
+
   Tester(StringRef testScript, ArrayRef<std::string> testScriptArgs);
 
   /// Runs the interestingness testing script on a MLIR test case file. Returns
@@ -46,6 +49,9 @@ class Tester {
   /// Return whether the file in the given path is interesting.
   Interestingness isInteresting(StringRef testCase) const;
 
+  void setTestScript(StringRef script) { testScript = script; }
+  void setTestScriptArgs(ArrayRef<std::string> args) { testScriptArgs = args; }
+
 private:
   StringRef testScript;
   ArrayRef<std::string> testScriptArgs;
diff --git a/mlir/lib/Reducer/ReductionTreePass.cpp b/mlir/lib/Reducer/ReductionTreePass.cpp
index 5b49204013cc0..af94cd798f629 100644
--- a/mlir/lib/Reducer/ReductionTreePass.cpp
+++ b/mlir/lib/Reducer/ReductionTreePass.cpp
@@ -181,6 +181,24 @@ class ReductionPatternInterfaceCollection
   }
 };
 
+//===----------------------------------------------------------------------===//
+// Reduction Pattern With Tester Interface Collection
+//===----------------------------------------------------------------------===//
+
+class ReductionPatternWithTesterInterfaceCollection
+    : public DialectInterfaceCollection<
+          DialectReductionPatternWithTesterInterface> {
+public:
+  using Base::Base;
+
+  // Collect the reduce patterns defined by each dialect.
+  void populateReductionPatterns(RewritePatternSet &pattern,
+                                 Tester &tester) const {
+    for (const DialectReductionPatternWithTesterInterface &interface : *this)
+      interface.populateReductionPatterns(pattern, tester);
+  }
+};
+
 //===----------------------------------------------------------------------===//
 // ReductionTreePass
 //===----------------------------------------------------------------------===//
@@ -201,15 +219,25 @@ class ReductionTreePass
 private:
   LogicalResult reduceOp(ModuleOp module, Region &region);
 
+  Tester tester;
   FrozenRewritePatternSet reducerPatterns;
 };
 
 } // namespace
 
 LogicalResult ReductionTreePass::initialize(MLIRContext *context) {
+  tester.setTestScript(testerName);
+  tester.setTestScriptArgs(testerArgs);
+
   RewritePatternSet patterns(context);
+
   ReductionPatternInterfaceCollection reducePatternCollection(context);
   reducePatternCollection.populateReductionPatterns(patterns);
+
+  ReductionPatternWithTesterInterfaceCollection
+      reducePatternWithTesterCollection(context);
+  reducePatternWithTesterCollection.populateReductionPatterns(patterns, tester);
+
   reducerPatterns = std::move(patterns);
   return success();
 }
@@ -244,11 +272,10 @@ void ReductionTreePass::runOnOperation() {
 }
 
 LogicalResult ReductionTreePass::reduceOp(ModuleOp module, Region &region) {
-  Tester test(testerName, testerArgs);
   switch (traversalModeId) {
   case TraversalMode::SinglePath:
     return findOptimal<ReductionNode::iterator<TraversalMode::SinglePath>>(
-        module, region, reducerPatterns, test);
+        module, region, reducerPatterns, tester);
   default:
     return module.emitError() << "unsupported traversal mode detected";
   }

@github-actions
Copy link

github-actions bot commented Nov 5, 2025

✅ With the latest revision this PR passed the C/C++ code formatter.

@aidint aidint force-pushed the mlir-reduce-add-patterns-with-tester branch from 0612a34 to c7898be Compare November 5, 2025 20:43
@aidint aidint force-pushed the mlir-reduce-add-patterns-with-tester branch from c7898be to 04bdf55 Compare November 5, 2025 20:56
Copy link
Member

@jpienaar jpienaar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice, looks like a rather minimal change which enables a few different ways of reducing.

Copy link
Member

@jpienaar jpienaar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall looks good, and we can tune with arg removal pattern usage in follow up.

@jpienaar jpienaar merged commit 329dec9 into llvm:main Nov 13, 2025
10 checks passed
@aidint aidint deleted the mlir-reduce-add-patterns-with-tester branch November 16, 2025 15:30
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

mlir:core MLIR Core Infrastructure mlir

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants