Skip to content

Commit 2c21790

Browse files
authored
Revert "[MLIR][SCF] Sink scf.if from scf.while before region into after region in scf-uplift-while-to-for" (#169888)
Reverts #165216 It is implemented in #169892 .
1 parent 29fef3a commit 2c21790

File tree

2 files changed

+1
-109
lines changed

2 files changed

+1
-109
lines changed

mlir/lib/Dialect/SCF/Transforms/UpliftWhileToFor.cpp

Lines changed: 1 addition & 78 deletions
Original file line numberDiff line numberDiff line change
@@ -19,83 +19,6 @@
1919
using namespace mlir;
2020

2121
namespace {
22-
/// Move an scf.if op that is directly before the scf.condition op in the while
23-
/// before region, and whose condition matches the condition of the
24-
/// scf.condition op, down into the while after region.
25-
///
26-
/// scf.while (%init) : (...) -> ... {
27-
/// %cond = ...
28-
/// %res = scf.if %cond -> (...) {
29-
/// use1(%init)
30-
/// %then_val = ...
31-
/// ... // then block
32-
/// scf.yield %then_val
33-
/// } else {
34-
/// scf.yield %init
35-
/// }
36-
/// scf.condition(%cond) %res
37-
/// } do {
38-
/// ^bb0(%arg):
39-
/// use2(%arg)
40-
/// ...
41-
///
42-
/// becomes
43-
/// scf.while (%init) : (...) -> ... {
44-
/// %cond = ...
45-
/// scf.condition(%cond) %init
46-
/// } do {
47-
/// ^bb0(%arg): :
48-
/// use1(%arg)
49-
/// ... // if then block
50-
/// %then_val = ...
51-
/// use2(%then_val)
52-
/// ...
53-
struct WhileMoveIfDown : public OpRewritePattern<scf::WhileOp> {
54-
using OpRewritePattern<scf::WhileOp>::OpRewritePattern;
55-
56-
LogicalResult matchAndRewrite(scf::WhileOp op,
57-
PatternRewriter &rewriter) const override {
58-
// Check that the first opeation produces one result and that result must
59-
// have exactly two uses (these two uses come from the `scf.if` and
60-
// `scf.condition` operations).
61-
Operation &condOp = op.getBeforeBody()->front();
62-
if (condOp.getNumResults() != 1 || !condOp.getResult(0).hasNUses(2))
63-
return failure();
64-
65-
Value condVal = condOp.getResult(0);
66-
auto ifOp = dyn_cast<scf::IfOp>(condOp.getNextNode());
67-
if (!ifOp || ifOp.getCondition() != condVal)
68-
return failure();
69-
70-
auto term = dyn_cast<scf::ConditionOp>(ifOp->getNextNode());
71-
if (!term || term.getCondition() != condVal)
72-
return failure();
73-
74-
// Check that if results and else yield operands match the scf.condition op
75-
// arguments and while before arguments respectively.
76-
if (!llvm::equal(ifOp->getResults(), term.getArgs()) ||
77-
!llvm::equal(ifOp.elseYield()->getOperands(), op.getBeforeArguments()))
78-
return failure();
79-
80-
// Update uses and move the if op into the after region.
81-
rewriter.replaceAllUsesWith(op.getAfterArguments(),
82-
ifOp.thenYield()->getOperands());
83-
rewriter.replaceUsesWithIf(op.getBeforeArguments(), op.getAfterArguments(),
84-
[&](OpOperand &use) {
85-
return ifOp.getThenRegion().isAncestor(
86-
use.getOwner()->getParentRegion());
87-
});
88-
rewriter.modifyOpInPlace(
89-
term, [&]() { term.getArgsMutable().assign(op.getBeforeArguments()); });
90-
91-
rewriter.eraseOp(ifOp.thenYield());
92-
rewriter.inlineBlockBefore(ifOp.thenBlock(), op.getAfterBody(),
93-
op.getAfterBody()->begin());
94-
rewriter.eraseOp(ifOp);
95-
return success();
96-
}
97-
};
98-
9922
struct UpliftWhileOp : public OpRewritePattern<scf::WhileOp> {
10023
using OpRewritePattern::OpRewritePattern;
10124

@@ -344,5 +267,5 @@ FailureOr<scf::ForOp> mlir::scf::upliftWhileToForLoop(RewriterBase &rewriter,
344267
}
345268

346269
void mlir::scf::populateUpliftWhileToForPatterns(RewritePatternSet &patterns) {
347-
patterns.add<UpliftWhileOp, WhileMoveIfDown>(patterns.getContext());
270+
patterns.add<UpliftWhileOp>(patterns.getContext());
348271
}

mlir/test/Dialect/SCF/uplift-while.mlir

Lines changed: 0 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -185,34 +185,3 @@ func.func @uplift_while(%arg0: index, %arg1: index, %arg2: index) -> (i32, f32)
185185
// CHECK: %[[T2:.*]] = "test.test2"(%[[ARG2]]) : (f32) -> f32
186186
// CHECK: scf.yield %[[T1]], %[[T2]] : i32, f32
187187
// CHECK: return %[[RES]]#0, %[[RES]]#1 : i32, f32
188-
189-
// -----
190-
191-
func.func @uplift_while(%low: index, %upper: index, %val : i32) -> i32 {
192-
%c1 = arith.constant 1 : index
193-
%1:2 = scf.while (%iv = %low, %iter = %val) : (index, i32) -> (index, i32) {
194-
%2 = arith.cmpi slt, %iv, %upper : index
195-
%3:2 = scf.if %2 -> (index, i32) {
196-
%4 = "test.test"(%iter) : (i32) -> i32
197-
%5 = arith.addi %iv, %c1 : index
198-
scf.yield %5, %4 : index, i32
199-
} else {
200-
scf.yield %iv, %iter : index, i32
201-
}
202-
scf.condition(%2) %3#0, %3#1 : index, i32
203-
} do {
204-
^bb0(%arg0: index, %arg1: i32):
205-
scf.yield %arg0, %arg1 : index, i32
206-
}
207-
return %1#1 : i32
208-
}
209-
210-
// CHECK-LABEL: func.func @uplift_while(
211-
// CHECK-SAME: %[[ARG0:.*]]: index, %[[ARG1:.*]]: index, %[[ARG2:.*]]: i32) -> i32 {
212-
// CHECK: %[[CONSTANT_0:.*]] = arith.constant 1 : index
213-
// CHECK: %[[FOR_0:.*]] = scf.for %[[VAL_0:.*]] = %[[ARG0]] to %[[ARG1]] step %[[CONSTANT_0]] iter_args(%[[VAL_1:.*]] = %[[ARG2]]) -> (i32) {
214-
// CHECK: %[[VAL_2:.*]] = "test.test"(%[[VAL_1]]) : (i32) -> i32
215-
// CHECK: scf.yield %[[VAL_2]] : i32
216-
// CHECK: }
217-
// CHECK: return %[[FOR_0]] : i32
218-
// CHECK: }

0 commit comments

Comments
 (0)