Skip to content

Commit 56dddbd

Browse files
committed
Add more constraints to simplify the transformation process and move it into canonicalization.
We do not have an appropriate cost model for evaluation, and adding more constraints can ensure that it is always beneficial and does not introduce excessive compile-time overhead.
1 parent d76de86 commit 56dddbd

File tree

4 files changed

+111
-156
lines changed

4 files changed

+111
-156
lines changed

mlir/lib/Dialect/SCF/IR/SCF.cpp

Lines changed: 78 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3546,6 +3546,82 @@ LogicalResult scf::WhileOp::verify() {
35463546
}
35473547

35483548
namespace {
3549+
/// Move an scf.if op that is directly before the scf.condition op in the while
3550+
/// before region, and whose condition matches the condition of the
3551+
/// scf.condition op, down into the while after region.
3552+
///
3553+
/// scf.while (%init) : (...) -> ... {
3554+
/// %cond = ...
3555+
/// %res = scf.if %cond -> (...) {
3556+
/// use1(%init)
3557+
/// %then_val = ...
3558+
/// ... // then block
3559+
/// scf.yield %then_val
3560+
/// } else {
3561+
/// scf.yield %init
3562+
/// }
3563+
/// scf.condition(%cond) %res
3564+
/// } do {
3565+
/// ^bb0(%arg):
3566+
/// use2(%arg)
3567+
/// ...
3568+
///
3569+
/// becomes
3570+
/// scf.while (%init) : (...) -> ... {
3571+
/// %cond = ...
3572+
/// scf.condition(%cond) %init
3573+
/// } do {
3574+
/// ^bb0(%arg): :
3575+
/// use1(%arg)
3576+
/// ... // if then block
3577+
/// %then_val = ...
3578+
/// use2(%then_val)
3579+
/// ...
3580+
struct WhileMoveIfDown : public OpRewritePattern<WhileOp> {
3581+
using OpRewritePattern<WhileOp>::OpRewritePattern;
3582+
3583+
LogicalResult matchAndRewrite(WhileOp op,
3584+
PatternRewriter &rewriter) const override {
3585+
// Check that the first operation in the before region is an scf.if whose
3586+
// condition matches the condition of the scf.condition op.
3587+
Operation &condOp = op.getBeforeBody()->front();
3588+
if (condOp.getNumResults() != 1 || !condOp.getResult(0).hasNUses(2))
3589+
return failure();
3590+
3591+
Value condVal = condOp.getResult(0);
3592+
auto ifOp = dyn_cast<scf::IfOp>(condOp.getNextNode());
3593+
if (condOp.getNumResults() != 1 || !ifOp || ifOp.getCondition() != condVal)
3594+
return failure();
3595+
3596+
auto term = dyn_cast<scf::ConditionOp>(ifOp->getNextNode());
3597+
if (!term || term.getCondition() != condVal)
3598+
return failure();
3599+
3600+
// Check that if results and else yield operands match the scf.condition op
3601+
// arguments and while before arguments respectively.
3602+
if (!llvm::equal(ifOp->getResults(), term.getArgs()) ||
3603+
!llvm::equal(ifOp.elseYield()->getOperands(), op.getBeforeArguments()))
3604+
return failure();
3605+
3606+
// Update uses and move the if op into the after region.
3607+
rewriter.replaceAllUsesWith(op.getAfterArguments(),
3608+
ifOp.thenYield()->getOperands());
3609+
rewriter.replaceUsesWithIf(op.getBeforeArguments(), op.getAfterArguments(),
3610+
[&](OpOperand &use) {
3611+
return ifOp.getThenRegion().isAncestor(
3612+
use.getOwner()->getParentRegion());
3613+
});
3614+
rewriter.modifyOpInPlace(
3615+
term, [&]() { term.getArgsMutable().assign(op.getBeforeArguments()); });
3616+
3617+
rewriter.eraseOp(ifOp.thenYield());
3618+
rewriter.inlineBlockBefore(ifOp.thenBlock(), op.getAfterBody(),
3619+
op.getAfterBody()->begin());
3620+
rewriter.eraseOp(ifOp);
3621+
return success();
3622+
}
3623+
};
3624+
35493625
/// Replace uses of the condition within the do block with true, since otherwise
35503626
/// the block would not be evaluated.
35513627
///
@@ -4258,7 +4334,8 @@ void WhileOp::getCanonicalizationPatterns(RewritePatternSet &results,
42584334
results.add<RemoveLoopInvariantArgsFromBeforeBlock,
42594335
RemoveLoopInvariantValueYielded, WhileConditionTruth,
42604336
WhileCmpCond, WhileUnusedResult, WhileRemoveDuplicatedResults,
4261-
WhileRemoveUnusedArgs, WhileOpAlignBeforeArgs>(context);
4337+
WhileRemoveUnusedArgs, WhileOpAlignBeforeArgs, WhileMoveIfDown>(
4338+
context);
42624339
}
42634340

42644341
//===----------------------------------------------------------------------===//

mlir/lib/Dialect/SCF/Transforms/UpliftWhileToFor.cpp

Lines changed: 1 addition & 124 deletions
Original file line numberDiff line numberDiff line change
@@ -15,132 +15,10 @@
1515
#include "mlir/Dialect/SCF/Transforms/Patterns.h"
1616
#include "mlir/IR/Dominance.h"
1717
#include "mlir/IR/PatternMatch.h"
18-
#include "mlir/Transforms/RegionUtils.h"
1918

2019
using namespace mlir;
2120

2221
namespace {
23-
/// Move a scf.if op that is directly before the scf.condition op in the while
24-
/// before region, and whose condition matches the condition of the
25-
/// scf.condition op, down into the while after region.
26-
///
27-
/// scf.while (..) : (...) -> ... {
28-
/// %additional_used_values = ...
29-
/// %cond = ...
30-
/// ...
31-
/// %res = scf.if %cond -> (...) {
32-
/// use(%additional_used_values)
33-
/// ... // then block
34-
/// scf.yield %then_value
35-
/// } else {
36-
/// scf.yield %else_value
37-
/// }
38-
/// scf.condition(%cond) %res, ...
39-
/// } do {
40-
/// ^bb0(%res_arg, ...):
41-
/// use(%res_arg)
42-
/// ...
43-
///
44-
/// becomes
45-
/// scf.while (..) : (...) -> ... {
46-
/// %additional_used_values = ...
47-
/// %cond = ...
48-
/// ...
49-
/// scf.condition(%cond) %else_value, ..., %additional_used_values
50-
/// } do {
51-
/// ^bb0(%res_arg ..., %additional_args): :
52-
/// use(%additional_args)
53-
/// ... // if then block
54-
/// use(%then_value)
55-
/// ...
56-
struct WhileMoveIfDown : public OpRewritePattern<scf::WhileOp> {
57-
using OpRewritePattern<scf::WhileOp>::OpRewritePattern;
58-
59-
LogicalResult matchAndRewrite(scf::WhileOp op,
60-
PatternRewriter &rewriter) const override {
61-
auto conditionOp =
62-
cast<scf::ConditionOp>(op.getBeforeBody()->getTerminator());
63-
auto ifOp = dyn_cast_or_null<scf::IfOp>(conditionOp->getPrevNode());
64-
65-
// Check that the ifOp is directly before the conditionOp and that it
66-
// matches the condition of the conditionOp. Also ensure that the ifOp has
67-
// no else block with content, as that would complicate the transformation.
68-
// TODO: support else blocks with content.
69-
if (!ifOp || ifOp.getCondition() != conditionOp.getCondition() ||
70-
(ifOp.elseBlock() && !ifOp.elseBlock()->without_terminator().empty()))
71-
return failure();
72-
73-
assert(ifOp->use_empty() || (llvm::all_equal(ifOp->getUsers()) &&
74-
*ifOp->user_begin() == conditionOp) &&
75-
"ifOp has unexpected uses");
76-
77-
Location loc = op.getLoc();
78-
79-
// Replace uses of ifOp results in the conditionOp with the yielded values
80-
// from the ifOp branches.
81-
for (auto [idx, arg] : llvm::enumerate(conditionOp.getArgs())) {
82-
auto it = llvm::find(ifOp->getResults(), arg);
83-
if (it != ifOp->getResults().end()) {
84-
size_t ifOpIdx = it.getIndex();
85-
Value thenValue = ifOp.thenYield()->getOperand(ifOpIdx);
86-
Value elseValue = ifOp.elseYield()->getOperand(ifOpIdx);
87-
88-
rewriter.replaceAllUsesWith(ifOp->getResults()[ifOpIdx], elseValue);
89-
rewriter.replaceAllUsesWith(op.getAfterArguments()[idx], thenValue);
90-
}
91-
}
92-
93-
// Collect additional used values from before region.
94-
SetVector<Value> additionalUsedValues;
95-
visitUsedValuesDefinedAbove(ifOp.getThenRegion(), [&](OpOperand *operand) {
96-
if (op.getBefore().isAncestor(operand->get().getParentRegion()))
97-
additionalUsedValues.insert(operand->get());
98-
});
99-
100-
// Create new whileOp with additional used values as results.
101-
auto additionalValueTypes = llvm::map_to_vector(
102-
additionalUsedValues, [](Value val) { return val.getType(); });
103-
size_t additionalValueSize = additionalUsedValues.size();
104-
SmallVector<Type> newResultTypes(op.getResultTypes());
105-
newResultTypes.append(additionalValueTypes);
106-
107-
auto newWhileOp =
108-
scf::WhileOp::create(rewriter, loc, newResultTypes, op.getInits());
109-
110-
newWhileOp.getBefore().takeBody(op.getBefore());
111-
newWhileOp.getAfter().takeBody(op.getAfter());
112-
newWhileOp.getAfter().addArguments(
113-
additionalValueTypes, SmallVector<Location>(additionalValueSize, loc));
114-
115-
SmallVector<Value> conditionArgs = conditionOp.getArgs();
116-
llvm::append_range(conditionArgs, additionalUsedValues);
117-
118-
// Update conditionOp inside new whileOp before region.
119-
rewriter.setInsertionPoint(conditionOp);
120-
rewriter.replaceOpWithNewOp<scf::ConditionOp>(
121-
conditionOp, conditionOp.getCondition(), conditionArgs);
122-
123-
// Replace uses of additional used values inside the ifOp then region with
124-
// the whileOp after region arguments.
125-
rewriter.replaceUsesWithIf(
126-
additionalUsedValues.takeVector(),
127-
newWhileOp.getAfterArguments().take_back(additionalValueSize),
128-
[&](OpOperand &use) {
129-
return ifOp.getThenRegion().isAncestor(
130-
use.getOwner()->getParentRegion());
131-
});
132-
133-
// Inline ifOp then region into new whileOp after region.
134-
rewriter.eraseOp(ifOp.thenYield());
135-
rewriter.inlineBlockBefore(ifOp.thenBlock(), newWhileOp.getAfterBody(),
136-
newWhileOp.getAfterBody()->begin());
137-
rewriter.eraseOp(ifOp);
138-
rewriter.replaceOp(op,
139-
newWhileOp->getResults().drop_back(additionalValueSize));
140-
return success();
141-
}
142-
};
143-
14422
struct UpliftWhileOp : public OpRewritePattern<scf::WhileOp> {
14523
using OpRewritePattern::OpRewritePattern;
14624

@@ -389,6 +267,5 @@ FailureOr<scf::ForOp> mlir::scf::upliftWhileToForLoop(RewriterBase &rewriter,
389267
}
390268

391269
void mlir::scf::populateUpliftWhileToForPatterns(RewritePatternSet &patterns) {
392-
patterns.add<WhileMoveIfDown, UpliftWhileOp>(patterns.getContext());
393-
scf::WhileOp::getCanonicalizationPatterns(patterns, patterns.getContext());
270+
patterns.add<UpliftWhileOp>(patterns.getContext());
394271
}

mlir/test/Dialect/SCF/canonicalize.mlir

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -974,6 +974,38 @@ func.func @replace_if_with_cond3(%arg0 : i1, %arg2: i64) -> (i32, i64) {
974974

975975
// -----
976976

977+
// CHECK-LABEL: @while_move_if_down
978+
// CHECK-SAME: (%[[ARG0:.*]]: index, %[[ARG1:.*]]: i32)
979+
func.func @while_move_if_down(%arg0: index, %arg1: i32) -> i32 {
980+
%0:2 = scf.while (%init0 = %arg0, %init1 = %arg1) : (index, i32) -> (index, i32) {
981+
%condition = "test.condition"() : () -> i1
982+
%res:2 = scf.if %condition -> (index, i32) {
983+
%then_val:2 = "test.use1"(%init0, %init1) : (index, i32) -> (i32, index)
984+
scf.yield %then_val#1, %then_val#0 : index, i32
985+
} else {
986+
scf.yield %init0, %init1 : index, i32
987+
}
988+
scf.condition(%condition) %res#0, %res#1 : index, i32
989+
} do {
990+
^bb0(%arg2: index, %arg3: i32):
991+
%1:2 = "test.use2"(%arg2, %arg3) : (index, i32) -> (i32, index)
992+
scf.yield %1#1, %1#0 : index, i32
993+
}
994+
return %0#1 : i32
995+
}
996+
// CHECK-NEXT: %[[WHILE_0:.*]]:2 = scf.while (%[[VAL_0:.*]] = %[[ARG0]], %[[VAL_1:.*]] = %[[ARG1]]) : (index, i32) -> (index, i32) {
997+
// CHECK-NEXT: %[[VAL_2:.*]] = "test.condition"() : () -> i1
998+
// CHECK-NEXT: scf.condition(%[[VAL_2]]) %[[VAL_0]], %[[VAL_1]] : index, i32
999+
// CHECK-NEXT: } do {
1000+
// CHECK-NEXT: ^bb0(%[[VAL_3:.*]]: index, %[[VAL_4:.*]]: i32):
1001+
// CHECK-NEXT: %[[VAL_5:.*]]:2 = "test.use1"(%[[VAL_3]], %[[VAL_4]]) : (index, i32) -> (i32, index)
1002+
// CHECK-NEXT: %[[VAL_6:.*]]:2 = "test.use2"(%[[VAL_5]]#1, %[[VAL_5]]#0) : (index, i32) -> (i32, index)
1003+
// CHECK-NEXT: scf.yield %[[VAL_6]]#1, %[[VAL_6]]#0 : index, i32
1004+
// CHECK-NEXT: }
1005+
// CHECK-NEXT: return %[[VAL_7:.*]]#1 : i32
1006+
1007+
// -----
1008+
9771009
// CHECK-LABEL: @while_cond_true
9781010
func.func @while_cond_true() -> i1 {
9791011
%0 = scf.while () : () -> i1 {

mlir/test/Dialect/SCF/uplift-while.mlir

Lines changed: 0 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -185,34 +185,3 @@ func.func @uplift_while(%arg0: index, %arg1: index, %arg2: index) -> (i32, f32)
185185
// CHECK: %[[T2:.*]] = "test.test2"(%[[ARG2]]) : (f32) -> f32
186186
// CHECK: scf.yield %[[T1]], %[[T2]] : i32, f32
187187
// CHECK: return %[[RES]]#0, %[[RES]]#1 : i32, f32
188-
189-
// -----
190-
191-
func.func @uplift_while(%low: index, %upper: index, %val : i32) -> i32 {
192-
%c1 = arith.constant 1 : index
193-
%1:2 = scf.while (%iv = %low, %iter = %val) : (index, i32) -> (index, i32) {
194-
%2 = arith.cmpi slt, %iv, %upper : index
195-
%3:2 = scf.if %2 -> (index, i32) {
196-
%4 = "test.test"(%iter) : (i32) -> i32
197-
%5 = arith.addi %iv, %c1 : index
198-
scf.yield %5, %4 : index, i32
199-
} else {
200-
scf.yield %iv, %iter : index, i32
201-
}
202-
scf.condition(%2) %3#0, %3#1 : index, i32
203-
} do {
204-
^bb0(%arg0: index, %arg1: i32):
205-
scf.yield %arg0, %arg1 : index, i32
206-
}
207-
return %1#1 : i32
208-
}
209-
210-
// CHECK-LABEL: func.func @uplift_while(
211-
// CHECK-SAME: %[[ARG0:.*]]: index, %[[ARG1:.*]]: index, %[[ARG2:.*]]: i32) -> i32 {
212-
// CHECK: %[[CONSTANT_0:.*]] = arith.constant 1 : index
213-
// CHECK: %[[FOR_0:.*]] = scf.for %[[VAL_0:.*]] = %[[ARG0]] to %[[ARG1]] step %[[CONSTANT_0]] iter_args(%[[VAL_1:.*]] = %[[ARG2]]) -> (i32) {
214-
// CHECK: %[[VAL_2:.*]] = "test.test"(%[[VAL_1]]) : (i32) -> i32
215-
// CHECK: scf.yield %[[VAL_2]] : i32
216-
// CHECK: }
217-
// CHECK: return %[[FOR_0]] : i32
218-
// CHECK: }

0 commit comments

Comments
 (0)