Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 20 additions & 13 deletions clang/lib/CIR/CodeGen/CIRGenFunction.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -242,12 +242,15 @@ void CIRGenFunction::LexicalScope::cleanup() {
}
};

if (returnBlock != nullptr) {
// Write out the return block, which loads the value from `__retval` and
// issues the `cir.return`.
// Cleanup are done right before codegen resumes a scope. This is where
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
// Cleanup are done right before codegen resumes a scope. This is where
// Cleanups are done right before codegen resumes a scope. This is where

// objects are destroyed. Process all return blocks.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In spite of what the comment says, I think this still has problems if we are returning through a cleanup. In the upstream code, I haven't yet committed the patch that handles return through cleanups (#163849), but I don't think that will work with multiple cleanups, as this case fails in the incubator:

https://godbolt.org/z/dTjjcG38d

That's not necessarily a problem with this PR, just a bigger problem that we need to solve with regard to cleanups.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1

llvm::SmallVector<mlir::Block *> retBlocks;
for (mlir::Block *retBlock : localScope->getRetBlocks()) {
mlir::OpBuilder::InsertionGuard guard(builder);
builder.setInsertionPointToEnd(returnBlock);
(void)emitReturn(*returnLoc);
builder.setInsertionPointToEnd(retBlock);
retBlocks.push_back(retBlock);
mlir::Location retLoc = localScope->getRetLoc(retBlock);
emitReturn(retLoc);
}

auto insertCleanupAndLeave = [&](mlir::Block *insPt) {
Expand All @@ -274,19 +277,21 @@ void CIRGenFunction::LexicalScope::cleanup() {

if (localScope->depth == 0) {
// Reached the end of the function.
if (returnBlock != nullptr) {
if (returnBlock->getUses().empty()) {
returnBlock->erase();
// Special handling only for single return block case
if (localScope->getRetBlocks().size() == 1) {
mlir::Block *retBlock = localScope->getRetBlocks()[0];
mlir::Location retLoc = localScope->getRetLoc(retBlock);
if (retBlock->getUses().empty()) {
retBlock->erase();
} else {
// Thread return block via cleanup block.
if (cleanupBlock) {
for (mlir::BlockOperand &blockUse : returnBlock->getUses()) {
for (mlir::BlockOperand &blockUse : retBlock->getUses()) {
cir::BrOp brOp = mlir::cast<cir::BrOp>(blockUse.getOwner());
brOp.setSuccessor(cleanupBlock);
}
}

builder.create<cir::BrOp>(*returnLoc, returnBlock);
builder.create<cir::BrOp>(retLoc, retBlock);
return;
}
}
Expand Down Expand Up @@ -324,8 +329,10 @@ void CIRGenFunction::LexicalScope::cleanup() {
bool entryBlock = builder.getInsertionBlock()->isEntryBlock();
if (!entryBlock && curBlock->empty()) {
curBlock->erase();
if (returnBlock != nullptr && returnBlock->getUses().empty())
returnBlock->erase();
for (mlir::Block *retBlock : retBlocks) {
if (retBlock->getUses().empty())
retBlock->erase();
}
return;
}

Expand Down
81 changes: 53 additions & 28 deletions clang/lib/CIR/CodeGen/CIRGenFunction.h
Original file line number Diff line number Diff line change
Expand Up @@ -1040,44 +1040,69 @@ class CIRGenFunction : public CIRGenTypeCache {
// ---

private:
// `returnBlock`, `returnLoc`, and all the functions that deal with them
// will change and become more complicated when `switch` statements are
// upstreamed. `case` statements within the `switch` are in the same scope
// but have their own regions. Therefore the LexicalScope will need to
// keep track of multiple return blocks.
mlir::Block *returnBlock = nullptr;
std::optional<mlir::Location> returnLoc;

// See the comment on `getOrCreateRetBlock`.
// On switches we need one return block per region, since cases don't
// have their own scopes but are distinct regions nonetheless.

// TODO: This implementation should change once we have support for early
// exits in MLIR structured control flow (llvm-project#161575)
llvm::SmallVector<mlir::Block *> retBlocks;
llvm::DenseMap<mlir::Block *, mlir::Location> retLocs;
llvm::DenseMap<cir::CaseOp, unsigned> retBlockInCaseIndex;
std::optional<unsigned> normalRetBlockIndex;

// There's usually only one ret block per scope, but this needs to be
// get or create because of potential unreachable return statements, note
// that for those, all source location maps to the first one found.
mlir::Block *createRetBlock(CIRGenFunction &cgf, mlir::Location loc) {
assert(returnBlock == nullptr && "only one return block per scope");
// Create the cleanup block but don't hook it up just yet.
assert((isa_and_nonnull<cir::CaseOp>(
cgf.builder.getBlock()->getParentOp()) ||
retBlocks.size() == 0) &&
"only switches can hold more than one ret block");

// Create the return block but don't hook it up just yet.
mlir::OpBuilder::InsertionGuard guard(cgf.builder);
returnBlock =
cgf.builder.createBlock(cgf.builder.getBlock()->getParent());
updateRetLoc(returnBlock, loc);
return returnBlock;
auto *b = cgf.builder.createBlock(cgf.builder.getBlock()->getParent());
retBlocks.push_back(b);
updateRetLoc(b, loc);
return b;
}

cir::ReturnOp emitReturn(mlir::Location loc);
void emitImplicitReturn();

public:
mlir::Block *getRetBlock() { return returnBlock; }
mlir::Location getRetLoc(mlir::Block *b) { return *returnLoc; }
void updateRetLoc(mlir::Block *b, mlir::Location loc) { returnLoc = loc; }

// Create the return block for this scope, or return the existing one.
// This get-or-create logic is necessary to handle multiple return
// statements within the same scope, which can happen if some of them are
// dead code or if there is a `goto` into the middle of the scope.
llvm::ArrayRef<mlir::Block *> getRetBlocks() { return retBlocks; }
mlir::Location getRetLoc(mlir::Block *b) { return retLocs.at(b); }
void updateRetLoc(mlir::Block *b, mlir::Location loc) {
retLocs.insert_or_assign(b, loc);
}

mlir::Block *getOrCreateRetBlock(CIRGenFunction &cgf, mlir::Location loc) {
if (returnBlock == nullptr) {
returnBlock = createRetBlock(cgf, loc);
return returnBlock;
// Check if we're inside a case region
if (auto caseOp = mlir::dyn_cast_if_present<cir::CaseOp>(
cgf.builder.getBlock()->getParentOp())) {
auto iter = retBlockInCaseIndex.find(caseOp);
if (iter != retBlockInCaseIndex.end()) {
// Reuse existing return block
mlir::Block *ret = retBlocks[iter->second];
updateRetLoc(ret, loc);
return ret;
}
// Create new return block
mlir::Block *ret = createRetBlock(cgf, loc);
retBlockInCaseIndex[caseOp] = retBlocks.size() - 1;
return ret;
}
updateRetLoc(returnBlock, loc);
return returnBlock;

if (normalRetBlockIndex) {
mlir::Block *ret = retBlocks[*normalRetBlockIndex];
updateRetLoc(ret, loc);
return ret;
}

mlir::Block *ret = createRetBlock(cgf, loc);
normalRetBlockIndex = retBlocks.size() - 1;
return ret;
}

mlir::Block *getEntryBlock() { return entryBlock; }
Expand Down
87 changes: 87 additions & 0 deletions clang/test/CIR/CodeGen/switch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1183,3 +1183,90 @@ int nested_switch(int a) {
// OGCG: [[IFEND10]]:
// OGCG: br label %[[EPILOG]]
// OGCG: [[EPILOG]]:

int sw_return_multi_cases(int x) {
switch (x) {
case 0:
return 0;
case 1:
return 1;
case 2:
return 2;
default:
return -1;
}
}

// CIR-LABEL: cir.func{{.*}} @_Z21sw_return_multi_casesi
// CIR: cir.switch (%{{.*}} : !s32i) {
// CIR-NEXT: cir.case(equal, [#cir.int<0> : !s32i]) {
// CIR: %[[ZERO:.*]] = cir.const #cir.int<0> : !s32i
// CIR: cir.store{{.*}} %[[ZERO]], %{{.*}} : !s32i, !cir.ptr<!s32i>
// CIR: %[[RET0:.*]] = cir.load{{.*}} %{{.*}} : !cir.ptr<!s32i>, !s32i
// CIR-NEXT: cir.return %[[RET0]] : !s32i
// CIR-NEXT: }
// CIR-NEXT: cir.case(equal, [#cir.int<1> : !s32i]) {
// CIR: %[[ONE:.*]] = cir.const #cir.int<1> : !s32i
// CIR: cir.store{{.*}} %[[ONE]], %{{.*}} : !s32i, !cir.ptr<!s32i>
// CIR: %[[RET1:.*]] = cir.load{{.*}} %{{.*}} : !cir.ptr<!s32i>, !s32i
// CIR-NEXT: cir.return %[[RET1]] : !s32i
// CIR-NEXT: }
// CIR-NEXT: cir.case(equal, [#cir.int<2> : !s32i]) {
// CIR: %[[TWO:.*]] = cir.const #cir.int<2> : !s32i
// CIR: cir.store{{.*}} %[[TWO]], %{{.*}} : !s32i, !cir.ptr<!s32i>
// CIR: %[[RET2:.*]] = cir.load{{.*}} %{{.*}} : !cir.ptr<!s32i>, !s32i
// CIR-NEXT: cir.return %[[RET2]] : !s32i
// CIR-NEXT: }
// CIR-NEXT: cir.case(default, []) {
// CIR: %[[ONE:.*]] = cir.const #cir.int<1> : !s32i
// CIR: %[[NEG:.*]] = cir.unary(minus, %[[ONE]]) {{.*}} : !s32i, !s32i
// CIR: cir.store{{.*}} %[[NEG]], %{{.*}} : !s32i, !cir.ptr<!s32i>
// CIR: %[[RETDEF:.*]] = cir.load{{.*}} %{{.*}} : !cir.ptr<!s32i>, !s32i
// CIR-NEXT: cir.return %[[RETDEF]] : !s32i
// CIR-NEXT: }
// CIR-NEXT: cir.yield

// LLVM-LABEL: define{{.*}} i32 @_Z21sw_return_multi_casesi
// LLVM: switch i32 %{{.*}}, label %[[DEFAULT:.*]] [
// LLVM-DAG: i32 0, label %[[CASE0:.*]]
// LLVM-DAG: i32 1, label %[[CASE1:.*]]
// LLVM-DAG: i32 2, label %[[CASE2:.*]]
// LLVM: ]
// LLVM: [[CASE0]]:
// LLVM: store i32 0, ptr %{{.*}}, align 4
// LLVM: %{{.*}} = load i32, ptr %{{.*}}, align 4
// LLVM: ret i32 %{{.*}}
// LLVM: [[CASE1]]:
// LLVM: store i32 1, ptr %{{.*}}, align 4
// LLVM: %{{.*}} = load i32, ptr %{{.*}}, align 4
// LLVM: ret i32 %{{.*}}
// LLVM: [[CASE2]]:
// LLVM: store i32 2, ptr %{{.*}}, align 4
// LLVM: %{{.*}} = load i32, ptr %{{.*}}, align 4
// LLVM: ret i32 %{{.*}}
// LLVM: [[DEFAULT]]:
// LLVM: store i32 -1, ptr %{{.*}}, align 4
// LLVM: %{{.*}} = load i32, ptr %{{.*}}, align 4
// LLVM: ret i32 %{{.*}}

// OGCG-LABEL: define{{.*}} i32 @_Z21sw_return_multi_casesi
// OGCG: entry:
// OGCG: %[[RETVAL:.*]] = alloca i32, align 4
// OGCG: %[[X_ADDR:.*]] = alloca i32, align 4
// OGCG: %[[X_VAL:.*]] = load i32, ptr %[[X_ADDR]], align 4
// OGCG: switch i32 %[[X_VAL]], label %[[DEFAULT:.*]] [
// OGCG-DAG: i32 0, label %[[SW0:.*]]
// OGCG-DAG: i32 1, label %[[SW1:.*]]
// OGCG-DAG: i32 2, label %[[SW2:.*]]
// OGCG: ]
// OGCG: [[SW0]]:
// OGCG: br label %[[RETURN:.*]]
// OGCG: [[SW1]]:
// OGCG: br label %[[RETURN]]
// OGCG: [[SW2]]:
// OGCG: br label %[[RETURN]]
// OGCG: [[DEFAULT]]:
// OGCG: br label %[[RETURN]]
// OGCG: [[RETURN]]:
// OGCG: %[[RETVAL_LOAD:.*]] = load i32, ptr %[[RETVAL]], align 4
// OGCG: ret i32 %[[RETVAL_LOAD]]