-
Notifications
You must be signed in to change notification settings - Fork 15.3k
[MLIR] Add reduction interface with tester to mlir-reduce #166096
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[MLIR] Add reduction interface with tester to mlir-reduce #166096
Conversation
|
@llvm/pr-subscribers-mlir-core @llvm/pr-subscribers-mlir Author: AidinT (aidint) ChangesCurrently we don't have a support for patterns that need access to Full diff: https://github.com/llvm/llvm-project/pull/166096.diff 3 Files Affected:
diff --git a/mlir/include/mlir/Reducer/ReductionPatternInterface.h b/mlir/include/mlir/Reducer/ReductionPatternInterface.h
index a85562fda4d93..c4d7a94479358 100644
--- a/mlir/include/mlir/Reducer/ReductionPatternInterface.h
+++ b/mlir/include/mlir/Reducer/ReductionPatternInterface.h
@@ -10,6 +10,7 @@
#define MLIR_REDUCER_REDUCTIONPATTERNINTERFACE_H
#include "mlir/IR/DialectInterface.h"
+#include "mlir/Reducer/Tester.h"
namespace mlir {
@@ -51,6 +52,31 @@ class DialectReductionPatternInterface
DialectReductionPatternInterface(Dialect *dialect) : Base(dialect) {}
};
+/// This interface extends the `dialectreductionpatterninterface` by allowing
+/// reduction patterns to use a `Tester` instance. Some reduction patterns may
+/// need to run tester to determine whether certain transformations preserve the
+/// "interesting" behavior of the program. This is mostly useful when pattern
+/// should choose between multiple modifications.
+/// Implementation follows the same logic as the
+/// `dialectreductionpatterninterface`.
+///
+/// Example:
+/// MyDialectReductionPatternWithTester::populateReductionPatterns(
+/// RewritePatternSet &patterns, Tester &tester) {
+/// patterns.add<PatternWithTester>(patterns.getContext(), tester);
+/// }
+class DialectReductionPatternWithTesterInterface
+ : public DialectInterface::Base<
+ DialectReductionPatternWithTesterInterface> {
+public:
+ virtual void populateReductionPatterns(RewritePatternSet &patterns,
+ Tester &tester) const = 0;
+
+protected:
+ DialectReductionPatternWithTesterInterface(Dialect *dialect)
+ : Base(dialect) {}
+};
+
} // namespace mlir
#endif // MLIR_REDUCER_REDUCTIONPATTERNINTERFACE_H
diff --git a/mlir/include/mlir/Reducer/Tester.h b/mlir/include/mlir/Reducer/Tester.h
index eb44afc7c1c15..bed4408342034 100644
--- a/mlir/include/mlir/Reducer/Tester.h
+++ b/mlir/include/mlir/Reducer/Tester.h
@@ -36,6 +36,9 @@ class Tester {
Untested,
};
+ Tester() = default;
+ Tester(const Tester &) = default;
+
Tester(StringRef testScript, ArrayRef<std::string> testScriptArgs);
/// Runs the interestingness testing script on a MLIR test case file. Returns
@@ -46,6 +49,9 @@ class Tester {
/// Return whether the file in the given path is interesting.
Interestingness isInteresting(StringRef testCase) const;
+ void setTestScript(StringRef script) { testScript = script; }
+ void setTestScriptArgs(ArrayRef<std::string> args) { testScriptArgs = args; }
+
private:
StringRef testScript;
ArrayRef<std::string> testScriptArgs;
diff --git a/mlir/lib/Reducer/ReductionTreePass.cpp b/mlir/lib/Reducer/ReductionTreePass.cpp
index 5b49204013cc0..af94cd798f629 100644
--- a/mlir/lib/Reducer/ReductionTreePass.cpp
+++ b/mlir/lib/Reducer/ReductionTreePass.cpp
@@ -181,6 +181,24 @@ class ReductionPatternInterfaceCollection
}
};
+//===----------------------------------------------------------------------===//
+// Reduction Pattern With Tester Interface Collection
+//===----------------------------------------------------------------------===//
+
+class ReductionPatternWithTesterInterfaceCollection
+ : public DialectInterfaceCollection<
+ DialectReductionPatternWithTesterInterface> {
+public:
+ using Base::Base;
+
+ // Collect the reduce patterns defined by each dialect.
+ void populateReductionPatterns(RewritePatternSet &pattern,
+ Tester &tester) const {
+ for (const DialectReductionPatternWithTesterInterface &interface : *this)
+ interface.populateReductionPatterns(pattern, tester);
+ }
+};
+
//===----------------------------------------------------------------------===//
// ReductionTreePass
//===----------------------------------------------------------------------===//
@@ -201,15 +219,25 @@ class ReductionTreePass
private:
LogicalResult reduceOp(ModuleOp module, Region ®ion);
+ Tester tester;
FrozenRewritePatternSet reducerPatterns;
};
} // namespace
LogicalResult ReductionTreePass::initialize(MLIRContext *context) {
+ tester.setTestScript(testerName);
+ tester.setTestScriptArgs(testerArgs);
+
RewritePatternSet patterns(context);
+
ReductionPatternInterfaceCollection reducePatternCollection(context);
reducePatternCollection.populateReductionPatterns(patterns);
+
+ ReductionPatternWithTesterInterfaceCollection
+ reducePatternWithTesterCollection(context);
+ reducePatternWithTesterCollection.populateReductionPatterns(patterns, tester);
+
reducerPatterns = std::move(patterns);
return success();
}
@@ -244,11 +272,10 @@ void ReductionTreePass::runOnOperation() {
}
LogicalResult ReductionTreePass::reduceOp(ModuleOp module, Region ®ion) {
- Tester test(testerName, testerArgs);
switch (traversalModeId) {
case TraversalMode::SinglePath:
return findOptimal<ReductionNode::iterator<TraversalMode::SinglePath>>(
- module, region, reducerPatterns, test);
+ module, region, reducerPatterns, tester);
default:
return module.emitError() << "unsupported traversal mode detected";
}
|
|
✅ With the latest revision this PR passed the C/C++ code formatter. |
0612a34 to
c7898be
Compare
c7898be to
04bdf55
Compare
jpienaar
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice, looks like a rather minimal change which enables a few different ways of reducing.
jpienaar
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Overall looks good, and we can tune with arg removal pattern usage in follow up.
Currently, we don't have support for patterns that need access to a
Testerinstance inmlir-reduce. This PR addsDialectReductionPatternWithTesterInterfaceto the set of supported interfaces. Dialects can implement this interface to inject the tester into their pattern classes.