Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion mlir/include/mlir/Reducer/ReductionPatternInterface.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
#define MLIR_REDUCER_REDUCTIONPATTERNINTERFACE_H

#include "mlir/IR/DialectInterface.h"
#include "mlir/Reducer/Tester.h"

namespace mlir {

Expand Down Expand Up @@ -47,10 +48,17 @@ class DialectReductionPatternInterface
/// replacing an operation with a constant.
virtual void populateReductionPatterns(RewritePatternSet &patterns) const = 0;

/// This method extends `populateReductionPatterns` by allowing reduction
/// patterns to use a `Tester` instance. Some reduction patterns may need to
/// run tester to determine whether certain transformations preserve the
/// "interesting" behavior of the program. This is mostly useful when pattern
/// should choose between multiple modifications.
virtual void populateReductionPatternsWithTester(RewritePatternSet &patterns,
Tester &tester) const {}

protected:
DialectReductionPatternInterface(Dialect *dialect) : Base(dialect) {}
};

} // namespace mlir

#endif // MLIR_REDUCER_REDUCTIONPATTERNINTERFACE_H
6 changes: 6 additions & 0 deletions mlir/include/mlir/Reducer/Tester.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,9 @@ class Tester {
Untested,
};

Tester() = default;
Tester(const Tester &) = default;

Tester(StringRef testScript, ArrayRef<std::string> testScriptArgs);

/// Runs the interestingness testing script on a MLIR test case file. Returns
Expand All @@ -46,6 +49,9 @@ class Tester {
/// Return whether the file in the given path is interesting.
Interestingness isInteresting(StringRef testCase) const;

void setTestScript(StringRef script) { testScript = script; }
void setTestScriptArgs(ArrayRef<std::string> args) { testScriptArgs = args; }

private:
StringRef testScript;
ArrayRef<std::string> testScriptArgs;
Expand Down
18 changes: 13 additions & 5 deletions mlir/lib/Reducer/ReductionTreePass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -175,9 +175,12 @@ class ReductionPatternInterfaceCollection
using Base::Base;

// Collect the reduce patterns defined by each dialect.
void populateReductionPatterns(RewritePatternSet &pattern) const {
for (const DialectReductionPatternInterface &interface : *this)
void populateReductionPatterns(RewritePatternSet &pattern,
Tester &tester) const {
for (const DialectReductionPatternInterface &interface : *this) {
interface.populateReductionPatterns(pattern);
interface.populateReductionPatternsWithTester(pattern, tester);
}
}
};

Expand All @@ -201,15 +204,21 @@ class ReductionTreePass
private:
LogicalResult reduceOp(ModuleOp module, Region &region);

Tester tester;
FrozenRewritePatternSet reducerPatterns;
};

} // namespace

LogicalResult ReductionTreePass::initialize(MLIRContext *context) {
tester.setTestScript(testerName);
tester.setTestScriptArgs(testerArgs);

RewritePatternSet patterns(context);

ReductionPatternInterfaceCollection reducePatternCollection(context);
reducePatternCollection.populateReductionPatterns(patterns);
reducePatternCollection.populateReductionPatterns(patterns, tester);

reducerPatterns = std::move(patterns);
return success();
}
Expand Down Expand Up @@ -244,11 +253,10 @@ void ReductionTreePass::runOnOperation() {
}

LogicalResult ReductionTreePass::reduceOp(ModuleOp module, Region &region) {
Tester test(testerName, testerArgs);
switch (traversalModeId) {
case TraversalMode::SinglePath:
return findOptimal<ReductionNode::iterator<TraversalMode::SinglePath>>(
module, region, reducerPatterns, test);
module, region, reducerPatterns, tester);
default:
return module.emitError() << "unsupported traversal mode detected";
}
Expand Down