1515// ===----------------------------------------------------------------------===//
1616
1717#include " mlir/IR/DialectInterface.h"
18+ #include " mlir/IR/Dominance.h"
1819#include " mlir/Reducer/Passes.h"
1920#include " mlir/Reducer/ReductionNode.h"
2021#include " mlir/Reducer/ReductionPatternInterface.h"
2122#include " mlir/Reducer/Tester.h"
2223#include " mlir/Rewrite/FrozenRewritePatternSet.h"
24+ #include " mlir/Support/LLVM.h"
2325#include " mlir/Transforms/GreedyPatternRewriteDriver.h"
2426
2527#include " llvm/ADT/ArrayRef.h"
@@ -38,7 +40,7 @@ using namespace mlir;
3840static void applyPatterns (Region ®ion,
3941 const FrozenRewritePatternSet &patterns,
4042 ArrayRef<ReductionNode::Range> rangeToKeep,
41- bool eraseOpNotInRange) {
43+ bool eraseOpNotInRange, bool replaceOperands ) {
4244 std::vector<Operation *> opsNotInRange;
4345 std::vector<Operation *> opsInRange;
4446 size_t keepIndex = 0 ;
@@ -53,17 +55,33 @@ static void applyPatterns(Region ®ion,
5355 opsInRange.push_back (&op.value ());
5456 }
5557
58+ DominanceInfo domInfo (region.getParentOp ());
59+ mlir::DenseMap<mlir::Type, mlir::SmallVector<mlir::Value, 5 >> valueMap;
60+
5661 // `applyOpPatternsGreedily` with folding may erase the ops so we can't do the
5762 // pattern matching in above iteration. Besides, erase op not-in-range may end
5863 // up in invalid module, so `applyOpPatternsGreedily` with folding should come
5964 // before that transform.
6065 for (Operation *op : opsInRange) {
66+ if (replaceOperands)
67+ for (auto operandTie : llvm::enumerate (op->getOperands ())) {
68+ size_t index = operandTie.index ();
69+ auto operand = operandTie.value ();
70+ for (auto candidate : valueMap[operand.getType ()])
71+ if (domInfo.properlyDominates (candidate, op))
72+ op->setOperand (index, candidate);
73+ }
74+
6175 // `applyOpPatternsGreedily` with folding returns whether the op is
6276 // converted. Omit it because we don't have expectation this reduction will
6377 // be success or not.
6478 (void )applyOpPatternsGreedily (op, patterns,
6579 GreedyRewriteConfig ().setStrictness (
6680 GreedyRewriteStrictness::ExistingOps));
81+
82+ if (op && replaceOperands)
83+ for (auto result : op->getResults ())
84+ valueMap[result.getType ()].push_back (result);
6785 }
6886
6987 if (eraseOpNotInRange)
@@ -83,7 +101,8 @@ static void applyPatterns(Region ®ion,
83101template <typename IteratorType>
84102static LogicalResult findOptimal (ModuleOp module , Region ®ion,
85103 const FrozenRewritePatternSet &patterns,
86- const Tester &test, bool eraseOpNotInRange) {
104+ const Tester &test, bool eraseOpNotInRange,
105+ bool replaceOperands) {
87106 std::pair<Tester::Interestingness, size_t > initStatus =
88107 test.isInteresting (module );
89108 // While exploring the reduction tree, we always branch from an interesting
@@ -111,7 +130,7 @@ static LogicalResult findOptimal(ModuleOp module, Region ®ion,
111130 Region &curRegion = currentNode.getRegion ();
112131
113132 applyPatterns (curRegion, patterns, currentNode.getRanges (),
114- eraseOpNotInRange);
133+ eraseOpNotInRange, replaceOperands );
115134 currentNode.update (test.isInteresting (currentNode.getModule ()));
116135
117136 if (currentNode.isInteresting () == Tester::Interestingness::True &&
@@ -134,7 +153,8 @@ static LogicalResult findOptimal(ModuleOp module, Region ®ion,
134153 // Reduce the region through the optimal path.
135154 while (!trace.empty ()) {
136155 ReductionNode *top = trace.pop_back_val ();
137- applyPatterns (region, patterns, top->getStartRanges (), eraseOpNotInRange);
156+ applyPatterns (region, patterns, top->getStartRanges (), eraseOpNotInRange,
157+ replaceOperands);
138158 }
139159
140160 if (test.isInteresting (module ).first != Tester::Interestingness::True)
@@ -148,19 +168,21 @@ static LogicalResult findOptimal(ModuleOp module, Region ®ion,
148168template <typename IteratorType>
149169static LogicalResult findOptimal (ModuleOp module , Region ®ion,
150170 const FrozenRewritePatternSet &patterns,
151- const Tester &test) {
171+ const Tester &test, bool replaceOperands ) {
152172 // We separate the reduction process into 2 steps, the first one is to erase
153173 // redundant operations and the second one is to apply the reducer patterns.
154174
155175 // In the first phase, we don't apply any patterns so that we only select the
156176 // range of operations to keep to the module stay interesting.
157177 if (failed (findOptimal<IteratorType>(module , region, /* patterns=*/ {}, test,
158- /* eraseOpNotInRange=*/ true )))
178+ /* eraseOpNotInRange=*/ true ,
179+ replaceOperands)))
159180 return failure ();
160181 // In the second phase, we suppose that no operation is redundant, so we try
161182 // to rewrite the operation into simpler form.
162183 return findOptimal<IteratorType>(module , region, patterns, test,
163- /* eraseOpNotInRange=*/ false );
184+ /* eraseOpNotInRange=*/ false ,
185+ /* replaceOperands=*/ false );
164186}
165187
166188namespace {
@@ -248,7 +270,7 @@ LogicalResult ReductionTreePass::reduceOp(ModuleOp module, Region ®ion) {
248270 switch (traversalModeId) {
249271 case TraversalMode::SinglePath:
250272 return findOptimal<ReductionNode::iterator<TraversalMode::SinglePath>>(
251- module , region, reducerPatterns, test);
273+ module , region, reducerPatterns, test, replaceOperands );
252274 default :
253275 return module .emitError () << " unsupported traversal mode detected" ;
254276 }
0 commit comments