diff --git a/clang/lib/CIR/CodeGen/CIRGenFunction.cpp b/clang/lib/CIR/CodeGen/CIRGenFunction.cpp index 25a46df406df4..23440689c14a9 100644 --- a/clang/lib/CIR/CodeGen/CIRGenFunction.cpp +++ b/clang/lib/CIR/CodeGen/CIRGenFunction.cpp @@ -242,12 +242,15 @@ void CIRGenFunction::LexicalScope::cleanup() { } }; - if (returnBlock != nullptr) { - // Write out the return block, which loads the value from `__retval` and - // issues the `cir.return`. + // Cleanup are done right before codegen resumes a scope. This is where + // objects are destroyed. Process all return blocks. + llvm::SmallVector retBlocks; + for (mlir::Block *retBlock : localScope->getRetBlocks()) { mlir::OpBuilder::InsertionGuard guard(builder); - builder.setInsertionPointToEnd(returnBlock); - (void)emitReturn(*returnLoc); + builder.setInsertionPointToEnd(retBlock); + retBlocks.push_back(retBlock); + mlir::Location retLoc = localScope->getRetLoc(retBlock); + emitReturn(retLoc); } auto insertCleanupAndLeave = [&](mlir::Block *insPt) { @@ -274,19 +277,21 @@ void CIRGenFunction::LexicalScope::cleanup() { if (localScope->depth == 0) { // Reached the end of the function. - if (returnBlock != nullptr) { - if (returnBlock->getUses().empty()) { - returnBlock->erase(); + // Special handling only for single return block case + if (localScope->getRetBlocks().size() == 1) { + mlir::Block *retBlock = localScope->getRetBlocks()[0]; + mlir::Location retLoc = localScope->getRetLoc(retBlock); + if (retBlock->getUses().empty()) { + retBlock->erase(); } else { // Thread return block via cleanup block. if (cleanupBlock) { - for (mlir::BlockOperand &blockUse : returnBlock->getUses()) { + for (mlir::BlockOperand &blockUse : retBlock->getUses()) { cir::BrOp brOp = mlir::cast(blockUse.getOwner()); brOp.setSuccessor(cleanupBlock); } } - - builder.create(*returnLoc, returnBlock); + builder.create(retLoc, retBlock); return; } } @@ -324,8 +329,10 @@ void CIRGenFunction::LexicalScope::cleanup() { bool entryBlock = builder.getInsertionBlock()->isEntryBlock(); if (!entryBlock && curBlock->empty()) { curBlock->erase(); - if (returnBlock != nullptr && returnBlock->getUses().empty()) - returnBlock->erase(); + for (mlir::Block *retBlock : retBlocks) { + if (retBlock->getUses().empty()) + retBlock->erase(); + } return; } diff --git a/clang/lib/CIR/CodeGen/CIRGenFunction.h b/clang/lib/CIR/CodeGen/CIRGenFunction.h index 5a71126c8dc07..dfa7917be6b57 100644 --- a/clang/lib/CIR/CodeGen/CIRGenFunction.h +++ b/clang/lib/CIR/CodeGen/CIRGenFunction.h @@ -1040,44 +1040,69 @@ class CIRGenFunction : public CIRGenTypeCache { // --- private: - // `returnBlock`, `returnLoc`, and all the functions that deal with them - // will change and become more complicated when `switch` statements are - // upstreamed. `case` statements within the `switch` are in the same scope - // but have their own regions. Therefore the LexicalScope will need to - // keep track of multiple return blocks. - mlir::Block *returnBlock = nullptr; - std::optional returnLoc; - - // See the comment on `getOrCreateRetBlock`. + // On switches we need one return block per region, since cases don't + // have their own scopes but are distinct regions nonetheless. + + // TODO: This implementation should change once we have support for early + // exits in MLIR structured control flow (llvm-project#161575) + llvm::SmallVector retBlocks; + llvm::DenseMap retLocs; + llvm::DenseMap retBlockInCaseIndex; + std::optional normalRetBlockIndex; + + // There's usually only one ret block per scope, but this needs to be + // get or create because of potential unreachable return statements, note + // that for those, all source location maps to the first one found. mlir::Block *createRetBlock(CIRGenFunction &cgf, mlir::Location loc) { - assert(returnBlock == nullptr && "only one return block per scope"); - // Create the cleanup block but don't hook it up just yet. + assert((isa_and_nonnull( + cgf.builder.getBlock()->getParentOp()) || + retBlocks.size() == 0) && + "only switches can hold more than one ret block"); + + // Create the return block but don't hook it up just yet. mlir::OpBuilder::InsertionGuard guard(cgf.builder); - returnBlock = - cgf.builder.createBlock(cgf.builder.getBlock()->getParent()); - updateRetLoc(returnBlock, loc); - return returnBlock; + auto *b = cgf.builder.createBlock(cgf.builder.getBlock()->getParent()); + retBlocks.push_back(b); + updateRetLoc(b, loc); + return b; } cir::ReturnOp emitReturn(mlir::Location loc); void emitImplicitReturn(); public: - mlir::Block *getRetBlock() { return returnBlock; } - mlir::Location getRetLoc(mlir::Block *b) { return *returnLoc; } - void updateRetLoc(mlir::Block *b, mlir::Location loc) { returnLoc = loc; } - - // Create the return block for this scope, or return the existing one. - // This get-or-create logic is necessary to handle multiple return - // statements within the same scope, which can happen if some of them are - // dead code or if there is a `goto` into the middle of the scope. + llvm::ArrayRef getRetBlocks() { return retBlocks; } + mlir::Location getRetLoc(mlir::Block *b) { return retLocs.at(b); } + void updateRetLoc(mlir::Block *b, mlir::Location loc) { + retLocs.insert_or_assign(b, loc); + } + mlir::Block *getOrCreateRetBlock(CIRGenFunction &cgf, mlir::Location loc) { - if (returnBlock == nullptr) { - returnBlock = createRetBlock(cgf, loc); - return returnBlock; + // Check if we're inside a case region + if (auto caseOp = mlir::dyn_cast_if_present( + cgf.builder.getBlock()->getParentOp())) { + auto iter = retBlockInCaseIndex.find(caseOp); + if (iter != retBlockInCaseIndex.end()) { + // Reuse existing return block + mlir::Block *ret = retBlocks[iter->second]; + updateRetLoc(ret, loc); + return ret; + } + // Create new return block + mlir::Block *ret = createRetBlock(cgf, loc); + retBlockInCaseIndex[caseOp] = retBlocks.size() - 1; + return ret; } - updateRetLoc(returnBlock, loc); - return returnBlock; + + if (normalRetBlockIndex) { + mlir::Block *ret = retBlocks[*normalRetBlockIndex]; + updateRetLoc(ret, loc); + return ret; + } + + mlir::Block *ret = createRetBlock(cgf, loc); + normalRetBlockIndex = retBlocks.size() - 1; + return ret; } mlir::Block *getEntryBlock() { return entryBlock; } diff --git a/clang/test/CIR/CodeGen/switch.cpp b/clang/test/CIR/CodeGen/switch.cpp index e13aa8f4f4953..3824be0d08c2f 100644 --- a/clang/test/CIR/CodeGen/switch.cpp +++ b/clang/test/CIR/CodeGen/switch.cpp @@ -1183,3 +1183,90 @@ int nested_switch(int a) { // OGCG: [[IFEND10]]: // OGCG: br label %[[EPILOG]] // OGCG: [[EPILOG]]: + +int sw_return_multi_cases(int x) { + switch (x) { + case 0: + return 0; + case 1: + return 1; + case 2: + return 2; + default: + return -1; + } +} + +// CIR-LABEL: cir.func{{.*}} @_Z21sw_return_multi_casesi +// CIR: cir.switch (%{{.*}} : !s32i) { +// CIR-NEXT: cir.case(equal, [#cir.int<0> : !s32i]) { +// CIR: %[[ZERO:.*]] = cir.const #cir.int<0> : !s32i +// CIR: cir.store{{.*}} %[[ZERO]], %{{.*}} : !s32i, !cir.ptr +// CIR: %[[RET0:.*]] = cir.load{{.*}} %{{.*}} : !cir.ptr, !s32i +// CIR-NEXT: cir.return %[[RET0]] : !s32i +// CIR-NEXT: } +// CIR-NEXT: cir.case(equal, [#cir.int<1> : !s32i]) { +// CIR: %[[ONE:.*]] = cir.const #cir.int<1> : !s32i +// CIR: cir.store{{.*}} %[[ONE]], %{{.*}} : !s32i, !cir.ptr +// CIR: %[[RET1:.*]] = cir.load{{.*}} %{{.*}} : !cir.ptr, !s32i +// CIR-NEXT: cir.return %[[RET1]] : !s32i +// CIR-NEXT: } +// CIR-NEXT: cir.case(equal, [#cir.int<2> : !s32i]) { +// CIR: %[[TWO:.*]] = cir.const #cir.int<2> : !s32i +// CIR: cir.store{{.*}} %[[TWO]], %{{.*}} : !s32i, !cir.ptr +// CIR: %[[RET2:.*]] = cir.load{{.*}} %{{.*}} : !cir.ptr, !s32i +// CIR-NEXT: cir.return %[[RET2]] : !s32i +// CIR-NEXT: } +// CIR-NEXT: cir.case(default, []) { +// CIR: %[[ONE:.*]] = cir.const #cir.int<1> : !s32i +// CIR: %[[NEG:.*]] = cir.unary(minus, %[[ONE]]) {{.*}} : !s32i, !s32i +// CIR: cir.store{{.*}} %[[NEG]], %{{.*}} : !s32i, !cir.ptr +// CIR: %[[RETDEF:.*]] = cir.load{{.*}} %{{.*}} : !cir.ptr, !s32i +// CIR-NEXT: cir.return %[[RETDEF]] : !s32i +// CIR-NEXT: } +// CIR-NEXT: cir.yield + +// LLVM-LABEL: define{{.*}} i32 @_Z21sw_return_multi_casesi +// LLVM: switch i32 %{{.*}}, label %[[DEFAULT:.*]] [ +// LLVM-DAG: i32 0, label %[[CASE0:.*]] +// LLVM-DAG: i32 1, label %[[CASE1:.*]] +// LLVM-DAG: i32 2, label %[[CASE2:.*]] +// LLVM: ] +// LLVM: [[CASE0]]: +// LLVM: store i32 0, ptr %{{.*}}, align 4 +// LLVM: %{{.*}} = load i32, ptr %{{.*}}, align 4 +// LLVM: ret i32 %{{.*}} +// LLVM: [[CASE1]]: +// LLVM: store i32 1, ptr %{{.*}}, align 4 +// LLVM: %{{.*}} = load i32, ptr %{{.*}}, align 4 +// LLVM: ret i32 %{{.*}} +// LLVM: [[CASE2]]: +// LLVM: store i32 2, ptr %{{.*}}, align 4 +// LLVM: %{{.*}} = load i32, ptr %{{.*}}, align 4 +// LLVM: ret i32 %{{.*}} +// LLVM: [[DEFAULT]]: +// LLVM: store i32 -1, ptr %{{.*}}, align 4 +// LLVM: %{{.*}} = load i32, ptr %{{.*}}, align 4 +// LLVM: ret i32 %{{.*}} + +// OGCG-LABEL: define{{.*}} i32 @_Z21sw_return_multi_casesi +// OGCG: entry: +// OGCG: %[[RETVAL:.*]] = alloca i32, align 4 +// OGCG: %[[X_ADDR:.*]] = alloca i32, align 4 +// OGCG: %[[X_VAL:.*]] = load i32, ptr %[[X_ADDR]], align 4 +// OGCG: switch i32 %[[X_VAL]], label %[[DEFAULT:.*]] [ +// OGCG-DAG: i32 0, label %[[SW0:.*]] +// OGCG-DAG: i32 1, label %[[SW1:.*]] +// OGCG-DAG: i32 2, label %[[SW2:.*]] +// OGCG: ] +// OGCG: [[SW0]]: +// OGCG: br label %[[RETURN:.*]] +// OGCG: [[SW1]]: +// OGCG: br label %[[RETURN]] +// OGCG: [[SW2]]: +// OGCG: br label %[[RETURN]] +// OGCG: [[DEFAULT]]: +// OGCG: br label %[[RETURN]] +// OGCG: [[RETURN]]: +// OGCG: %[[RETVAL_LOAD:.*]] = load i32, ptr %[[RETVAL]], align 4 +// OGCG: ret i32 %[[RETVAL_LOAD]]