Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions clang/include/clang/CIR/Dialect/Passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ std::unique_ptr<Pass> createCIRSimplifyPass();
std::unique_ptr<Pass> createHoistAllocasPass();
std::unique_ptr<Pass> createLoweringPreparePass();
std::unique_ptr<Pass> createLoweringPreparePass(clang::ASTContext *astCtx);
std::unique_ptr<Pass> createGotoSolverPass();

void populateCIRPreLoweringPasses(mlir::OpPassManager &pm);

Expand Down
10 changes: 10 additions & 0 deletions clang/include/clang/CIR/Dialect/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,16 @@ def CIRFlattenCFG : Pass<"cir-flatten-cfg"> {
let dependentDialects = ["cir::CIRDialect"];
}

def GotoSolver : Pass<"cir-goto-solver"> {
let summary = "Replaces goto operations with branches";
let description = [{
This pass transforms CIR and replaces goto-s with branch
operations to the proper blocks.
}];
let constructor = "mlir::createGotoSolverPass()";
let dependentDialects = ["cir::CIRDialect"];
}

def LoweringPrepare : Pass<"cir-lowering-prepare"> {
let summary = "Lower to more fine-grained CIR operations before lowering to "
"other dialects";
Expand Down
1 change: 1 addition & 0 deletions clang/lib/CIR/Dialect/Transforms/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ add_clang_library(MLIRCIRTransforms
FlattenCFG.cpp
HoistAllocas.cpp
LoweringPrepare.cpp
GotoSolver.cpp

DEPENDS
MLIRCIRPassIncGen
Expand Down
57 changes: 57 additions & 0 deletions clang/lib/CIR/Dialect/Transforms/GotoSolver.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
//====- GotoSolver.cpp -----------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "PassDetail.h"
Copy link
Member

@bcardosolopes bcardosolopes Aug 20, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You need the LLVM comment header boilerplate here!

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for catching that.

#include "clang/CIR/Dialect/IR/CIRDialect.h"
#include "clang/CIR/Dialect/Passes.h"
#include "llvm/Support/TimeProfiler.h"
#include <memory>

using namespace mlir;
using namespace cir;

namespace {

struct GotoSolverPass : public GotoSolverBase<GotoSolverPass> {
GotoSolverPass() = default;
void runOnOperation() override;
};

static void process(cir::FuncOp func) {
mlir::OpBuilder rewriter(func.getContext());
llvm::StringMap<Block *> labels;
llvm::SmallVector<cir::GotoOp, 4> gotos;

func.getBody().walk([&](mlir::Operation *op) {
if (auto lab = dyn_cast<cir::LabelOp>(op)) {
// Will construct a string copy inplace. Safely erase the label
labels.try_emplace(lab.getLabel(), lab->getBlock());
lab.erase();
} else if (auto goTo = dyn_cast<cir::GotoOp>(op)) {
gotos.push_back(goTo);
}
});

for (auto goTo : gotos) {
mlir::OpBuilder::InsertionGuard guard(rewriter);
rewriter.setInsertionPoint(goTo);
Block *dest = labels[goTo.getLabel()];
cir::BrOp::create(rewriter, goTo.getLoc(), dest);
goTo.erase();
}
}

void GotoSolverPass::runOnOperation() {
llvm::TimeTraceScope scope("Goto Solver");
getOperation()->walk(&process);
}

} // namespace

std::unique_ptr<Pass> mlir::createGotoSolverPass() {
return std::make_unique<GotoSolverPass>();
}
1 change: 1 addition & 0 deletions clang/lib/CIR/Lowering/CIRPasses.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ namespace mlir {
void populateCIRPreLoweringPasses(OpPassManager &pm) {
pm.addPass(createHoistAllocasPass());
pm.addPass(createCIRFlattenCFGPass());
pm.addPass(createGotoSolverPass());
}

} // namespace mlir
95 changes: 95 additions & 0 deletions clang/test/CIR/CodeGen/goto.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
// RUN: %clang_cc1 -triple x86_64-unknown-linux-gnu -fclangir -emit-cir %s -o %t.cir
// RUN: FileCheck --input-file=%t.cir %s -check-prefix=CIR
// RUN: %clang_cc1 -triple x86_64-unknown-linux-gnu -fclangir -emit-llvm %s -o %t-cir.ll
// RUN: FileCheck --input-file=%t-cir.ll %s --check-prefix=LLVM
// RUN: %clang_cc1 -triple x86_64-unknown-linux-gnu -emit-llvm %s -o %t.ll
// RUN: FileCheck --input-file=%t.ll %s --check-prefix=OGCG

Expand Down Expand Up @@ -27,6 +29,24 @@ int shouldNotGenBranchRet(int x) {
// CIR: cir.store [[MINUS]], [[RETVAL]] : !s32i, !cir.ptr<!s32i>
// CIR: cir.br ^bb1

// LLVM: define dso_local i32 @_Z21shouldNotGenBranchReti
// LLVM: [[COND:%.*]] = load i32, ptr {{.*}}, align 4
// LLVM: [[CMP:%.*]] = icmp sgt i32 [[COND]], 5
// LLVM: br i1 [[CMP]], label %[[IFTHEN:.*]], label %[[IFEND:.*]]
// LLVM: [[IFTHEN]]:
// LLVM: br label %[[ERR:.*]]
// LLVM: [[IFEND]]:
// LLVM: br label %[[BB9:.*]]
// LLVM: [[BB9]]:
// LLVM: store i32 0, ptr %[[RETVAL:.*]], align 4
// LLVM: br label %[[BBRET:.*]]
// LLVM: [[BBRET]]:
// LLVM: [[RET:%.*]] = load i32, ptr %[[RETVAL]], align 4
// LLVM: ret i32 [[RET]]
// LLVM: [[ERR]]:
// LLVM: store i32 -1, ptr %[[RETVAL]], align 4
// LLVM: br label %10

// OGCG: define dso_local noundef i32 @_Z21shouldNotGenBranchReti
// OGCG: if.then:
// OGCG: br label %err
Expand All @@ -51,6 +71,17 @@ int shouldGenBranch(int x) {
// CIR: ^bb1:
// CIR: cir.label "err"

// LLVM: define dso_local i32 @_Z15shouldGenBranchi
// LLVM: br i1 [[CMP:%.*]], label %[[IFTHEN:.*]], label %[[IFEND:.*]]
// LLVM: [[IFTHEN]]:
// LLVM: br label %[[ERR:.*]]
// LLVM: [[IFEND]]:
// LLVM: br label %[[BB9:.*]]
// LLVM: [[BB9]]:
// LLVM: br label %[[ERR]]
// LLVM: [[ERR]]:
// LLVM: ret i32 [[RET:%.*]]

// OGCG: define dso_local noundef i32 @_Z15shouldGenBranchi
// OGCG: if.then:
// OGCG: br label %err
Expand Down Expand Up @@ -78,6 +109,15 @@ void severalLabelsInARow(int a) {
// CIR: ^bb[[#BLK3]]:
// CIR: cir.label "end2"

// LLVM: define dso_local void @_Z19severalLabelsInARowi
// LLVM: br label %[[END1:.*]]
// LLVM: [[UNRE:.*]]: ; No predecessors!
// LLVM: br label %[[END2:.*]]
// LLVM: [[END1]]:
// LLVM: br label %[[END2]]
// LLVM: [[END2]]:
// LLVM: ret

// OGCG: define dso_local void @_Z19severalLabelsInARowi
// OGCG: br label %end1
// OGCG: end1:
Expand All @@ -99,6 +139,13 @@ void severalGotosInARow(int a) {
// CIR: ^bb[[#BLK2:]]:
// CIR: cir.label "end"

// LLVM: define dso_local void @_Z18severalGotosInARowi
// LLVM: br label %[[END:.*]]
// LLVM: [[UNRE:.*]]: ; No predecessors!
// LLVM: br label %[[END]]
// LLVM: [[END]]:
// LLVM: ret void

// OGCG: define dso_local void @_Z18severalGotosInARowi(i32 noundef %a) #0 {
// OGCG: br label %end
// OGCG: end:
Expand Down Expand Up @@ -126,6 +173,14 @@ extern "C" void multiple_non_case(int v) {
// CIR: cir.call @action2()
// CIR: cir.break

// LLVM: define dso_local void @multiple_non_case
// LLVM: [[SWDEFAULT:.*]]:
// LLVM: call void @action1()
// LLVM: br label %[[L2:.*]]
// LLVM: [[L2]]:
// LLVM: call void @action2()
// LLVM: br label %[[BREAK:.*]]

// OGCG: define dso_local void @multiple_non_case
// OGCG: sw.default:
// OGCG: call void @action1()
Expand Down Expand Up @@ -158,6 +213,26 @@ extern "C" void case_follow_label(int v) {
// CIR: cir.call @action2()
// CIR: cir.goto "label"

// LLVM: define dso_local void @case_follow_label
// LLVM: switch i32 {{.*}}, label %[[SWDEFAULT:.*]] [
// LLVM: i32 1, label %[[LABEL:.*]]
// LLVM: i32 2, label %[[CASE2:.*]]
// LLVM: ]
// LLVM: [[LABEL]]:
// LLVM: br label %[[CASE2]]
// LLVM: [[CASE2]]:
// LLVM: call void @action1()
// LLVM: br label %[[BREAK:.*]]
// LLVM: [[BREAK]]:
// LLVM: br label %[[END:.*]]
// LLVM: [[SWDEFAULT]]:
// LLVM: call void @action2()
// LLVM: br label %[[LABEL]]
// LLVM: [[END]]:
// LLVM: br label %[[RET:.*]]
// LLVM: [[RET]]:
// LLVM: ret void

// OGCG: define dso_local void @case_follow_label
// OGCG: sw.bb:
// OGCG: br label %label
Expand Down Expand Up @@ -197,6 +272,26 @@ extern "C" void default_follow_label(int v) {
// CIR: cir.call @action2()
// CIR: cir.goto "label"

// LLVM: define dso_local void @default_follow_label
// LLVM: [[CASE1:.*]]:
// LLVM: br label %[[BB8:.*]]
// LLVM: [[BB8]]:
// LLVM: br label %[[CASE2:.*]]
// LLVM: [[CASE2]]:
// LLVM: call void @action1()
// LLVM: br label %[[BREAK:.*]]
// LLVM: [[LABEL:.*]]:
// LLVM: br label %[[SWDEFAULT:.*]]
// LLVM: [[SWDEFAULT]]:
// LLVM: call void @action2()
// LLVM: br label %[[BB9:.*]]
// LLVM: [[BB9]]:
// LLVM: br label %[[LABEL]]
// LLVM: [[BREAK]]:
// LLVM: br label %[[RET:.*]]
// LLVM: [[RET]]:
// LLVM: ret void

// OGCG: define dso_local void @default_follow_label
// OGCG: sw.bb:
// OGCG: call void @action1()
Expand Down
51 changes: 49 additions & 2 deletions clang/test/CIR/CodeGen/label.c
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
// RUN: %clang_cc1 -triple x86_64-unknown-linux-gnu -fclangir -emit-cir %s -o %t.cir
// RUN: FileCheck --input-file=%t.cir %s --check-prefix=CIR
// RUN: %clang_cc1 -triple x86_64-unknown-linux-gnu -fclangir -emit-llvm %s -o %t-cir.ll
// RUN: FileCheck --input-file=%t-cir.ll %s --check-prefix=LLVM
// RUN: %clang_cc1 -triple x86_64-unknown-linux-gnu -emit-llvm %s -o %t.ll
// RUN: FileCheck --input-file=%t.ll %s --check-prefix=OGCG

Expand All @@ -12,8 +14,8 @@ void label() {
// CIR: cir.label "labelA"
// CIR: cir.return

// Note: We are not lowering to LLVM IR via CIR at this stage because that
// process depends on the GotoSolver.
// LLVM:define dso_local void @label
// LLVM: ret void

// OGCG: define dso_local void @label
// OGCG: br label %labelA
Expand All @@ -33,6 +35,11 @@ void multiple_labels() {
// CIR: cir.label "labelC"
// CIR: cir.return

// LLVM: define dso_local void @multiple_labels()
// LLVM: br label %1
// LLVM: 1:
// LLVM: ret void

// OGCG: define dso_local void @multiple_labels
// OGCG: br label %labelB
// OGCG: labelB:
Expand All @@ -56,6 +63,22 @@ void label_in_if(int cond) {
// CIR: }
// CIR: cir.return

// LLVM: define dso_local void @label_in_if
// LLVM: br label %3
// LLVM: 3:
// LLVM: [[LOAD:%.*]] = load i32, ptr [[COND:%.*]], align 4
// LLVM: [[CMP:%.*]] = icmp ne i32 [[LOAD]], 0
// LLVM: br i1 [[CMP]], label %6, label %9
// LLVM: 6:
// LLVM: [[LOAD2:%.*]] = load i32, ptr [[COND]], align 4
// LLVM: [[ADD1:%.*]] = add nsw i32 [[LOAD2]], 1
// LLVM: store i32 [[ADD1]], ptr [[COND]], align 4
// LLVM: br label %9
// LLVM: 9:
// LLVM: br label %10
// LLVM: 10:
// LLVM: ret void

// OGCG: define dso_local void @label_in_if
// OGCG: if.then:
// OGCG: br label %labelD
Expand All @@ -80,6 +103,13 @@ void after_return() {
// CIR: cir.label "label"
// CIR: cir.br ^bb1

// LLVM: define dso_local void @after_return
// LLVM: br label %1
// LLVM: 1:
// LLVM: ret void
// LLVM: 2:
// LLVM: br label %1

// OGCG: define dso_local void @after_return
// OGCG: br label %label
// OGCG: label:
Expand All @@ -97,6 +127,11 @@ void after_unreachable() {
// CIR: cir.label "label"
// CIR: cir.return

// LLVM: define dso_local void @after_unreachable
// LLVM: unreachable
// LLVM: 1:
// LLVM: ret void

// OGCG: define dso_local void @after_unreachable
// OGCG: unreachable
// OGCG: label:
Expand All @@ -111,6 +146,9 @@ void labelWithoutMatch() {
// CIR: cir.return
// CIR: }

// LLVM: define dso_local void @labelWithoutMatch
// LLVM: ret void

// OGCG: define dso_local void @labelWithoutMatch
// OGCG: br label %end
// OGCG: end:
Expand All @@ -132,6 +170,15 @@ void foo() {
// CIR: cir.label "label"
// CIR: %0 = cir.alloca !rec_S, !cir.ptr<!rec_S>, ["agg.tmp0"]

// LLVM:define dso_local void @foo() {
// LLVM: [[ALLOC:%.*]] = alloca %struct.S, i64 1, align 1
// LLVM: br label %2
// LLVM:2:
// LLVM: [[CALL:%.*]] = call %struct.S @get()
// LLVM: store %struct.S [[CALL]], ptr [[ALLOC]], align 1
// LLVM: [[LOAD:%.*]] = load %struct.S, ptr [[ALLOC]], align 1
// LLVM: call void @bar(%struct.S [[LOAD]])

// OGCG: define dso_local void @foo()
// OGCG: %agg.tmp = alloca %struct.S, align 1
// OGCG: %undef.agg.tmp = alloca %struct.S, align 1
Expand Down
Loading