Skip to content

Commit ca27260

Browse files
committed
[MLIR] Add SCF.if Condition Canonicalizations
Add two canoncalizations for scf.if. 1) A canonicalization that allows users of a condition within an if to assume the condition is true if in the true region, etc. 2) A canonicalization that removes yielded statements that are equivalent to the condition or its negation Differential Revision: https://reviews.llvm.org/D101012
1 parent 7aa3cad commit ca27260

File tree

3 files changed

+283
-8
lines changed

3 files changed

+283
-8
lines changed

mlir/lib/Dialect/SCF/SCF.cpp

Lines changed: 162 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1106,12 +1106,172 @@ struct ConvertTrivialIfToSelect : public OpRewritePattern<IfOp> {
11061106
return success();
11071107
}
11081108
};
1109+
1110+
// Allow the true region of an if to assume the condition is true
1111+
// and vice versa. For example:
1112+
//
1113+
// scf.if %cmp {
1114+
// print(%cmp)
1115+
// }
1116+
//
1117+
// becomes
1118+
//
1119+
// scf.if %cmp {
1120+
// print(true)
1121+
// }
1122+
//
1123+
struct ConditionPropagation : public OpRewritePattern<IfOp> {
1124+
using OpRewritePattern<IfOp>::OpRewritePattern;
1125+
1126+
LogicalResult matchAndRewrite(IfOp op,
1127+
PatternRewriter &rewriter) const override {
1128+
// Early exit if the condition is constant since replacing a constant
1129+
// in the body with another constant isn't a simplification.
1130+
if (op.condition().getDefiningOp<ConstantOp>())
1131+
return failure();
1132+
1133+
bool changed = false;
1134+
mlir::Type i1Ty = rewriter.getI1Type();
1135+
1136+
// These variables serve to prevent creating duplicate constants
1137+
// and hold constant true or false values.
1138+
Value constantTrue = nullptr;
1139+
Value constantFalse = nullptr;
1140+
1141+
for (OpOperand &use :
1142+
llvm::make_early_inc_range(op.condition().getUses())) {
1143+
if (op.thenRegion().isAncestor(use.getOwner()->getParentRegion())) {
1144+
changed = true;
1145+
1146+
if (!constantTrue)
1147+
constantTrue = rewriter.create<mlir::ConstantOp>(
1148+
op.getLoc(), i1Ty, rewriter.getIntegerAttr(i1Ty, 1));
1149+
1150+
rewriter.updateRootInPlace(use.getOwner(),
1151+
[&]() { use.set(constantTrue); });
1152+
} else if (op.elseRegion().isAncestor(
1153+
use.getOwner()->getParentRegion())) {
1154+
changed = true;
1155+
1156+
if (!constantFalse)
1157+
constantFalse = rewriter.create<mlir::ConstantOp>(
1158+
op.getLoc(), i1Ty, rewriter.getIntegerAttr(i1Ty, 0));
1159+
1160+
rewriter.updateRootInPlace(use.getOwner(),
1161+
[&]() { use.set(constantFalse); });
1162+
}
1163+
}
1164+
1165+
return success(changed);
1166+
}
1167+
};
1168+
1169+
/// Remove any statements from an if that are equivalent to the condition
1170+
/// or its negation. For example:
1171+
///
1172+
/// %res:2 = scf.if %cmp {
1173+
/// yield something(), true
1174+
/// } else {
1175+
/// yield something2(), false
1176+
/// }
1177+
/// print(%res#1)
1178+
///
1179+
/// becomes
1180+
/// %res = scf.if %cmp {
1181+
/// yield something()
1182+
/// } else {
1183+
/// yield something2()
1184+
/// }
1185+
/// print(%cmp)
1186+
///
1187+
/// Additionally if both branches yield the same value, replace all uses
1188+
/// of the result with the yielded value
1189+
///
1190+
/// %res:2 = scf.if %cmp {
1191+
/// yield something(), %arg1
1192+
/// } else {
1193+
/// yield something2(), %arg1
1194+
/// }
1195+
/// print(%res#1)
1196+
///
1197+
/// becomes
1198+
/// %res = scf.if %cmp {
1199+
/// yield something()
1200+
/// } else {
1201+
/// yield something2()
1202+
/// }
1203+
// print(%arg1)
1204+
struct ReplaceIfYieldWithConditionOrValue : public OpRewritePattern<IfOp> {
1205+
using OpRewritePattern<IfOp>::OpRewritePattern;
1206+
1207+
LogicalResult matchAndRewrite(IfOp op,
1208+
PatternRewriter &rewriter) const override {
1209+
// Early exit if there are no results that could be replaced.
1210+
if (op.getNumResults() == 0)
1211+
return failure();
1212+
1213+
auto trueYield = cast<scf::YieldOp>(op.thenRegion().back().getTerminator());
1214+
auto falseYield =
1215+
cast<scf::YieldOp>(op.elseRegion().back().getTerminator());
1216+
1217+
rewriter.setInsertionPoint(op->getBlock(),
1218+
op.getOperation()->getIterator());
1219+
bool changed = false;
1220+
Type i1Ty = rewriter.getI1Type();
1221+
for (auto tup :
1222+
llvm::zip(trueYield.results(), falseYield.results(), op.results())) {
1223+
Value trueResult, falseResult, opResult;
1224+
std::tie(trueResult, falseResult, opResult) = tup;
1225+
1226+
if (trueResult == falseResult) {
1227+
if (!opResult.use_empty()) {
1228+
opResult.replaceAllUsesWith(trueResult);
1229+
changed = true;
1230+
}
1231+
continue;
1232+
}
1233+
1234+
auto trueYield = trueResult.getDefiningOp<ConstantOp>();
1235+
if (!trueYield)
1236+
continue;
1237+
1238+
if (!trueYield.getType().isInteger(1))
1239+
continue;
1240+
1241+
auto falseYield = falseResult.getDefiningOp<ConstantOp>();
1242+
if (!falseYield)
1243+
continue;
1244+
1245+
bool trueVal = trueYield.getValue().cast<BoolAttr>().getValue();
1246+
bool falseVal = falseYield.getValue().cast<BoolAttr>().getValue();
1247+
if (!trueVal && falseVal) {
1248+
if (!opResult.use_empty()) {
1249+
Value notCond = rewriter.create<XOrOp>(
1250+
op.getLoc(), op.condition(),
1251+
rewriter.create<mlir::ConstantOp>(
1252+
op.getLoc(), i1Ty, rewriter.getIntegerAttr(i1Ty, 1)));
1253+
opResult.replaceAllUsesWith(notCond);
1254+
changed = true;
1255+
}
1256+
}
1257+
if (trueVal && !falseVal) {
1258+
if (!opResult.use_empty()) {
1259+
opResult.replaceAllUsesWith(op.condition());
1260+
changed = true;
1261+
}
1262+
}
1263+
}
1264+
return success(changed);
1265+
}
1266+
};
1267+
11091268
} // namespace
11101269

