Skip to content

Commit d76de86

Browse files
committed
Use visitUsedValuesDefinedAbove to simplify the code.
1 parent 19a042e commit d76de86

File tree

1 file changed

+7
-16
lines changed

1 file changed

+7
-16
lines changed

mlir/lib/Dialect/SCF/Transforms/UpliftWhileToFor.cpp

Lines changed: 7 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
#include "mlir/Dialect/SCF/Transforms/Patterns.h"
1616
#include "mlir/IR/Dominance.h"
1717
#include "mlir/IR/PatternMatch.h"
18+
#include "mlir/Transforms/RegionUtils.h"
1819

1920
using namespace mlir;
2021

@@ -89,22 +90,12 @@ struct WhileMoveIfDown : public OpRewritePattern<scf::WhileOp> {
8990
}
9091
}
9192

92-
SmallVector<Value> additionalUsedValues;
93-
auto isValueUsedInsideIf = [&](Value val) {
94-
return llvm::any_of(val.getUsers(), [&](Operation *user) {
95-
return ifOp.getThenRegion().isAncestor(user->getParentRegion());
96-
});
97-
};
98-
9993
// Collect additional used values from before region.
100-
for (Operation *it = ifOp->getPrevNode(); it != nullptr;
101-
it = it->getPrevNode())
102-
llvm::copy_if(it->getResults(), std::back_inserter(additionalUsedValues),
103-
isValueUsedInsideIf);
104-
105-
llvm::copy_if(op.getBeforeArguments(),
106-
std::back_inserter(additionalUsedValues),
107-
isValueUsedInsideIf);
94+
SetVector<Value> additionalUsedValues;
95+
visitUsedValuesDefinedAbove(ifOp.getThenRegion(), [&](OpOperand *operand) {
96+
if (op.getBefore().isAncestor(operand->get().getParentRegion()))
97+
additionalUsedValues.insert(operand->get());
98+
});
10899

109100
// Create new whileOp with additional used values as results.
110101
auto additionalValueTypes = llvm::map_to_vector(
@@ -132,7 +123,7 @@ struct WhileMoveIfDown : public OpRewritePattern<scf::WhileOp> {
132123
// Replace uses of additional used values inside the ifOp then region with
133124
// the whileOp after region arguments.
134125
rewriter.replaceUsesWithIf(
135-
additionalUsedValues,
126+
additionalUsedValues.takeVector(),
136127
newWhileOp.getAfterArguments().take_back(additionalValueSize),
137128
[&](OpOperand &use) {
138129
return ifOp.getThenRegion().isAncestor(

0 commit comments

Comments
 (0)