Skip to content

Commit 720b851

Browse files
committed
add replaceOperands option to mlir-reduce
1 parent 4a13f09 commit 720b851

File tree

2 files changed

+34
-8
lines changed

2 files changed

+34
-8
lines changed

mlir/include/mlir/Reducer/Passes.td

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,10 @@ def ReductionTreePass : Pass<"reduction-tree"> {
3131
Option<"traversalModeId", "traversal-mode", "unsigned",
3232
/* default */"0",
3333
"The graph traversal mode, the default is single-path mode">,
34+
Option<"replaceOperands", "replace-operands", "bool",
35+
/* default */"false",
36+
"Whether the pass should replace operands with previously defined values with the same type">,
37+
3438
] # CommonReductionPassOptions.options;
3539
}
3640

mlir/lib/Reducer/ReductionTreePass.cpp

Lines changed: 30 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -15,11 +15,13 @@
1515
//===----------------------------------------------------------------------===//
1616

1717
#include "mlir/IR/DialectInterface.h"
18+
#include "mlir/IR/Dominance.h"
1819
#include "mlir/Reducer/Passes.h"
1920
#include "mlir/Reducer/ReductionNode.h"
2021
#include "mlir/Reducer/ReductionPatternInterface.h"
2122
#include "mlir/Reducer/Tester.h"
2223
#include "mlir/Rewrite/FrozenRewritePatternSet.h"
24+
#include "mlir/Support/LLVM.h"
2325
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
2426

2527
#include "llvm/ADT/ArrayRef.h"
@@ -38,7 +40,7 @@ using namespace mlir;
3840
static void applyPatterns(Region &region,
3941
const FrozenRewritePatternSet &patterns,
4042
ArrayRef<ReductionNode::Range> rangeToKeep,
41-
bool eraseOpNotInRange) {
43+
bool eraseOpNotInRange, bool replaceOperands) {
4244
std::vector<Operation *> opsNotInRange;
4345
std::vector<Operation *> opsInRange;
4446
size_t keepIndex = 0;
@@ -53,17 +55,33 @@ static void applyPatterns(Region &region,
5355
opsInRange.push_back(&op.value());
5456
}
5557

58+
DominanceInfo domInfo(region.getParentOp());
59+
mlir::DenseMap<mlir::Type, mlir::SmallVector<mlir::Value, 5>> valueMap;
60+
5661
// `applyOpPatternsGreedily` with folding may erase the ops so we can't do the
5762
// pattern matching in above iteration. Besides, erase op not-in-range may end
5863
// up in invalid module, so `applyOpPatternsGreedily` with folding should come
5964
// before that transform.
6065
for (Operation *op : opsInRange) {
66+
if (replaceOperands)
67+
for (auto operandTie : llvm::enumerate(op->getOperands())) {
68+
size_t index = operandTie.index();
69+
auto operand = operandTie.value();
70+
for (auto candidate : valueMap[operand.getType()])
71+
if (domInfo.properlyDominates(candidate, op))
72+
op->setOperand(index, candidate);
73+
}
74+
6175
// `applyOpPatternsGreedily` with folding returns whether the op is
6276
// converted. Omit it because we don't have expectation this reduction will
6377
// be success or not.
6478
(void)applyOpPatternsGreedily(op, patterns,
6579
GreedyRewriteConfig().setStrictness(
6680
GreedyRewriteStrictness::ExistingOps));
81+
82+
if (op && replaceOperands)
83+
for (auto result : op->getResults())
84+
valueMap[result.getType()].push_back(result);
6785
}
6886

6987
if (eraseOpNotInRange)
@@ -83,7 +101,8 @@ static void applyPatterns(Region &region,
83101
template <typename IteratorType>
84102
static LogicalResult findOptimal(ModuleOp module, Region &region,
85103
const FrozenRewritePatternSet &patterns,
86-
const Tester &test, bool eraseOpNotInRange) {
104+
const Tester &test, bool eraseOpNotInRange,
105+
bool replaceOperands) {
87106
std::pair<Tester::Interestingness, size_t> initStatus =
88107
test.isInteresting(module);
89108
// While exploring the reduction tree, we always branch from an interesting
@@ -111,7 +130,7 @@ static LogicalResult findOptimal(ModuleOp module, Region &region,
111130
Region &curRegion = currentNode.getRegion();
112131

113132
applyPatterns(curRegion, patterns, currentNode.getRanges(),
114-
eraseOpNotInRange);
133+
eraseOpNotInRange, replaceOperands);
115134
currentNode.update(test.isInteresting(currentNode.getModule()));
116135

117136
if (currentNode.isInteresting() == Tester::Interestingness::True &&
@@ -134,7 +153,8 @@ static LogicalResult findOptimal(ModuleOp module, Region &region,
134153
// Reduce the region through the optimal path.
135154
while (!trace.empty()) {
136155
ReductionNode *top = trace.pop_back_val();
137-
applyPatterns(region, patterns, top->getStartRanges(), eraseOpNotInRange);
156+
applyPatterns(region, patterns, top->getStartRanges(), eraseOpNotInRange,
157+
replaceOperands);
138158
}
139159

140160
if (test.isInteresting(module).first != Tester::Interestingness::True)
@@ -148,19 +168,21 @@ static LogicalResult findOptimal(ModuleOp module, Region &region,
148168
template <typename IteratorType>
149169
static LogicalResult findOptimal(ModuleOp module, Region &region,
150170
const FrozenRewritePatternSet &patterns,
151-
const Tester &test) {
171+
const Tester &test, bool replaceOperands) {
152172
// We separate the reduction process into 2 steps, the first one is to erase
153173
// redundant operations and the second one is to apply the reducer patterns.
154174

155175
// In the first phase, we don't apply any patterns so that we only select the
156176
// range of operations to keep to the module stay interesting.
157177
if (failed(findOptimal<IteratorType>(module, region, /*patterns=*/{}, test,
158-
/*eraseOpNotInRange=*/true)))
178+
/*eraseOpNotInRange=*/true,
179+
replaceOperands)))
159180
return failure();
160181
// In the second phase, we suppose that no operation is redundant, so we try
161182
// to rewrite the operation into simpler form.
162183
return findOptimal<IteratorType>(module, region, patterns, test,
163-
/*eraseOpNotInRange=*/false);
184+
/*eraseOpNotInRange=*/false,
185+
/*replaceOperands=*/false);
164186
}
165187

166188
namespace {
@@ -248,7 +270,7 @@ LogicalResult ReductionTreePass::reduceOp(ModuleOp module, Region &region) {
248270
switch (traversalModeId) {
249271
case TraversalMode::SinglePath:
250272
return findOptimal<ReductionNode::iterator<TraversalMode::SinglePath>>(
251-
module, region, reducerPatterns, test);
273+
module, region, reducerPatterns, test, replaceOperands);
252274
default:
253275
return module.emitError() << "unsupported traversal mode detected";
254276
}

0 commit comments

Comments
 (0)