Skip to content

Commit db05b39

Browse files
committed
Move pattern into populateUpliftWhileToForPatterns.
1 parent a289930 commit db05b39

File tree

4 files changed

+110
-112
lines changed

4 files changed

+110
-112
lines changed

mlir/lib/Dialect/SCF/IR/SCF.cpp

Lines changed: 1 addition & 79 deletions
Original file line numberDiff line numberDiff line change
@@ -3546,83 +3546,6 @@ LogicalResult scf::WhileOp::verify() {
35463546
}
35473547

35483548
namespace {
3549-
/// Move an scf.if op that is directly before the scf.condition op in the while
3550-
/// before region, and whose condition matches the condition of the
3551-
/// scf.condition op, down into the while after region.
3552-
///
3553-
/// scf.while (%init) : (...) -> ... {
3554-
/// %cond = ...
3555-
/// %res = scf.if %cond -> (...) {
3556-
/// use1(%init)
3557-
/// %then_val = ...
3558-
/// ... // then block
3559-
/// scf.yield %then_val
3560-
/// } else {
3561-
/// scf.yield %init
3562-
/// }
3563-
/// scf.condition(%cond) %res
3564-
/// } do {
3565-
/// ^bb0(%arg):
3566-
/// use2(%arg)
3567-
/// ...
3568-
///
3569-
/// becomes
3570-
/// scf.while (%init) : (...) -> ... {
3571-
/// %cond = ...
3572-
/// scf.condition(%cond) %init
3573-
/// } do {
3574-
/// ^bb0(%arg): :
3575-
/// use1(%arg)
3576-
/// ... // if then block
3577-
/// %then_val = ...
3578-
/// use2(%then_val)
3579-
/// ...
3580-
struct WhileMoveIfDown : public OpRewritePattern<WhileOp> {
3581-
using OpRewritePattern<WhileOp>::OpRewritePattern;
3582-
3583-
LogicalResult matchAndRewrite(WhileOp op,
3584-
PatternRewriter &rewriter) const override {
3585-
// Check that the first opeation produces one result and that result must
3586-
// have exactly two uses (these two uses come from the `scf.if` and
3587-
// `scf.condition` operations).
3588-
Operation &condOp = op.getBeforeBody()->front();
3589-
if (condOp.getNumResults() != 1 || !condOp.getResult(0).hasNUses(2))
3590-
return failure();
3591-
3592-
Value condVal = condOp.getResult(0);
3593-
auto ifOp = dyn_cast<scf::IfOp>(condOp.getNextNode());
3594-
if (!ifOp || ifOp.getCondition() != condVal)
3595-
return failure();
3596-
3597-
auto term = dyn_cast<scf::ConditionOp>(ifOp->getNextNode());
3598-
if (!term || term.getCondition() != condVal)
3599-
return failure();
3600-
3601-
// Check that if results and else yield operands match the scf.condition op
3602-
// arguments and while before arguments respectively.
3603-
if (!llvm::equal(ifOp->getResults(), term.getArgs()) ||
3604-
!llvm::equal(ifOp.elseYield()->getOperands(), op.getBeforeArguments()))
3605-
return failure();
3606-
3607-
// Update uses and move the if op into the after region.
3608-
rewriter.replaceAllUsesWith(op.getAfterArguments(),
3609-
ifOp.thenYield()->getOperands());
3610-
rewriter.replaceUsesWithIf(op.getBeforeArguments(), op.getAfterArguments(),
3611-
[&](OpOperand &use) {
3612-
return ifOp.getThenRegion().isAncestor(
3613-
use.getOwner()->getParentRegion());
3614-
});
3615-
rewriter.modifyOpInPlace(
3616-
term, [&]() { term.getArgsMutable().assign(op.getBeforeArguments()); });
3617-
3618-
rewriter.eraseOp(ifOp.thenYield());
3619-
rewriter.inlineBlockBefore(ifOp.thenBlock(), op.getAfterBody(),
3620-
op.getAfterBody()->begin());
3621-
rewriter.eraseOp(ifOp);
3622-
return success();
3623-
}
3624-
};
3625-
36263549
/// Replace uses of the condition within the do block with true, since otherwise
36273550
/// the block would not be evaluated.
36283551
///
@@ -4335,8 +4258,7 @@ void WhileOp::getCanonicalizationPatterns(RewritePatternSet &results,
43354258
results.add<RemoveLoopInvariantArgsFromBeforeBlock,
43364259
RemoveLoopInvariantValueYielded, WhileConditionTruth,
43374260
WhileCmpCond, WhileUnusedResult, WhileRemoveDuplicatedResults,
4338-
WhileRemoveUnusedArgs, WhileOpAlignBeforeArgs, WhileMoveIfDown>(
4339-
context);
4261+
WhileRemoveUnusedArgs, WhileOpAlignBeforeArgs>(context);
43404262
}
43414263

43424264
//===----------------------------------------------------------------------===//

mlir/lib/Dialect/SCF/Transforms/UpliftWhileToFor.cpp

Lines changed: 78 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,83 @@
1919
using namespace mlir;
2020

2121
namespace {
22+
/// Move an scf.if op that is directly before the scf.condition op in the while
23+
/// before region, and whose condition matches the condition of the
24+
/// scf.condition op, down into the while after region.
25+
///
26+
/// scf.while (%init) : (...) -> ... {
27+
/// %cond = ...
28+
/// %res = scf.if %cond -> (...) {
29+
/// use1(%init)
30+
/// %then_val = ...
31+
/// ... // then block
32+
/// scf.yield %then_val
33+
/// } else {
34+
/// scf.yield %init
35+
/// }
36+
/// scf.condition(%cond) %res
37+
/// } do {
38+
/// ^bb0(%arg):
39+
/// use2(%arg)
40+
/// ...
41+
///
42+
/// becomes
43+
/// scf.while (%init) : (...) -> ... {
44+
/// %cond = ...
45+
/// scf.condition(%cond) %init
46+
/// } do {
47+
/// ^bb0(%arg): :
48+
/// use1(%arg)
49+
/// ... // if then block
50+
/// %then_val = ...
51+
/// use2(%then_val)
52+
/// ...
53+
struct WhileMoveIfDown : public OpRewritePattern<scf::WhileOp> {
54+
using OpRewritePattern<scf::WhileOp>::OpRewritePattern;
55+
56+
LogicalResult matchAndRewrite(scf::WhileOp op,
57+
PatternRewriter &rewriter) const override {
58+
// Check that the first opeation produces one result and that result must
59+
// have exactly two uses (these two uses come from the `scf.if` and
60+
// `scf.condition` operations).
61+
Operation &condOp = op.getBeforeBody()->front();
62+
if (condOp.getNumResults() != 1 || !condOp.getResult(0).hasNUses(2))
63+
return failure();
64+
65+
Value condVal = condOp.getResult(0);
66+
auto ifOp = dyn_cast<scf::IfOp>(condOp.getNextNode());
67+
if (!ifOp || ifOp.getCondition() != condVal)
68+
return failure();
69+
70+
auto term = dyn_cast<scf::ConditionOp>(ifOp->getNextNode());
71+
if (!term || term.getCondition() != condVal)
72+
return failure();
73+
74+
// Check that if results and else yield operands match the scf.condition op
75+
// arguments and while before arguments respectively.
76+
if (!llvm::equal(ifOp->getResults(), term.getArgs()) ||
77+
!llvm::equal(ifOp.elseYield()->getOperands(), op.getBeforeArguments()))
78+
return failure();
79+
80+
// Update uses and move the if op into the after region.
81+
rewriter.replaceAllUsesWith(op.getAfterArguments(),
82+
ifOp.thenYield()->getOperands());
83+
rewriter.replaceUsesWithIf(op.getBeforeArguments(), op.getAfterArguments(),
84+
[&](OpOperand &use) {
85+
return ifOp.getThenRegion().isAncestor(
86+
use.getOwner()->getParentRegion());
87+
});
88+
rewriter.modifyOpInPlace(
89+
term, [&]() { term.getArgsMutable().assign(op.getBeforeArguments()); });
90+
91+
rewriter.eraseOp(ifOp.thenYield());
92+
rewriter.inlineBlockBefore(ifOp.thenBlock(), op.getAfterBody(),
93+
op.getAfterBody()->begin());
94+
rewriter.eraseOp(ifOp);
95+
return success();
96+
}
97+
};
98+
2299
struct UpliftWhileOp : public OpRewritePattern<scf::WhileOp> {
23100
using OpRewritePattern::OpRewritePattern;
24101

@@ -267,5 +344,5 @@ FailureOr<scf::ForOp> mlir::scf::upliftWhileToForLoop(RewriterBase &rewriter,
267344
}
268345

269346
void mlir::scf::populateUpliftWhileToForPatterns(RewritePatternSet &patterns) {
270-
patterns.add<UpliftWhileOp>(patterns.getContext());
347+
patterns.add<UpliftWhileOp, WhileMoveIfDown>(patterns.getContext());
271348
}

mlir/test/Dialect/SCF/canonicalize.mlir

Lines changed: 0 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -974,38 +974,6 @@ func.func @replace_if_with_cond3(%arg0 : i1, %arg2: i64) -> (i32, i64) {
974974

975975
// -----
976976

977-
// CHECK-LABEL: @while_move_if_down
978-
// CHECK-SAME: (%[[ARG0:.*]]: index, %[[ARG1:.*]]: i32)
979-
func.func @while_move_if_down(%arg0: index, %arg1: i32) -> i32 {
980-
%0:2 = scf.while (%init0 = %arg0, %init1 = %arg1) : (index, i32) -> (index, i32) {
981-
%condition = "test.condition"() : () -> i1
982-
%res:2 = scf.if %condition -> (index, i32) {
983-
%then_val:2 = "test.use1"(%init0, %init1) : (index, i32) -> (i32, index)
984-
scf.yield %then_val#1, %then_val#0 : index, i32
985-
} else {
986-
scf.yield %init0, %init1 : index, i32
987-
}
988-
scf.condition(%condition) %res#0, %res#1 : index, i32
989-
} do {
990-
^bb0(%arg2: index, %arg3: i32):
991-
%1:2 = "test.use2"(%arg2, %arg3) : (index, i32) -> (i32, index)
992-
scf.yield %1#1, %1#0 : index, i32
993-
}
994-
return %0#1 : i32
995-
}
996-
// CHECK-NEXT: %[[WHILE_0:.*]]:2 = scf.while (%[[VAL_0:.*]] = %[[ARG0]], %[[VAL_1:.*]] = %[[ARG1]]) : (index, i32) -> (index, i32) {
997-
// CHECK-NEXT: %[[VAL_2:.*]] = "test.condition"() : () -> i1
998-
// CHECK-NEXT: scf.condition(%[[VAL_2]]) %[[VAL_0]], %[[VAL_1]] : index, i32
999-
// CHECK-NEXT: } do {
1000-
// CHECK-NEXT: ^bb0(%[[VAL_3:.*]]: index, %[[VAL_4:.*]]: i32):
1001-
// CHECK-NEXT: %[[VAL_5:.*]]:2 = "test.use1"(%[[VAL_3]], %[[VAL_4]]) : (index, i32) -> (i32, index)
1002-
// CHECK-NEXT: %[[VAL_6:.*]]:2 = "test.use2"(%[[VAL_5]]#1, %[[VAL_5]]#0) : (index, i32) -> (i32, index)
1003-
// CHECK-NEXT: scf.yield %[[VAL_6]]#1, %[[VAL_6]]#0 : index, i32
1004-
// CHECK-NEXT: }
1005-
// CHECK-NEXT: return %[[VAL_7:.*]]#1 : i32
1006-
1007-
// -----
1008-
1009977
// CHECK-LABEL: @while_cond_true
1010978
func.func @while_cond_true() -> i1 {
1011979
%0 = scf.while () : () -> i1 {

mlir/test/Dialect/SCF/uplift-while.mlir

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -185,3 +185,34 @@ func.func @uplift_while(%arg0: index, %arg1: index, %arg2: index) -> (i32, f32)
185185
// CHECK: %[[T2:.*]] = "test.test2"(%[[ARG2]]) : (f32) -> f32
186186
// CHECK: scf.yield %[[T1]], %[[T2]] : i32, f32
187187
// CHECK: return %[[RES]]#0, %[[RES]]#1 : i32, f32
188+
189+
// -----
190+
191+
func.func @uplift_while(%low: index, %upper: index, %val : i32) -> i32 {
192+
%c1 = arith.constant 1 : index
193+
%1:2 = scf.while (%iv = %low, %iter = %val) : (index, i32) -> (index, i32) {
194+
%2 = arith.cmpi slt, %iv, %upper : index
195+
%3:2 = scf.if %2 -> (index, i32) {
196+
%4 = "test.test"(%iter) : (i32) -> i32
197+
%5 = arith.addi %iv, %c1 : index
198+
scf.yield %5, %4 : index, i32
199+
} else {
200+
scf.yield %iv, %iter : index, i32
201+
}
202+
scf.condition(%2) %3#0, %3#1 : index, i32
203+
} do {
204+
^bb0(%arg0: index, %arg1: i32):
205+
scf.yield %arg0, %arg1 : index, i32
206+
}
207+
return %1#1 : i32
208+
}
209+
210+
// CHECK-LABEL: func.func @uplift_while(
211+
// CHECK-SAME: %[[ARG0:.*]]: index, %[[ARG1:.*]]: index, %[[ARG2:.*]]: i32) -> i32 {
212+
// CHECK: %[[CONSTANT_0:.*]] = arith.constant 1 : index
213+
// CHECK: %[[FOR_0:.*]] = scf.for %[[VAL_0:.*]] = %[[ARG0]] to %[[ARG1]] step %[[CONSTANT_0]] iter_args(%[[VAL_1:.*]] = %[[ARG2]]) -> (i32) {
214+
// CHECK: %[[VAL_2:.*]] = "test.test"(%[[VAL_1]]) : (i32) -> i32
215+
// CHECK: scf.yield %[[VAL_2]] : i32
216+
// CHECK: }
217+
// CHECK: return %[[FOR_0]] : i32
218+
// CHECK: }

0 commit comments

Comments
 (0)