|
15 | 15 | #include "mlir/Dialect/SCF/Transforms/Patterns.h" |
16 | 16 | #include "mlir/IR/Dominance.h" |
17 | 17 | #include "mlir/IR/PatternMatch.h" |
| 18 | +#include "mlir/Transforms/RegionUtils.h" |
18 | 19 |
|
19 | 20 | using namespace mlir; |
20 | 21 |
|
@@ -89,22 +90,12 @@ struct WhileMoveIfDown : public OpRewritePattern<scf::WhileOp> { |
89 | 90 | } |
90 | 91 | } |
91 | 92 |
|
92 | | - SmallVector<Value> additionalUsedValues; |
93 | | - auto isValueUsedInsideIf = [&](Value val) { |
94 | | - return llvm::any_of(val.getUsers(), [&](Operation *user) { |
95 | | - return ifOp.getThenRegion().isAncestor(user->getParentRegion()); |
96 | | - }); |
97 | | - }; |
98 | | - |
99 | 93 | // Collect additional used values from before region. |
100 | | - for (Operation *it = ifOp->getPrevNode(); it != nullptr; |
101 | | - it = it->getPrevNode()) |
102 | | - llvm::copy_if(it->getResults(), std::back_inserter(additionalUsedValues), |
103 | | - isValueUsedInsideIf); |
104 | | - |
105 | | - llvm::copy_if(op.getBeforeArguments(), |
106 | | - std::back_inserter(additionalUsedValues), |
107 | | - isValueUsedInsideIf); |
| 94 | + SetVector<Value> additionalUsedValues; |
| 95 | + visitUsedValuesDefinedAbove(ifOp.getThenRegion(), [&](OpOperand *operand) { |
| 96 | + if (op.getBefore().isAncestor(operand->get().getParentRegion())) |
| 97 | + additionalUsedValues.insert(operand->get()); |
| 98 | + }); |
108 | 99 |
|
109 | 100 | // Create new whileOp with additional used values as results. |
110 | 101 | auto additionalValueTypes = llvm::map_to_vector( |
@@ -132,7 +123,7 @@ struct WhileMoveIfDown : public OpRewritePattern<scf::WhileOp> { |
132 | 123 | // Replace uses of additional used values inside the ifOp then region with |
133 | 124 | // the whileOp after region arguments. |
134 | 125 | rewriter.replaceUsesWithIf( |
135 | | - additionalUsedValues, |
| 126 | + additionalUsedValues.takeVector(), |
136 | 127 | newWhileOp.getAfterArguments().take_back(additionalValueSize), |
137 | 128 | [&](OpOperand &use) { |
138 | 129 | return ifOp.getThenRegion().isAncestor( |
|
0 commit comments