Skip to content

Commit fa7eb12

Browse files
authored
[SwitchToIf] support multi-result scf.index_switch + add regression test (#9245)
* SwitchToIf: support multi-result scf.index_switch + add regression test * IndexSwitchToIf: fail if adaptor operand missing
1 parent 33cc4ec commit fa7eb12

File tree

2 files changed

+62
-7
lines changed

2 files changed

+62
-7
lines changed

lib/Transforms/IndexSwitchToIf.cpp

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
#include "mlir/IR/OperationSupport.h"
1919
#include "mlir/Pass/Pass.h"
2020
#include "mlir/Transforms/DialectConversion.h"
21+
#include "llvm/ADT/STLExtras.h"
2122

2223
namespace circt {
2324
#define GEN_PASS_DEF_INDEXSWITCHTOIF
@@ -38,21 +39,24 @@ struct SwitchToIfConversion : public OpConversionPattern<scf::IndexSwitchOp> {
3839
Region &defaultRegion = switchOp.getDefaultRegion();
3940
bool hasResults = !switchOp.getResultTypes().empty();
4041

41-
Value finalResult;
42+
SmallVector<Value> finalResults;
4243
scf::IfOp prevIfOp = nullptr;
4344

4445
rewriter.setInsertionPointAfter(switchOp);
4546
auto switchCases = switchOp.getCases();
47+
Value switchOperand = adaptor.getArg();
48+
if (!switchOperand)
49+
return rewriter.notifyMatchFailure(switchOp,
50+
"missing converted switch operand");
4651
for (size_t i = 0; i < switchCases.size(); i++) {
4752
auto caseValueInt = switchCases[i];
4853
if (prevIfOp)
4954
rewriter.setInsertionPointToStart(&prevIfOp.getElseRegion().front());
5055

5156
Value caseValue =
5257
arith::ConstantIndexOp::create(rewriter, loc, caseValueInt);
53-
Value cond =
54-
arith::CmpIOp::create(rewriter, loc, arith::CmpIPredicate::eq,
55-
switchOp.getOperand(), caseValue);
58+
Value cond = arith::CmpIOp::create(
59+
rewriter, loc, arith::CmpIPredicate::eq, switchOperand, caseValue);
5660

5761
auto ifOp = scf::IfOp::create(rewriter, loc, switchOp.getResultTypes(),
5862
cond, /*hasElseRegion=*/true);
@@ -70,17 +74,17 @@ struct SwitchToIfConversion : public OpConversionPattern<scf::IndexSwitchOp> {
7074

7175
if (prevIfOp && hasResults) {
7276
rewriter.setInsertionPointToEnd(&prevIfOp.getElseRegion().front());
73-
scf::YieldOp::create(rewriter, loc, ifOp.getResult(0));
77+
scf::YieldOp::create(rewriter, loc, ifOp.getResults());
7478
}
7579

7680
if (i == 0 && hasResults)
77-
finalResult = ifOp.getResult(0);
81+
llvm::append_range(finalResults, ifOp.getResults());
7882

7983
prevIfOp = ifOp;
8084
}
8185

8286
if (hasResults)
83-
rewriter.replaceOp(switchOp, finalResult);
87+
rewriter.replaceOp(switchOp, finalResults);
8488
else
8589
rewriter.eraseOp(switchOp);
8690

test/Transforms/switch-to-if.mlir

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,57 @@ module {
5656
}
5757
}
5858

59+
// Switches that yield multiple values should thread every result through the
60+
// nested if-else chain.
61+
62+
// -----
63+
64+
// CHECK-LABEL: func.func @multi_result(
65+
// CHECK-SAME: %[[ARG:.*]]: index) -> (i32, f32) {
66+
// CHECK: %[[ZERO:.*]] = arith.constant 0 : index
67+
// CHECK: %[[CMP0:.*]] = arith.cmpi eq, %[[ARG]], %[[ZERO]] : index
68+
// CHECK: %[[IF0:[0-9]+]]:2 = scf.if %[[CMP0]] -> (i32, f32) {
69+
// CHECK: %[[C0:.*]] = arith.constant 10 : i32
70+
// CHECK: %[[F0:.*]] = arith.constant 1.000000e+00 : f32
71+
// CHECK: scf.yield %[[C0]], %[[F0]] : i32, f32
72+
// CHECK: } else {
73+
// CHECK: %[[ONE:.*]] = arith.constant 1 : index
74+
// CHECK: %[[CMP1:.*]] = arith.cmpi eq, %[[ARG]], %[[ONE]] : index
75+
// CHECK: %[[IF1:[0-9]+]]:2 = scf.if %[[CMP1]] -> (i32, f32) {
76+
// CHECK: %[[C1:.*]] = arith.constant 20 : i32
77+
// CHECK: %[[F1:.*]] = arith.constant 2.000000e+00 : f32
78+
// CHECK: scf.yield %[[C1]], %[[F1]] : i32, f32
79+
// CHECK: } else {
80+
// CHECK: %[[CDEF:.*]] = arith.constant 30 : i32
81+
// CHECK: %[[FDEF:.*]] = arith.constant 3.000000e+00 : f32
82+
// CHECK: scf.yield %[[CDEF]], %[[FDEF]] : i32, f32
83+
// CHECK: }
84+
// CHECK: scf.yield %[[IF1]]#0, %[[IF1]]#1 : i32, f32
85+
// CHECK: }
86+
// CHECK: return %[[IF0]]#0, %[[IF0]]#1 : i32, f32
87+
// CHECK: }
88+
module {
89+
func.func @multi_result(%arg0 : index) -> (i32, f32) {
90+
%0, %1 = scf.index_switch %arg0 -> i32, f32
91+
case 0 {
92+
%c0 = arith.constant 10 : i32
93+
%f0 = arith.constant 1.0 : f32
94+
scf.yield %c0, %f0 : i32, f32
95+
}
96+
case 1 {
97+
%c1 = arith.constant 20 : i32
98+
%f1 = arith.constant 2.0 : f32
99+
scf.yield %c1, %f1 : i32, f32
100+
}
101+
default {
102+
%cd = arith.constant 30 : i32
103+
%fd = arith.constant 3.0 : f32
104+
scf.yield %cd, %fd : i32, f32
105+
}
106+
return %0, %1 : i32, f32
107+
}
108+
}
109+
59110
// Switch to nested if-else when the yielded result is empty
60111

61112
// -----

0 commit comments

Comments
 (0)