Skip to content

Commit b881c34

Browse files
committed
[CIR] Fix multiple return statements in switch cases
Add support for multiple return statements in switch cases. Cases in switch statements don't have their own scopes but are distinct regions nonetheless. Insert multiple return blocks for each case and handle them in the cleanup code.
1 parent 0926265 commit b881c34

File tree

3 files changed

+160
-40
lines changed

3 files changed

+160
-40
lines changed

clang/lib/CIR/CodeGen/CIRGenFunction.cpp

Lines changed: 20 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -242,12 +242,15 @@ void CIRGenFunction::LexicalScope::cleanup() {
242242
}
243243
};
244244

245-
if (returnBlock != nullptr) {
246-
// Write out the return block, which loads the value from `__retval` and
247-
// issues the `cir.return`.
245+
// Cleanup are done right before codegen resumes a scope. This is where
246+
// objects are destroyed. Process all return blocks.
247+
llvm::SmallVector<mlir::Block *> retBlocks;
248+
for (mlir::Block *retBlock : localScope->getRetBlocks()) {
248249
mlir::OpBuilder::InsertionGuard guard(builder);
249-
builder.setInsertionPointToEnd(returnBlock);
250-
(void)emitReturn(*returnLoc);
250+
builder.setInsertionPointToEnd(retBlock);
251+
retBlocks.push_back(retBlock);
252+
mlir::Location retLoc = localScope->getRetLoc(retBlock);
253+
emitReturn(retLoc);
251254
}
252255

253256
auto insertCleanupAndLeave = [&](mlir::Block *insPt) {
@@ -274,19 +277,22 @@ void CIRGenFunction::LexicalScope::cleanup() {
274277

275278
if (localScope->depth == 0) {
276279
// Reached the end of the function.
277-
if (returnBlock != nullptr) {
278-
if (returnBlock->getUses().empty()) {
279-
returnBlock->erase();
280+
// Special handling only for single return block case
281+
if (localScope->getRetBlocks().size() == 1) {
282+
mlir::Block *retBlock = localScope->getRetBlocks()[0];
283+
mlir::Location retLoc = localScope->getRetLoc(retBlock);
284+
if (retBlock->getUses().empty()) {
285+
retBlock->erase();
280286
} else {
281287
// Thread return block via cleanup block.
282288
if (cleanupBlock) {
283-
for (mlir::BlockOperand &blockUse : returnBlock->getUses()) {
289+
for (mlir::BlockOperand &blockUse : retBlock->getUses()) {
284290
cir::BrOp brOp = mlir::cast<cir::BrOp>(blockUse.getOwner());
285291
brOp.setSuccessor(cleanupBlock);
286292
}
287293
}
288294

289-
cir::BrOp::create(builder, *returnLoc, returnBlock);
295+
cir::BrOp::create(builder, retLoc, retBlock);
290296
return;
291297
}
292298
}
@@ -324,8 +330,10 @@ void CIRGenFunction::LexicalScope::cleanup() {
324330
bool entryBlock = builder.getInsertionBlock()->isEntryBlock();
325331
if (!entryBlock && curBlock->empty()) {
326332
curBlock->erase();
327-
if (returnBlock != nullptr && returnBlock->getUses().empty())
328-
returnBlock->erase();
333+
for (mlir::Block *retBlock : retBlocks) {
334+
if (retBlock->getUses().empty())
335+
retBlock->erase();
336+
}
329337
return;
330338
}
331339

clang/lib/CIR/CodeGen/CIRGenFunction.h

Lines changed: 53 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1103,44 +1103,69 @@ class CIRGenFunction : public CIRGenTypeCache {
11031103
// ---
11041104

11051105
private:
1106-
// `returnBlock`, `returnLoc`, and all the functions that deal with them
1107-
// will change and become more complicated when `switch` statements are
1108-
// upstreamed. `case` statements within the `switch` are in the same scope
1109-
// but have their own regions. Therefore the LexicalScope will need to
1110-
// keep track of multiple return blocks.
1111-
mlir::Block *returnBlock = nullptr;
1112-
std::optional<mlir::Location> returnLoc;
1113-
1114-
// See the comment on `getOrCreateRetBlock`.
1106+
// On switches we need one return block per region, since cases don't
1107+
// have their own scopes but are distinct regions nonetheless.
1108+
1109+
// TODO: This implementation should change once we have support for early
1110+
// exits in MLIR structured control flow (llvm-project#161575)
1111+
llvm::SmallVector<mlir::Block *> retBlocks;
1112+
llvm::DenseMap<mlir::Block *, mlir::Location> retLocs;
1113+
llvm::DenseMap<cir::CaseOp, unsigned> retBlockInCaseIndex;
1114+
std::optional<unsigned> normalRetBlockIndex;
1115+
1116+
// There's usually only one ret block per scope, but this needs to be
1117+
// get or create because of potential unreachable return statements, note
1118+
// that for those, all source location maps to the first one found.
11151119
mlir::Block *createRetBlock(CIRGenFunction &cgf, mlir::Location loc) {
1116-
assert(returnBlock == nullptr && "only one return block per scope");
1117-
// Create the cleanup block but don't hook it up just yet.
1120+
assert((isa_and_nonnull<cir::CaseOp>(
1121+
cgf.builder.getBlock()->getParentOp()) ||
1122+
retBlocks.size() == 0) &&
1123+
"only switches can hold more than one ret block");
1124+
1125+
// Create the return block but don't hook it up just yet.
11181126
mlir::OpBuilder::InsertionGuard guard(cgf.builder);
1119-
returnBlock =
1120-
cgf.builder.createBlock(cgf.builder.getBlock()->getParent());
1121-
updateRetLoc(returnBlock, loc);
1122-
return returnBlock;
1127+
auto *b = cgf.builder.createBlock(cgf.builder.getBlock()->getParent());
1128+
retBlocks.push_back(b);
1129+
updateRetLoc(b, loc);
1130+
return b;
11231131
}
11241132

11251133
cir::ReturnOp emitReturn(mlir::Location loc);
11261134
void emitImplicitReturn();
11271135

11281136
public:
1129-
mlir::Block *getRetBlock() { return returnBlock; }
1130-
mlir::Location getRetLoc(mlir::Block *b) { return *returnLoc; }
1131-
void updateRetLoc(mlir::Block *b, mlir::Location loc) { returnLoc = loc; }
1132-
1133-
// Create the return block for this scope, or return the existing one.
1134-
// This get-or-create logic is necessary to handle multiple return
1135-
// statements within the same scope, which can happen if some of them are
1136-
// dead code or if there is a `goto` into the middle of the scope.
1137+
llvm::ArrayRef<mlir::Block *> getRetBlocks() { return retBlocks; }
1138+
mlir::Location getRetLoc(mlir::Block *b) { return retLocs.at(b); }
1139+
void updateRetLoc(mlir::Block *b, mlir::Location loc) {
1140+
retLocs.insert_or_assign(b, loc);
1141+
}
1142+
11371143
mlir::Block *getOrCreateRetBlock(CIRGenFunction &cgf, mlir::Location loc) {
1138-
if (returnBlock == nullptr) {
1139-
returnBlock = createRetBlock(cgf, loc);
1140-
return returnBlock;
1144+
// Check if we're inside a case region
1145+
if (auto caseOp = mlir::dyn_cast_if_present<cir::CaseOp>(
1146+
cgf.builder.getBlock()->getParentOp())) {
1147+
auto iter = retBlockInCaseIndex.find(caseOp);
1148+
if (iter != retBlockInCaseIndex.end()) {
1149+
// Reuse existing return block
1150+
mlir::Block *ret = retBlocks[iter->second];
1151+
updateRetLoc(ret, loc);
1152+
return ret;
1153+
}
1154+
// Create new return block
1155+
mlir::Block *ret = createRetBlock(cgf, loc);
1156+
retBlockInCaseIndex[caseOp] = retBlocks.size() - 1;
1157+
return ret;
11411158
}
1142-
updateRetLoc(returnBlock, loc);
1143-
return returnBlock;
1159+
1160+
if (normalRetBlockIndex) {
1161+
mlir::Block *ret = retBlocks[*normalRetBlockIndex];
1162+
updateRetLoc(ret, loc);
1163+
return ret;
1164+
}
1165+
1166+
mlir::Block *ret = createRetBlock(cgf, loc);
1167+
normalRetBlockIndex = retBlocks.size() - 1;
1168+
return ret;
11441169
}
11451170

11461171
mlir::Block *getEntryBlock() { return entryBlock; }

clang/test/CIR/CodeGen/switch.cpp

Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1183,3 +1183,90 @@ int nested_switch(int a) {
11831183
// OGCG: [[IFEND10]]:
11841184
// OGCG: br label %[[EPILOG]]
11851185
// OGCG: [[EPILOG]]:
1186+
1187+
int sw_return_multi_cases(int x) {
1188+
switch (x) {
1189+
case 0:
1190+
return 0;
1191+
case 1:
1192+
return 1;
1193+
case 2:
1194+
return 2;
1195+
default:
1196+
return -1;
1197+
}
1198+
}
1199+
1200+
// CIR-LABEL: cir.func{{.*}} @_Z21sw_return_multi_casesi
1201+
// CIR: cir.switch (%{{.*}} : !s32i) {
1202+
// CIR-NEXT: cir.case(equal, [#cir.int<0> : !s32i]) {
1203+
// CIR: %[[ZERO:.*]] = cir.const #cir.int<0> : !s32i
1204+
// CIR: cir.store{{.*}} %[[ZERO]], %{{.*}} : !s32i, !cir.ptr<!s32i>
1205+
// CIR: %[[RET0:.*]] = cir.load{{.*}} %{{.*}} : !cir.ptr<!s32i>, !s32i
1206+
// CIR-NEXT: cir.return %[[RET0]] : !s32i
1207+
// CIR-NEXT: }
1208+
// CIR-NEXT: cir.case(equal, [#cir.int<1> : !s32i]) {
1209+
// CIR: %[[ONE:.*]] = cir.const #cir.int<1> : !s32i
1210+
// CIR: cir.store{{.*}} %[[ONE]], %{{.*}} : !s32i, !cir.ptr<!s32i>
1211+
// CIR: %[[RET1:.*]] = cir.load{{.*}} %{{.*}} : !cir.ptr<!s32i>, !s32i
1212+
// CIR-NEXT: cir.return %[[RET1]] : !s32i
1213+
// CIR-NEXT: }
1214+
// CIR-NEXT: cir.case(equal, [#cir.int<2> : !s32i]) {
1215+
// CIR: %[[TWO:.*]] = cir.const #cir.int<2> : !s32i
1216+
// CIR: cir.store{{.*}} %[[TWO]], %{{.*}} : !s32i, !cir.ptr<!s32i>
1217+
// CIR: %[[RET2:.*]] = cir.load{{.*}} %{{.*}} : !cir.ptr<!s32i>, !s32i
1218+
// CIR-NEXT: cir.return %[[RET2]] : !s32i
1219+
// CIR-NEXT: }
1220+
// CIR-NEXT: cir.case(default, []) {
1221+
// CIR: %[[ONE:.*]] = cir.const #cir.int<1> : !s32i
1222+
// CIR: %[[NEG:.*]] = cir.unary(minus, %[[ONE]]) {{.*}} : !s32i, !s32i
1223+
// CIR: cir.store{{.*}} %[[NEG]], %{{.*}} : !s32i, !cir.ptr<!s32i>
1224+
// CIR: %[[RETDEF:.*]] = cir.load{{.*}} %{{.*}} : !cir.ptr<!s32i>, !s32i
1225+
// CIR-NEXT: cir.return %[[RETDEF]] : !s32i
1226+
// CIR-NEXT: }
1227+
// CIR-NEXT: cir.yield
1228+
1229+
// LLVM-LABEL: define{{.*}} i32 @_Z21sw_return_multi_casesi
1230+
// LLVM: switch i32 %{{.*}}, label %[[DEFAULT:.*]] [
1231+
// LLVM-DAG: i32 0, label %[[CASE0:.*]]
1232+
// LLVM-DAG: i32 1, label %[[CASE1:.*]]
1233+
// LLVM-DAG: i32 2, label %[[CASE2:.*]]
1234+
// LLVM: ]
1235+
// LLVM: [[CASE0]]:
1236+
// LLVM: store i32 0, ptr %{{.*}}, align 4
1237+
// LLVM: %{{.*}} = load i32, ptr %{{.*}}, align 4
1238+
// LLVM: ret i32 %{{.*}}
1239+
// LLVM: [[CASE1]]:
1240+
// LLVM: store i32 1, ptr %{{.*}}, align 4
1241+
// LLVM: %{{.*}} = load i32, ptr %{{.*}}, align 4
1242+
// LLVM: ret i32 %{{.*}}
1243+
// LLVM: [[CASE2]]:
1244+
// LLVM: store i32 2, ptr %{{.*}}, align 4
1245+
// LLVM: %{{.*}} = load i32, ptr %{{.*}}, align 4
1246+
// LLVM: ret i32 %{{.*}}
1247+
// LLVM: [[DEFAULT]]:
1248+
// LLVM: store i32 -1, ptr %{{.*}}, align 4
1249+
// LLVM: %{{.*}} = load i32, ptr %{{.*}}, align 4
1250+
// LLVM: ret i32 %{{.*}}
1251+
1252+
// OGCG-LABEL: define{{.*}} i32 @_Z21sw_return_multi_casesi
1253+
// OGCG: entry:
1254+
// OGCG: %[[RETVAL:.*]] = alloca i32, align 4
1255+
// OGCG: %[[X_ADDR:.*]] = alloca i32, align 4
1256+
// OGCG: %[[X_VAL:.*]] = load i32, ptr %[[X_ADDR]], align 4
1257+
// OGCG: switch i32 %[[X_VAL]], label %[[DEFAULT:.*]] [
1258+
// OGCG-DAG: i32 0, label %[[SW0:.*]]
1259+
// OGCG-DAG: i32 1, label %[[SW1:.*]]
1260+
// OGCG-DAG: i32 2, label %[[SW2:.*]]
1261+
// OGCG: ]
1262+
// OGCG: [[SW0]]:
1263+
// OGCG: br label %[[RETURN:.*]]
1264+
// OGCG: [[SW1]]:
1265+
// OGCG: br label %[[RETURN]]
1266+
// OGCG: [[SW2]]:
1267+
// OGCG: br label %[[RETURN]]
1268+
// OGCG: [[DEFAULT]]:
1269+
// OGCG: br label %[[RETURN]]
1270+
// OGCG: [[RETURN]]:
1271+
// OGCG: %[[RETVAL_LOAD:.*]] = load i32, ptr %[[RETVAL]], align 4
1272+
// OGCG: ret i32 %[[RETVAL_LOAD]]

0 commit comments

Comments
 (0)