11111270
void IfOp::getCanonicalizationPatterns(RewritePatternSet &results,
11121271
MLIRContext *context) {
1113-
results.add<RemoveUnusedResults, RemoveStaticCondition,
1114-
ConvertTrivialIfToSelect>(context);
1272+
results
1273+
.add<RemoveUnusedResults, RemoveStaticCondition, ConvertTrivialIfToSelect,
1274+
ConditionPropagation, ReplaceIfYieldWithConditionOrValue>(context);
11151275
}
11161276

11171277
//===----------------------------------------------------------------------===//

mlir/test/Dialect/SCF/canonicalize.mlir

Lines changed: 118 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -103,22 +103,25 @@ func private @side_effect()
103103
func @one_unused(%cond: i1) -> (index) {
104104
%c0 = constant 0 : index
105105
%c1 = constant 1 : index
106+
%c2 = constant 2 : index
107+
%c3 = constant 3 : index
106108
%0, %1 = scf.if %cond -> (index, index) {
107109
call @side_effect() : () -> ()
108110
scf.yield %c0, %c1 : index, index
109111
} else {
110-
scf.yield %c0, %c1 : index, index
112+
scf.yield %c2, %c3 : index, index
111113
}
112114
return %1 : index
113115
}
114116

115117
// CHECK-LABEL: func @one_unused
116118
// CHECK: [[C0:%.*]] = constant 1 : index
119+
// CHECK: [[C3:%.*]] = constant 3 : index
117120
// CHECK: [[V0:%.*]] = scf.if %{{.*}} -> (index) {
118121
// CHECK: call @side_effect() : () -> ()
119122
// CHECK: scf.yield [[C0]] : index
120123
// CHECK: } else
121-
// CHECK: scf.yield [[C0]] : index
124+
// CHECK: scf.yield [[C3]] : index
122125
// CHECK: }
123126
// CHECK: return [[V0]] : index
124127

