Skip to content

Commit 8ceeba8

Browse files
authored
[MLIR][SCF] Canonicalize redundant scf.if from scf.while before region into after region (llvm#169892)
When a `scf.if` directly precedes a `scf.condition` in the before region of a `scf.while` and both share the same condition, move the if into the after region of the loop. This helps simplify the control flow to enable uplifting `scf.while` to `scf.for`.
1 parent b7721c5 commit 8ceeba8

File tree

2 files changed

+180
-1
lines changed

2 files changed

+180
-1
lines changed

mlir/lib/Dialect/SCF/IR/SCF.cpp

Lines changed: 130 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
#include "mlir/Interfaces/ParallelCombiningOpInterface.h"
2727
#include "mlir/Interfaces/ValueBoundsOpInterface.h"
2828
#include "mlir/Transforms/InliningUtils.h"
29+
#include "mlir/Transforms/RegionUtils.h"
2930
#include "llvm/ADT/MapVector.h"
3031
#include "llvm/ADT/STLExtras.h"
3132
#include "llvm/ADT/SmallPtrSet.h"
@@ -3687,6 +3688,133 @@ LogicalResult scf::WhileOp::verify() {
36873688
}
36883689

36893690
namespace {
3691+
/// Move a scf.if op that is directly before the scf.condition op in the while
3692+
/// before region, and whose condition matches the condition of the
3693+
/// scf.condition op, down into the while after region.
3694+
///
3695+
/// scf.while (..) : (...) -> ... {
3696+
/// %additional_used_values = ...
3697+
/// %cond = ...
3698+
/// ...
3699+
/// %res = scf.if %cond -> (...) {
3700+
/// use(%additional_used_values)
3701+
/// ... // then block
3702+
/// scf.yield %then_value
3703+
/// } else {
3704+
/// scf.yield %else_value
3705+
/// }
3706+
/// scf.condition(%cond) %res, ...
3707+
/// } do {
3708+
/// ^bb0(%res_arg, ...):
3709+
/// use(%res_arg)
3710+
/// ...
3711+
///
3712+
/// becomes
3713+
/// scf.while (..) : (...) -> ... {
3714+
/// %additional_used_values = ...
3715+
/// %cond = ...
3716+
/// ...
3717+
/// scf.condition(%cond) %else_value, ..., %additional_used_values
3718+
/// } do {
3719+
/// ^bb0(%res_arg ..., %additional_args): :
3720+
/// use(%additional_args)
3721+
/// ... // if then block
3722+
/// use(%then_value)
3723+
/// ...
3724+
struct WhileMoveIfDown : public OpRewritePattern<scf::WhileOp> {
3725+
using OpRewritePattern<scf::WhileOp>::OpRewritePattern;
3726+
3727+
LogicalResult matchAndRewrite(scf::WhileOp op,
3728+
PatternRewriter &rewriter) const override {
3729+
auto conditionOp = op.getConditionOp();
3730+
3731+
// Only support ifOp right before the condition at the moment. Relaxing this
3732+
// would require to:
3733+
// - check that the body does not have side-effects conflicting with
3734+
// operations between the if and the condition.
3735+
// - check that results of the if operation are only used as arguments to
3736+
// the condition.
3737+
auto ifOp = dyn_cast_or_null<scf::IfOp>(conditionOp->getPrevNode());
3738+
3739+
// Check that the ifOp is directly before the conditionOp and that it
3740+
// matches the condition of the conditionOp. Also ensure that the ifOp has
3741+
// no else block with content, as that would complicate the transformation.
3742+
// TODO: support else blocks with content.
3743+
if (!ifOp || ifOp.getCondition() != conditionOp.getCondition() ||
3744+
(ifOp.elseBlock() && !ifOp.elseBlock()->without_terminator().empty()))
3745+
return failure();
3746+
3747+
assert(ifOp->use_empty() || (llvm::all_equal(ifOp->getUsers()) &&
3748+
*ifOp->user_begin() == conditionOp) &&
3749+
"ifOp has unexpected uses");
3750+
3751+
Location loc = op.getLoc();
3752+
3753+
// Replace uses of ifOp results in the conditionOp with the yielded values
3754+
// from the ifOp branches.
3755+
for (auto [idx, arg] : llvm::enumerate(conditionOp.getArgs())) {
3756+
auto it = llvm::find(ifOp->getResults(), arg);
3757+
if (it != ifOp->getResults().end()) {
3758+
size_t ifOpIdx = it.getIndex();
3759+
Value thenValue = ifOp.thenYield()->getOperand(ifOpIdx);
3760+
Value elseValue = ifOp.elseYield()->getOperand(ifOpIdx);
3761+
3762+
rewriter.replaceAllUsesWith(ifOp->getResults()[ifOpIdx], elseValue);
3763+
rewriter.replaceAllUsesWith(op.getAfterArguments()[idx], thenValue);
3764+
}
3765+
}
3766+
3767+
// Collect additional used values from before region.
3768+
SetVector<Value> additionalUsedValuesSet;
3769+
visitUsedValuesDefinedAbove(ifOp.getThenRegion(), [&](OpOperand *operand) {
3770+
if (&op.getBefore() == operand->get().getParentRegion())
3771+
additionalUsedValuesSet.insert(operand->get());
3772+
});
3773+
3774+
// Create new whileOp with additional used values as results.
3775+
auto additionalUsedValues = additionalUsedValuesSet.getArrayRef();
3776+
auto additionalValueTypes = llvm::map_to_vector(
3777+
additionalUsedValues, [](Value val) { return val.getType(); });
3778+
size_t additionalValueSize = additionalUsedValues.size();
3779+
SmallVector<Type> newResultTypes(op.getResultTypes());
3780+
newResultTypes.append(additionalValueTypes);
3781+
3782+
auto newWhileOp =
3783+
scf::WhileOp::create(rewriter, loc, newResultTypes, op.getInits());
3784+
3785+
rewriter.modifyOpInPlace(newWhileOp, [&] {
3786+
newWhileOp.getBefore().takeBody(op.getBefore());
3787+
newWhileOp.getAfter().takeBody(op.getAfter());
3788+
newWhileOp.getAfter().addArguments(
3789+
additionalValueTypes,
3790+
SmallVector<Location>(additionalValueSize, loc));
3791+
});
3792+
3793+
rewriter.modifyOpInPlace(conditionOp, [&] {
3794+
conditionOp.getArgsMutable().append(additionalUsedValues);
3795+
});
3796+
3797+
// Replace uses of additional used values inside the ifOp then region with
3798+
// the whileOp after region arguments.
3799+
rewriter.replaceUsesWithIf(
3800+
additionalUsedValues,
3801+
newWhileOp.getAfterArguments().take_back(additionalValueSize),
3802+
[&](OpOperand &use) {
3803+
return ifOp.getThenRegion().isAncestor(
3804+
use.getOwner()->getParentRegion());
3805+
});
3806+
3807+
// Inline ifOp then region into new whileOp after region.
3808+
rewriter.eraseOp(ifOp.thenYield());
3809+
rewriter.inlineBlockBefore(ifOp.thenBlock(), newWhileOp.getAfterBody(),
3810+
newWhileOp.getAfterBody()->begin());
3811+
rewriter.eraseOp(ifOp);
3812+
rewriter.replaceOp(op,
3813+
newWhileOp->getResults().drop_back(additionalValueSize));
3814+
return success();
3815+
}
3816+
};
3817+
36903818
/// Replace uses of the condition within the do block with true, since otherwise
36913819
/// the block would not be evaluated.
36923820
///
@@ -4399,7 +4527,8 @@ void WhileOp::getCanonicalizationPatterns(RewritePatternSet &results,
43994527
results.add<RemoveLoopInvariantArgsFromBeforeBlock,
44004528
RemoveLoopInvariantValueYielded, WhileConditionTruth,
44014529
WhileCmpCond, WhileUnusedResult, WhileRemoveDuplicatedResults,
4402-
WhileRemoveUnusedArgs, WhileOpAlignBeforeArgs>(context);
4530+
WhileRemoveUnusedArgs, WhileOpAlignBeforeArgs, WhileMoveIfDown>(
4531+
context);
44034532
}
44044533

44054534
//===----------------------------------------------------------------------===//

mlir/test/Dialect/SCF/canonicalize.mlir

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -974,6 +974,56 @@ func.func @replace_if_with_cond3(%arg0 : i1, %arg2: i64) -> (i32, i64) {
974974

975975
// -----
976976

977+
// CHECK-LABEL: @while_move_if_down
978+
func.func @while_move_if_down() -> i32 {
979+
%defined_outside = "test.get_some_value0" () : () -> (i32)
980+
%0 = scf.while () : () -> (i32) {
981+
%used_value = "test.get_some_value1" () : () -> (i32)
982+
%used_by_subregion = "test.get_some_value2" () : () -> (i32)
983+
%else_value = "test.get_some_value3" () : () -> (i32)
984+
%condition = "test.condition"() : () -> i1
985+
%res = scf.if %condition -> (i32) {
986+
"test.use0" (%defined_outside) : (i32) -> ()
987+
"test.use1" (%used_value) : (i32) -> ()
988+
test.alloca_scope_region {
989+
"test.use2" (%used_by_subregion) : (i32) -> ()
990+
}
991+
%then_value = "test.get_some_value4" () : () -> (i32)
992+
scf.yield %then_value : i32
993+
} else {
994+
scf.yield %else_value : i32
995+
}
996+
scf.condition(%condition) %res : i32
997+
} do {
998+
^bb0(%res_arg: i32):
999+
"test.use3" (%res_arg) : (i32) -> ()
1000+
scf.yield
1001+
}
1002+
return %0 : i32
1003+
}
1004+
// CHECK: %[[defined_outside:.*]] = "test.get_some_value0"() : () -> i32
1005+
// CHECK: %[[WHILE_RES:.*]]:3 = scf.while : () -> (i32, i32, i32) {
1006+
// CHECK: %[[used_value:.*]] = "test.get_some_value1"() : () -> i32
1007+
// CHECK: %[[used_by_subregion:.*]] = "test.get_some_value2"() : () -> i32
1008+
// CHECK: %[[else_value:.*]] = "test.get_some_value3"() : () -> i32
1009+
// CHECK: %[[condition:.*]] = "test.condition"() : () -> i1
1010+
// CHECK: scf.condition(%[[condition]]) %[[else_value]], %[[used_value]], %[[used_by_subregion]] : i32, i32, i32
1011+
// CHECK: } do {
1012+
// CHECK: ^bb0(%[[res_arg:.*]]: i32, %[[used_value_arg:.*]]: i32, %[[used_by_subregion_arg:.*]]: i32):
1013+
// CHECK: "test.use0"(%[[defined_outside]]) : (i32) -> ()
1014+
// CHECK: "test.use1"(%[[used_value_arg]]) : (i32) -> ()
1015+
// CHECK: test.alloca_scope_region {
1016+
// CHECK: "test.use2"(%[[used_by_subregion_arg]]) : (i32) -> ()
1017+
// CHECK: }
1018+
// CHECK: %[[then_value:.*]] = "test.get_some_value4"() : () -> i32
1019+
// CHECK: "test.use3"(%[[then_value]]) : (i32) -> ()
1020+
// CHECK: scf.yield
1021+
// CHECK: }
1022+
// CHECK: return %[[WHILE_RES]]#0 : i32
1023+
// CHECK: }
1024+
1025+
// -----
1026+
9771027
// CHECK-LABEL: @while_cond_true
9781028
func.func @while_cond_true() -> i1 {
9791029
%0 = scf.while () : () -> i1 {

0 commit comments

Comments
 (0)