Skip to content

Commit 1171117

Browse files
committed
[CIR] Fix multiple return statements in switch cases
Add support for multiple return statements in switch cases. Cases in switch statements don't have their own scopes but are distinct regions nonetheless. Insert multiple return blocks for each case and handle them in the cleanup code.
1 parent a34e8c3 commit 1171117

File tree

3 files changed

+160
-41
lines changed

3 files changed

+160
-41
lines changed

clang/lib/CIR/CodeGen/CIRGenFunction.cpp

Lines changed: 20 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -242,12 +242,15 @@ void CIRGenFunction::LexicalScope::cleanup() {
242242
}
243243
};
244244

245-
if (returnBlock != nullptr) {
246-
// Write out the return block, which loads the value from `__retval` and
247-
// issues the `cir.return`.
245+
// Cleanup are done right before codegen resumes a scope. This is where
246+
// objects are destroyed. Process all return blocks.
247+
llvm::SmallVector<mlir::Block *> retBlocks;
248+
for (mlir::Block *retBlock : localScope->getRetBlocks()) {
248249
mlir::OpBuilder::InsertionGuard guard(builder);
249-
builder.setInsertionPointToEnd(returnBlock);
250-
(void)emitReturn(*returnLoc);
250+
builder.setInsertionPointToEnd(retBlock);
251+
retBlocks.push_back(retBlock);
252+
mlir::Location retLoc = localScope->getRetLoc(retBlock);
253+
emitReturn(retLoc);
251254
}
252255

253256
auto insertCleanupAndLeave = [&](mlir::Block *insPt) {
@@ -274,19 +277,21 @@ void CIRGenFunction::LexicalScope::cleanup() {
274277

275278
if (localScope->depth == 0) {
276279
// Reached the end of the function.
277-
if (returnBlock != nullptr) {
278-
if (returnBlock->getUses().empty()) {
279-
returnBlock->erase();
280+
// Special handling only for single return block case
281+
if (localScope->getRetBlocks().size() == 1) {
282+
mlir::Block *retBlock = localScope->getRetBlocks()[0];
283+
mlir::Location retLoc = localScope->getRetLoc(retBlock);
284+
if (retBlock->getUses().empty()) {
285+
retBlock->erase();
280286
} else {
281287
// Thread return block via cleanup block.
282288
if (cleanupBlock) {
283-
for (mlir::BlockOperand &blockUse : returnBlock->getUses()) {
289+
for (mlir::BlockOperand &blockUse : retBlock->getUses()) {
284290
cir::BrOp brOp = mlir::cast<cir::BrOp>(blockUse.getOwner());
285291
brOp.setSuccessor(cleanupBlock);
286292
}
287293
}
288-
289-
builder.create<cir::BrOp>(*returnLoc, returnBlock);
294+
builder.create<cir::BrOp>(retLoc, retBlock);
290295
return;
291296
}
292297
}
@@ -324,8 +329,10 @@ void CIRGenFunction::LexicalScope::cleanup() {
324329
bool entryBlock = builder.getInsertionBlock()->isEntryBlock();
325330
if (!entryBlock && curBlock->empty()) {
326331
curBlock->erase();
327-
if (returnBlock != nullptr && returnBlock->getUses().empty())
328-
returnBlock->erase();
332+
for (mlir::Block *retBlock : retBlocks) {
333+
if (retBlock->getUses().empty())
334+
retBlock->erase();
335+
}
329336
return;
330337
}
331338

clang/lib/CIR/CodeGen/CIRGenFunction.h

Lines changed: 53 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1040,44 +1040,69 @@ class CIRGenFunction : public CIRGenTypeCache {
10401040
// ---
10411041

10421042
private:
1043-
// `returnBlock`, `returnLoc`, and all the functions that deal with them
1044-
// will change and become more complicated when `switch` statements are
1045-
// upstreamed. `case` statements within the `switch` are in the same scope
1046-
// but have their own regions. Therefore the LexicalScope will need to
1047-
// keep track of multiple return blocks.
1048-
mlir::Block *returnBlock = nullptr;
1049-
std::optional<mlir::Location> returnLoc;
1050-
1051-
// See the comment on `getOrCreateRetBlock`.
1043+
// On switches we need one return block per region, since cases don't
1044+
// have their own scopes but are distinct regions nonetheless.
1045+
1046+
// TODO: This implementation should change once we have support for early
1047+
// exits in MLIR structured control flow (llvm-project#161575)
1048+
llvm::SmallVector<mlir::Block *> retBlocks;
1049+
llvm::DenseMap<mlir::Block *, mlir::Location> retLocs;
1050+
llvm::DenseMap<cir::CaseOp, unsigned> retBlockInCaseIndex;
1051+
std::optional<unsigned> normalRetBlockIndex;
1052+
1053+
// There's usually only one ret block per scope, but this needs to be
1054+
// get or create because of potential unreachable return statements, note
1055+
// that for those, all source location maps to the first one found.
10521056
mlir::Block *createRetBlock(CIRGenFunction &cgf, mlir::Location loc) {
1053-
assert(returnBlock == nullptr && "only one return block per scope");
1054-
// Create the cleanup block but don't hook it up just yet.
1057+
assert((isa_and_nonnull<cir::CaseOp>(
1058+
cgf.builder.getBlock()->getParentOp()) ||
1059+
retBlocks.size() == 0) &&
1060+
"only switches can hold more than one ret block");
1061+
1062+
// Create the return block but don't hook it up just yet.
10551063
mlir::OpBuilder::InsertionGuard guard(cgf.builder);
1056-
returnBlock =
1057-
cgf.builder.createBlock(cgf.builder.getBlock()->getParent());
1058-
updateRetLoc(returnBlock, loc);
1059-
return returnBlock;
1064+
auto *b = cgf.builder.createBlock(cgf.builder.getBlock()->getParent());
1065+
retBlocks.push_back(b);
1066+
updateRetLoc(b, loc);
1067+
return b;
10601068
}
10611069

10621070
cir::ReturnOp emitReturn(mlir::Location loc);
10631071
void emitImplicitReturn();
10641072

10651073
public:
1066-
mlir::Block *getRetBlock() { return returnBlock; }
1067-
mlir::Location getRetLoc(mlir::Block *b) { return *returnLoc; }
1068-
void updateRetLoc(mlir::Block *b, mlir::Location loc) { returnLoc = loc; }
1069-
1070-
// Create the return block for this scope, or return the existing one.
1071-
// This get-or-create logic is necessary to handle multiple return
1072-
// statements within the same scope, which can happen if some of them are
1073-
// dead code or if there is a `goto` into the middle of the scope.
1074+
llvm::ArrayRef<mlir::Block *> getRetBlocks() { return retBlocks; }
1075+
mlir::Location getRetLoc(mlir::Block *b) { return retLocs.at(b); }
1076+
void updateRetLoc(mlir::Block *b, mlir::Location loc) {
1077+
retLocs.insert_or_assign(b, loc);
1078+
}
1079+
10741080
mlir::Block *getOrCreateRetBlock(CIRGenFunction &cgf, mlir::Location loc) {
1075-
if (returnBlock == nullptr) {
1076-
returnBlock = createRetBlock(cgf, loc);
1077-
return returnBlock;
1081+
// Check if we're inside a case region
1082+
if (auto caseOp = mlir::dyn_cast_if_present<cir::CaseOp>(
1083+
cgf.builder.getBlock()->getParentOp())) {
1084+
auto iter = retBlockInCaseIndex.find(caseOp);
1085+
if (iter != retBlockInCaseIndex.end()) {
1086+
// Reuse existing return block
1087+
mlir::Block *ret = retBlocks[iter->second];
1088+
updateRetLoc(ret, loc);
1089+
return ret;
1090+
}
1091+
// Create new return block
1092+
mlir::Block *ret = createRetBlock(cgf, loc);
1093+
retBlockInCaseIndex[caseOp] = retBlocks.size() - 1;
1094+
return ret;
10781095
}
1079-
updateRetLoc(returnBlock, loc);
1080-
return returnBlock;
1096+
1097+
if (normalRetBlockIndex) {
1098+
mlir::Block *ret = retBlocks[*normalRetBlockIndex];
1099+
updateRetLoc(ret, loc);
1100+
return ret;
1101+
}
1102+
1103+
mlir::Block *ret = createRetBlock(cgf, loc);
1104+
normalRetBlockIndex = retBlocks.size() - 1;
1105+
return ret;
10811106
}
10821107

10831108
mlir::Block *getEntryBlock() { return entryBlock; }

clang/test/CIR/CodeGen/switch.cpp

Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1183,3 +1183,90 @@ int nested_switch(int a) {
11831183
// OGCG: [[IFEND10]]:
11841184
// OGCG: br label %[[EPILOG]]
11851185
// OGCG: [[EPILOG]]:
1186+
1187+
int sw_return_multi_cases(int x) {
1188+
switch (x) {
1189+
case 0:
1190+
return 0;
1191+
case 1:
1192+
return 1;
1193+
case 2:
1194+
return 2;
1195+
default:
1196+
return -1;
1197+
}
1198+
}
1199+
1200+
// CIR-LABEL: cir.func{{.*}} @_Z21sw_return_multi_casesi
1201+
// CIR: cir.switch (%{{.*}} : !s32i) {
1202+
// CIR-NEXT: cir.case(equal, [#cir.int<0> : !s32i]) {
1203+
// CIR: %[[ZERO:.*]] = cir.const #cir.int<0> : !s32i
1204+
// CIR: cir.store{{.*}} %[[ZERO]], %{{.*}} : !s32i, !cir.ptr<!s32i>
1205+
// CIR: %[[RET0:.*]] = cir.load{{.*}} %{{.*}} : !cir.ptr<!s32i>, !s32i
1206+
// CIR-NEXT: cir.return %[[RET0]] : !s32i
1207+
// CIR-NEXT: }
1208+
// CIR-NEXT: cir.case(equal, [#cir.int<1> : !s32i]) {
1209+
// CIR: %[[ONE:.*]] = cir.const #cir.int<1> : !s32i
1210+
// CIR: cir.store{{.*}} %[[ONE]], %{{.*}} : !s32i, !cir.ptr<!s32i>
1211+
// CIR: %[[RET1:.*]] = cir.load{{.*}} %{{.*}} : !cir.ptr<!s32i>, !s32i
1212+
// CIR-NEXT: cir.return %[[RET1]] : !s32i
1213+
// CIR-NEXT: }
1214+
// CIR-NEXT: cir.case(equal, [#cir.int<2> : !s32i]) {
1215+
// CIR: %[[TWO:.*]] = cir.const #cir.int<2> : !s32i
1216+
// CIR: cir.store{{.*}} %[[TWO]], %{{.*}} : !s32i, !cir.ptr<!s32i>
1217+
// CIR: %[[RET2:.*]] = cir.load{{.*}} %{{.*}} : !cir.ptr<!s32i>, !s32i
1218+
// CIR-NEXT: cir.return %[[RET2]] : !s32i
1219+
// CIR-NEXT: }
1220+
// CIR-NEXT: cir.case(default, []) {
1221+
// CIR: %[[ONE:.*]] = cir.const #cir.int<1> : !s32i
1222+
// CIR: %[[NEG:.*]] = cir.unary(minus, %[[ONE]]) {{.*}} : !s32i, !s32i
1223+
// CIR: cir.store{{.*}} %[[NEG]], %{{.*}} : !s32i, !cir.ptr<!s32i>
1224+
// CIR: %[[RETDEF:.*]] = cir.load{{.*}} %{{.*}} : !cir.ptr<!s32i>, !s32i
1225+
// CIR-NEXT: cir.return %[[RETDEF]] : !s32i
1226+
// CIR-NEXT: }
1227+
// CIR-NEXT: cir.yield
1228+
1229+
// LLVM-LABEL: define{{.*}} i32 @_Z21sw_return_multi_casesi
1230+
// LLVM: switch i32 %{{.*}}, label %[[DEFAULT:.*]] [
1231+
// LLVM-DAG: i32 0, label %[[CASE0:.*]]
1232+
// LLVM-DAG: i32 1, label %[[CASE1:.*]]
1233+
// LLVM-DAG: i32 2, label %[[CASE2:.*]]
1234+
// LLVM: ]
1235+
// LLVM: [[CASE0]]:
1236+
// LLVM: store i32 0, ptr %{{.*}}, align 4
1237+
// LLVM: %{{.*}} = load i32, ptr %{{.*}}, align 4
1238+
// LLVM: ret i32 %{{.*}}
1239+
// LLVM: [[CASE1]]:
1240+
// LLVM: store i32 1, ptr %{{.*}}, align 4
1241+
// LLVM: %{{.*}} = load i32, ptr %{{.*}}, align 4
1242+
// LLVM: ret i32 %{{.*}}
1243+
// LLVM: [[CASE2]]:
1244+
// LLVM: store i32 2, ptr %{{.*}}, align 4
1245+
// LLVM: %{{.*}} = load i32, ptr %{{.*}}, align 4
1246+
// LLVM: ret i32 %{{.*}}
1247+
// LLVM: [[DEFAULT]]:
1248+
// LLVM: store i32 -1, ptr %{{.*}}, align 4
1249+
// LLVM: %{{.*}} = load i32, ptr %{{.*}}, align 4
1250+
// LLVM: ret i32 %{{.*}}
1251+
1252+
// OGCG-LABEL: define{{.*}} i32 @_Z21sw_return_multi_casesi
1253+
// OGCG: entry:
1254+
// OGCG: %[[RETVAL:.*]] = alloca i32, align 4
1255+
// OGCG: %[[X_ADDR:.*]] = alloca i32, align 4
1256+
// OGCG: %[[X_VAL:.*]] = load i32, ptr %[[X_ADDR]], align 4
1257+
// OGCG: switch i32 %[[X_VAL]], label %[[DEFAULT:.*]] [
1258+
// OGCG-DAG: i32 0, label %[[SW0:.*]]
1259+
// OGCG-DAG: i32 1, label %[[SW1:.*]]
1260+
// OGCG-DAG: i32 2, label %[[SW2:.*]]
1261+
// OGCG: ]
1262+
// OGCG: [[SW0]]:
1263+
// OGCG: br label %[[RETURN:.*]]
1264+
// OGCG: [[SW1]]:
1265+
// OGCG: br label %[[RETURN]]
1266+
// OGCG: [[SW2]]:
1267+
// OGCG: br label %[[RETURN]]
1268+
// OGCG: [[DEFAULT]]:
1269+
// OGCG: br label %[[RETURN]]
1270+
// OGCG: [[RETURN]]:
1271+
// OGCG: %[[RETVAL_LOAD:.*]] = load i32, ptr %[[RETVAL]], align 4
1272+
// OGCG: ret i32 %[[RETVAL_LOAD]]

0 commit comments

Comments
 (0)