@@ -128,12 +131,14 @@ func private @side_effect()
128131
func @nested_unused(%cond1: i1, %cond2: i1) -> (index) {
129132
%c0 = constant 0 : index
130133
%c1 = constant 1 : index
134+
%c2 = constant 2 : index
135+
%c3 = constant 3 : index
131136
%0, %1 = scf.if %cond1 -> (index, index) {
132137
%2, %3 = scf.if %cond2 -> (index, index) {
133138
call @side_effect() : () -> ()
134139
scf.yield %c0, %c1 : index, index
135140
} else {
136-
scf.yield %c0, %c1 : index, index
141+
scf.yield %c2, %c3 : index, index
137142
}
138143
scf.yield %2, %3 : index, index
139144
} else {
@@ -144,12 +149,13 @@ func @nested_unused(%cond1: i1, %cond2: i1) -> (index) {
144149

145150
// CHECK-LABEL: func @nested_unused
146151
// CHECK: [[C0:%.*]] = constant 1 : index
152+
// CHECK: [[C3:%.*]] = constant 3 : index
147153
// CHECK: [[V0:%.*]] = scf.if {{.*}} -> (index) {
148154
// CHECK: [[V1:%.*]] = scf.if {{.*}} -> (index) {
149155
// CHECK: call @side_effect() : () -> ()
150156
// CHECK: scf.yield [[C0]] : index
151157
// CHECK: } else
152-
// CHECK: scf.yield [[C0]] : index
158+
// CHECK: scf.yield [[C3]] : index
153159
// CHECK: }
154160
// CHECK: scf.yield [[V1]] : index
155161
// CHECK: } else
@@ -610,3 +616,111 @@ func @matmul_on_tensors(%t0: tensor<32x1024xf32>, %t1: tensor<1024x1024xf32>) ->
610616
%res = subtensor_insert %2 into %t1[0, 0] [32, 1024] [1, 1] : tensor<32x1024xf32> into tensor<1024x1024xf32>
611617
return %res : tensor<1024x1024xf32>
612618
}
619+
620+
621+
622+
// CHECK-LABEL: @cond_prop
623+
func @cond_prop(%arg0 : i1) -> index {
624+
%c1 = constant 1 : index
625+
%c2 = constant 2 : index
626+
%c3 = constant 3 : index
627+
%c4 = constant 4 : index
628+
%res = scf.if %arg0 -> index {
629+
%res1 = scf.if %arg0 -> index {
630+
%v1 = "test.get_some_value"() : () -> i32
631+
scf.yield %c1 : index
632+
} else {
633+
%v2 = "test.get_some_value"() : () -> i32
634+
scf.yield %c2 : index
635+
}
636+
scf.yield %res1 : index
637+
} else {
638+
%res2 = scf.if %arg0 -> index {
639+
%v3 = "test.get_some_value"() : () -> i32
640+
scf.yield %c3 : index
641+
} else {
642+
%v4 = "test.get_some_value"() : () -> i32
643+
scf.yield %c4 : index
644+
}
645+
scf.yield %res2 : index
646+
}
647+
return %res : index
648+
}
649+
// CHECK-DAG: %[[c1:.+]] = constant 1 : index
650+
// CHECK-DAG: %[[c4:.+]] = constant 4 : index
651+
// CHECK-NEXT: %[[if:.+]] = scf.if %arg0 -> (index) {
652+
// CHECK-NEXT: %{{.+}} = "test.get_some_value"() : () -> i32
653+
// CHECK-NEXT: scf.yield %[[c1]] : index
654+
// CHECK-NEXT: } else {
655+
// CHECK-NEXT: %{{.+}} = "test.get_some_value"() : () -> i32
656+
// CHECK-NEXT: scf.yield %[[c4]] : index
657+
// CHECK-NEXT: }
658+
// CHECK-NEXT: return %[[if]] : index
659+
// CHECK-NEXT:}
660+
661+
// CHECK-LABEL: @replace_if_with_cond1
662+
func @replace_if_with_cond1(%arg0 : i1) -> (i32, i1) {
663+
%true = constant true
664+
%false = constant false
665+
%res:2 = scf.if %arg0 -> (i32, i1) {
666+
%v = "test.get_some_value"() : () -> i32
667+
scf.yield %v, %true : i32, i1
668+
} else {
669+
%v2 = "test.get_some_value"() : () -> i32
670+
scf.yield %v2, %false : i32, i1
671+
}
672+
return %res#0, %res#1 : i32, i1
673+
}
674+
// CHECK-NEXT: %[[if:.+]] = scf.if %arg0 -> (i32) {
675+
// CHECK-NEXT: %[[sv1:.+]] = "test.get_some_value"() : () -> i32
676+
// CHECK-NEXT: scf.yield %[[sv1]] : i32
677+
// CHECK-NEXT: } else {
678+
// CHECK-NEXT: %[[sv2:.+]] = "test.get_some_value"() : () -> i32
679+
// CHECK-NEXT: scf.yield %[[sv2]] : i32
680+
// CHECK-NEXT: }
681+
// CHECK-NEXT: return %[[if]], %arg0 : i32, i1
682+
683+
// CHECK-LABEL: @replace_if_with_cond2
684+
func @replace_if_with_cond2(%arg0 : i1) -> (i32, i1) {
685+
%true = constant true
686+
%false = constant false
687+
%res:2 = scf.if %arg0 -> (i32, i1) {
688+
%v = "test.get_some_value"() : () -> i32
689+
scf.yield %v, %false : i32, i1
690+
} else {
691+
%v2 = "test.get_some_value"() : () -> i32
692+
scf.yield %v2, %true : i32, i1
693+
}
694+
return %res#0, %res#1 : i32, i1
695+
}
696+
// CHECK-NEXT: %true = constant true
697+
// CHECK-NEXT: %[[toret:.+]] = xor %arg0, %true : i1
698+
// CHECK-NEXT: %[[if:.+]] = scf.if %arg0 -> (i32) {
699+
// CHECK-NEXT: %[[sv1:.+]] = "test.get_some_value"() : () -> i32
700+
// CHECK-NEXT: scf.yield %[[sv1]] : i32
701+
// CHECK-NEXT: } else {
702+
// CHECK-NEXT: %[[sv2:.+]] = "test.get_some_value"() : () -> i32
703+
// CHECK-NEXT: scf.yield %[[sv2]] : i32
704+
// CHECK-NEXT: }
705+
// CHECK-NEXT: return %[[if]], %[[toret]] : i32, i1
706+
707+
708+
// CHECK-LABEL: @replace_if_with_cond3
709+
func @replace_if_with_cond3(%arg0 : i1, %arg2: i64) -> (i32, i64) {
710+
%res:2 = scf.if %arg0 -> (i32, i64) {
711+
%v = "test.get_some_value"() : () -> i32
712+
scf.yield %v, %arg2 : i32, i64
713+
} else {
714+
%v2 = "test.get_some_value"() : () -> i32
715+
scf.yield %v2, %arg2 : i32, i64
716+
}
717+
return %res#0, %res#1 : i32, i64
718+
}
719+
// CHECK-NEXT: %[[if:.+]] = scf.if %arg0 -> (i32) {
720+
// CHECK-NEXT: %[[sv1:.+]] = "test.get_some_value"() : () -> i32
721+
// CHECK-NEXT: scf.yield %[[sv1]] : i32
722+
// CHECK-NEXT: } else {
723+
// CHECK-NEXT: %[[sv2:.+]] = "test.get_some_value"() : () -> i32
724+
// CHECK-NEXT: scf.yield %[[sv2]] : i32
725+
// CHECK-NEXT: }
726+
// CHECK-NEXT: return %[[if]], %arg1 : i32, i64

mlir/test/Transforms/canonicalize.mlir

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1198,11 +1198,12 @@ func @clone_loop_alloc(%arg0: index, %arg1: index, %arg2: index, %arg3: memref<2
11981198
// -----
11991199

12001200
// CHECK-LABEL: func @clone_nested_region
1201-
func @clone_nested_region(%arg0: index, %arg1: index) -> memref<?x?xf32> {
1201+
func @clone_nested_region(%arg0: index, %arg1: index, %arg2: index) -> memref<?x?xf32> {
1202+
%cmp = cmpi eq, %arg0, %arg1 : index
12021203
%0 = cmpi eq, %arg0, %arg1 : index
12031204
%1 = memref.alloc(%arg0, %arg0) : memref<?x?xf32>
12041205
%2 = scf.if %0 -> (memref<?x?xf32>) {
1205-
%3 = scf.if %0 -> (memref<?x?xf32>) {
1206+
%3 = scf.if %cmp -> (memref<?x?xf32>) {
12061207
%9 = memref.clone %1 : memref<?x?xf32> to memref<?x?xf32>
12071208
scf.yield %9 : memref<?x?xf32>
12081209
} else {

0 commit comments

Comments
 (0)