Skip to content

Commit 329dec9

Browse files
authored
[MLIR] Add reduction interface with tester to mlir-reduce (#166096)
Currently, we don't have support for patterns that need access to a `Tester` instance in `mlir-reduce`. This PR adds `DialectReductionPatternWithTesterInterface` to the set of supported interfaces. Dialects can implement this interface to inject the tester into their pattern classes.
1 parent 622d52d commit 329dec9

File tree

3 files changed

+28
-6
lines changed

3 files changed

+28
-6
lines changed

mlir/include/mlir/Reducer/ReductionPatternInterface.h

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
#define MLIR_REDUCER_REDUCTIONPATTERNINTERFACE_H
1111

1212
#include "mlir/IR/DialectInterface.h"
13+
#include "mlir/Reducer/Tester.h"
1314

1415
namespace mlir {
1516

@@ -47,10 +48,17 @@ class DialectReductionPatternInterface
4748
/// replacing an operation with a constant.
4849
virtual void populateReductionPatterns(RewritePatternSet &patterns) const = 0;
4950

51+
/// This method extends `populateReductionPatterns` by allowing reduction
52+
/// patterns to use a `Tester` instance. Some reduction patterns may need to
53+
/// run tester to determine whether certain transformations preserve the
54+
/// "interesting" behavior of the program. This is mostly useful when pattern
55+
/// should choose between multiple modifications.
56+
virtual void populateReductionPatternsWithTester(RewritePatternSet &patterns,
57+
Tester &tester) const {}
58+
5059
protected:
5160
DialectReductionPatternInterface(Dialect *dialect) : Base(dialect) {}
5261
};
53-
5462
} // namespace mlir
5563

5664
#endif // MLIR_REDUCER_REDUCTIONPATTERNINTERFACE_H

mlir/include/mlir/Reducer/Tester.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,9 @@ class Tester {
3636
Untested,
3737
};
3838

39+
Tester() = default;
40+
Tester(const Tester &) = default;
41+
3942
Tester(StringRef testScript, ArrayRef<std::string> testScriptArgs);
4043

4144
/// Runs the interestingness testing script on a MLIR test case file. Returns
@@ -46,6 +49,9 @@ class Tester {
4649
/// Return whether the file in the given path is interesting.
4750
Interestingness isInteresting(StringRef testCase) const;
4851

52+
void setTestScript(StringRef script) { testScript = script; }
53+
void setTestScriptArgs(ArrayRef<std::string> args) { testScriptArgs = args; }
54+
4955
private:
5056
StringRef testScript;
5157
ArrayRef<std::string> testScriptArgs;

mlir/lib/Reducer/ReductionTreePass.cpp

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -175,9 +175,12 @@ class ReductionPatternInterfaceCollection
175175
using Base::Base;
176176

177177
// Collect the reduce patterns defined by each dialect.
178-
void populateReductionPatterns(RewritePatternSet &pattern) const {
179-
for (const DialectReductionPatternInterface &interface : *this)
178+
void populateReductionPatterns(RewritePatternSet &pattern,
179+
Tester &tester) const {
180+
for (const DialectReductionPatternInterface &interface : *this) {
180181
interface.populateReductionPatterns(pattern);
182+
interface.populateReductionPatternsWithTester(pattern, tester);
183+
}
181184
}
182185
};
183186

@@ -201,15 +204,21 @@ class ReductionTreePass
201204
private:
202205
LogicalResult reduceOp(ModuleOp module, Region &region);
203206

207+
Tester tester;
204208
FrozenRewritePatternSet reducerPatterns;
205209
};
206210

207211
} // namespace
208212

209213
LogicalResult ReductionTreePass::initialize(MLIRContext *context) {
214+
tester.setTestScript(testerName);
215+
tester.setTestScriptArgs(testerArgs);
216+
210217
RewritePatternSet patterns(context);
218+
211219
ReductionPatternInterfaceCollection reducePatternCollection(context);
212-
reducePatternCollection.populateReductionPatterns(patterns);
220+
reducePatternCollection.populateReductionPatterns(patterns, tester);
221+
213222
reducerPatterns = std::move(patterns);
214223
return success();
215224
}
@@ -244,11 +253,10 @@ void ReductionTreePass::runOnOperation() {
244253
}
245254

246255
LogicalResult ReductionTreePass::reduceOp(ModuleOp module, Region &region) {
247-
Tester test(testerName, testerArgs);
248256
switch (traversalModeId) {
249257
case TraversalMode::SinglePath:
250258
return findOptimal<ReductionNode::iterator<TraversalMode::SinglePath>>(
251-
module, region, reducerPatterns, test);
259+
module, region, reducerPatterns, tester);
252260
default:
253261
return module.emitError() << "unsupported traversal mode detected";
254262
}

0 commit comments

Comments
 (0)