Skip to content

Commit fc62990

Browse files
[CIR] Upstream GotoSolver pass (#154596)
This PR upstreams the GotoSolver pass. It works by walking the function and matching each label to a goto. If a label is not matched to a goto, it is removed and not lowered.
1 parent 3923adf commit fc62990

File tree

8 files changed

+266
-2
lines changed

8 files changed

+266
-2
lines changed

clang/include/clang/CIR/Dialect/Passes.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ std::unique_ptr<Pass> createCIRSimplifyPass();
2626
std::unique_ptr<Pass> createHoistAllocasPass();
2727
std::unique_ptr<Pass> createLoweringPreparePass();
2828
std::unique_ptr<Pass> createLoweringPreparePass(clang::ASTContext *astCtx);
29+
std::unique_ptr<Pass> createGotoSolverPass();
2930

3031
void populateCIRPreLoweringPasses(mlir::OpPassManager &pm);
3132

clang/include/clang/CIR/Dialect/Passes.td

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,16 @@ def CIRFlattenCFG : Pass<"cir-flatten-cfg"> {
7272
let dependentDialects = ["cir::CIRDialect"];
7373
}
7474

75+
def GotoSolver : Pass<"cir-goto-solver"> {
76+
let summary = "Replaces goto operations with branches";
77+
let description = [{
78+
This pass transforms CIR and replaces goto-s with branch
79+
operations to the proper blocks.
80+
}];
81+
let constructor = "mlir::createGotoSolverPass()";
82+
let dependentDialects = ["cir::CIRDialect"];
83+
}
84+
7585
def LoweringPrepare : Pass<"cir-lowering-prepare"> {
7686
let summary = "Lower to more fine-grained CIR operations before lowering to "
7787
"other dialects";

clang/lib/CIR/Dialect/Transforms/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ add_clang_library(MLIRCIRTransforms
44
FlattenCFG.cpp
55
HoistAllocas.cpp
66
LoweringPrepare.cpp
7+
GotoSolver.cpp
78

89
DEPENDS
910
MLIRCIRPassIncGen
Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
//====- GotoSolver.cpp -----------------------------------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
#include "PassDetail.h"
9+
#include "clang/CIR/Dialect/IR/CIRDialect.h"
10+
#include "clang/CIR/Dialect/Passes.h"
11+
#include "llvm/Support/TimeProfiler.h"
12+
#include <memory>
13+
14+
using namespace mlir;
15+
using namespace cir;
16+
17+
namespace {
18+
19+
struct GotoSolverPass : public GotoSolverBase<GotoSolverPass> {
20+
GotoSolverPass() = default;
21+
void runOnOperation() override;
22+
};
23+
24+
static void process(cir::FuncOp func) {
25+
mlir::OpBuilder rewriter(func.getContext());
26+
llvm::StringMap<Block *> labels;
27+
llvm::SmallVector<cir::GotoOp, 4> gotos;
28+
29+
func.getBody().walk([&](mlir::Operation *op) {
30+
if (auto lab = dyn_cast<cir::LabelOp>(op)) {
31+
// Will construct a string copy inplace. Safely erase the label
32+
labels.try_emplace(lab.getLabel(), lab->getBlock());
33+
lab.erase();
34+
} else if (auto goTo = dyn_cast<cir::GotoOp>(op)) {
35+
gotos.push_back(goTo);
36+
}
37+
});
38+
39+
for (auto goTo : gotos) {
40+
mlir::OpBuilder::InsertionGuard guard(rewriter);
41+
rewriter.setInsertionPoint(goTo);
42+
Block *dest = labels[goTo.getLabel()];
43+
cir::BrOp::create(rewriter, goTo.getLoc(), dest);
44+
goTo.erase();
45+
}
46+
}
47+
48+
void GotoSolverPass::runOnOperation() {
49+
llvm::TimeTraceScope scope("Goto Solver");
50+
getOperation()->walk(&process);
51+
}
52+
53+
} // namespace
54+
55+
std::unique_ptr<Pass> mlir::createGotoSolverPass() {
56+
return std::make_unique<GotoSolverPass>();
57+
}

clang/lib/CIR/Lowering/CIRPasses.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ namespace mlir {
4545
void populateCIRPreLoweringPasses(OpPassManager &pm) {
4646
pm.addPass(createHoistAllocasPass());
4747
pm.addPass(createCIRFlattenCFGPass());
48+
pm.addPass(createGotoSolverPass());
4849
}
4950

5051
} // namespace mlir

clang/test/CIR/CodeGen/goto.cpp

Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
// RUN: %clang_cc1 -triple x86_64-unknown-linux-gnu -fclangir -emit-cir %s -o %t.cir
22
// RUN: FileCheck --input-file=%t.cir %s -check-prefix=CIR
3+
// RUN: %clang_cc1 -triple x86_64-unknown-linux-gnu -fclangir -emit-llvm %s -o %t-cir.ll
4+
// RUN: FileCheck --input-file=%t-cir.ll %s --check-prefix=LLVM
35
// RUN: %clang_cc1 -triple x86_64-unknown-linux-gnu -emit-llvm %s -o %t.ll
46
// RUN: FileCheck --input-file=%t.ll %s --check-prefix=OGCG
57

@@ -27,6 +29,24 @@ int shouldNotGenBranchRet(int x) {
2729
// CIR: cir.store [[MINUS]], [[RETVAL]] : !s32i, !cir.ptr<!s32i>
2830
// CIR: cir.br ^bb1
2931

32+
// LLVM: define dso_local i32 @_Z21shouldNotGenBranchReti
33+
// LLVM: [[COND:%.*]] = load i32, ptr {{.*}}, align 4
34+
// LLVM: [[CMP:%.*]] = icmp sgt i32 [[COND]], 5
35+
// LLVM: br i1 [[CMP]], label %[[IFTHEN:.*]], label %[[IFEND:.*]]
36+
// LLVM: [[IFTHEN]]:
37+
// LLVM: br label %[[ERR:.*]]
38+
// LLVM: [[IFEND]]:
39+
// LLVM: br label %[[BB9:.*]]
40+
// LLVM: [[BB9]]:
41+
// LLVM: store i32 0, ptr %[[RETVAL:.*]], align 4
42+
// LLVM: br label %[[BBRET:.*]]
43+
// LLVM: [[BBRET]]:
44+
// LLVM: [[RET:%.*]] = load i32, ptr %[[RETVAL]], align 4
45+
// LLVM: ret i32 [[RET]]
46+
// LLVM: [[ERR]]:
47+
// LLVM: store i32 -1, ptr %[[RETVAL]], align 4
48+
// LLVM: br label %10
49+
3050
// OGCG: define dso_local noundef i32 @_Z21shouldNotGenBranchReti
3151
// OGCG: if.then:
3252
// OGCG: br label %err
@@ -51,6 +71,17 @@ int shouldGenBranch(int x) {
5171
// CIR: ^bb1:
5272
// CIR: cir.label "err"
5373

74+
// LLVM: define dso_local i32 @_Z15shouldGenBranchi
75+
// LLVM: br i1 [[CMP:%.*]], label %[[IFTHEN:.*]], label %[[IFEND:.*]]
76+
// LLVM: [[IFTHEN]]:
77+
// LLVM: br label %[[ERR:.*]]
78+
// LLVM: [[IFEND]]:
79+
// LLVM: br label %[[BB9:.*]]
80+
// LLVM: [[BB9]]:
81+
// LLVM: br label %[[ERR]]
82+
// LLVM: [[ERR]]:
83+
// LLVM: ret i32 [[RET:%.*]]
84+
5485
// OGCG: define dso_local noundef i32 @_Z15shouldGenBranchi
5586
// OGCG: if.then:
5687
// OGCG: br label %err
@@ -78,6 +109,15 @@ void severalLabelsInARow(int a) {
78109
// CIR: ^bb[[#BLK3]]:
79110
// CIR: cir.label "end2"
80111

112+
// LLVM: define dso_local void @_Z19severalLabelsInARowi
113+
// LLVM: br label %[[END1:.*]]
114+
// LLVM: [[UNRE:.*]]: ; No predecessors!
115+
// LLVM: br label %[[END2:.*]]
116+
// LLVM: [[END1]]:
117+
// LLVM: br label %[[END2]]
118+
// LLVM: [[END2]]:
119+
// LLVM: ret
120+
81121
// OGCG: define dso_local void @_Z19severalLabelsInARowi
82122
// OGCG: br label %end1
83123
// OGCG: end1:
@@ -99,6 +139,13 @@ void severalGotosInARow(int a) {
99139
// CIR: ^bb[[#BLK2:]]:
100140
// CIR: cir.label "end"
101141

142+
// LLVM: define dso_local void @_Z18severalGotosInARowi
143+
// LLVM: br label %[[END:.*]]
144+
// LLVM: [[UNRE:.*]]: ; No predecessors!
145+
// LLVM: br label %[[END]]
146+
// LLVM: [[END]]:
147+
// LLVM: ret void
148+
102149
// OGCG: define dso_local void @_Z18severalGotosInARowi(i32 noundef %a) #0 {
103150
// OGCG: br label %end
104151
// OGCG: end:
@@ -126,6 +173,14 @@ extern "C" void multiple_non_case(int v) {
126173
// CIR: cir.call @action2()
127174
// CIR: cir.break
128175

176+
// LLVM: define dso_local void @multiple_non_case
177+
// LLVM: [[SWDEFAULT:.*]]:
178+
// LLVM: call void @action1()
179+
// LLVM: br label %[[L2:.*]]
180+
// LLVM: [[L2]]:
181+
// LLVM: call void @action2()
182+
// LLVM: br label %[[BREAK:.*]]
183+
129184
// OGCG: define dso_local void @multiple_non_case
130185
// OGCG: sw.default:
131186
// OGCG: call void @action1()
@@ -158,6 +213,26 @@ extern "C" void case_follow_label(int v) {
158213
// CIR: cir.call @action2()
159214
// CIR: cir.goto "label"
160215

216+
// LLVM: define dso_local void @case_follow_label
217+
// LLVM: switch i32 {{.*}}, label %[[SWDEFAULT:.*]] [
218+
// LLVM: i32 1, label %[[LABEL:.*]]
219+
// LLVM: i32 2, label %[[CASE2:.*]]
220+
// LLVM: ]
221+
// LLVM: [[LABEL]]:
222+
// LLVM: br label %[[CASE2]]
223+
// LLVM: [[CASE2]]:
224+
// LLVM: call void @action1()
225+
// LLVM: br label %[[BREAK:.*]]
226+
// LLVM: [[BREAK]]:
227+
// LLVM: br label %[[END:.*]]
228+
// LLVM: [[SWDEFAULT]]:
229+
// LLVM: call void @action2()
230+
// LLVM: br label %[[LABEL]]
231+
// LLVM: [[END]]:
232+
// LLVM: br label %[[RET:.*]]
233+
// LLVM: [[RET]]:
234+
// LLVM: ret void
235+
161236
// OGCG: define dso_local void @case_follow_label
162237
// OGCG: sw.bb:
163238
// OGCG: br label %label
@@ -197,6 +272,26 @@ extern "C" void default_follow_label(int v) {
197272
// CIR: cir.call @action2()
198273
// CIR: cir.goto "label"
199274

275+
// LLVM: define dso_local void @default_follow_label
276+
// LLVM: [[CASE1:.*]]:
277+
// LLVM: br label %[[BB8:.*]]
278+
// LLVM: [[BB8]]:
279+
// LLVM: br label %[[CASE2:.*]]
280+
// LLVM: [[CASE2]]:
281+
// LLVM: call void @action1()
282+
// LLVM: br label %[[BREAK:.*]]
283+
// LLVM: [[LABEL:.*]]:
284+
// LLVM: br label %[[SWDEFAULT:.*]]
285+
// LLVM: [[SWDEFAULT]]:
286+
// LLVM: call void @action2()
287+
// LLVM: br label %[[BB9:.*]]
288+
// LLVM: [[BB9]]:
289+
// LLVM: br label %[[LABEL]]
290+
// LLVM: [[BREAK]]:
291+
// LLVM: br label %[[RET:.*]]
292+
// LLVM: [[RET]]:
293+
// LLVM: ret void
294+
200295
// OGCG: define dso_local void @default_follow_label
201296
// OGCG: sw.bb:
202297
// OGCG: call void @action1()

clang/test/CIR/CodeGen/label.c

Lines changed: 49 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
// RUN: %clang_cc1 -triple x86_64-unknown-linux-gnu -fclangir -emit-cir %s -o %t.cir
22
// RUN: FileCheck --input-file=%t.cir %s --check-prefix=CIR
3+
// RUN: %clang_cc1 -triple x86_64-unknown-linux-gnu -fclangir -emit-llvm %s -o %t-cir.ll
4+
// RUN: FileCheck --input-file=%t-cir.ll %s --check-prefix=LLVM
35
// RUN: %clang_cc1 -triple x86_64-unknown-linux-gnu -emit-llvm %s -o %t.ll
46
// RUN: FileCheck --input-file=%t.ll %s --check-prefix=OGCG
57

@@ -12,8 +14,8 @@ void label() {
1214
// CIR: cir.label "labelA"
1315
// CIR: cir.return
1416

15-
// Note: We are not lowering to LLVM IR via CIR at this stage because that
16-
// process depends on the GotoSolver.
17+
// LLVM:define dso_local void @label
18+
// LLVM: ret void
1719

1820
// OGCG: define dso_local void @label
1921
// OGCG: br label %labelA
@@ -33,6 +35,11 @@ void multiple_labels() {
3335
// CIR: cir.label "labelC"
3436
// CIR: cir.return
3537

38+
// LLVM: define dso_local void @multiple_labels()
39+
// LLVM: br label %1
40+
// LLVM: 1:
41+
// LLVM: ret void
42+
3643
// OGCG: define dso_local void @multiple_labels
3744
// OGCG: br label %labelB
3845
// OGCG: labelB:
@@ -56,6 +63,22 @@ void label_in_if(int cond) {
5663
// CIR: }
5764
// CIR: cir.return
5865

66+
// LLVM: define dso_local void @label_in_if
67+
// LLVM: br label %3
68+
// LLVM: 3:
69+
// LLVM: [[LOAD:%.*]] = load i32, ptr [[COND:%.*]], align 4
70+
// LLVM: [[CMP:%.*]] = icmp ne i32 [[LOAD]], 0
71+
// LLVM: br i1 [[CMP]], label %6, label %9
72+
// LLVM: 6:
73+
// LLVM: [[LOAD2:%.*]] = load i32, ptr [[COND]], align 4
74+
// LLVM: [[ADD1:%.*]] = add nsw i32 [[LOAD2]], 1
75+
// LLVM: store i32 [[ADD1]], ptr [[COND]], align 4
76+
// LLVM: br label %9
77+
// LLVM: 9:
78+
// LLVM: br label %10
79+
// LLVM: 10:
80+
// LLVM: ret void
81+
5982
// OGCG: define dso_local void @label_in_if
6083
// OGCG: if.then:
6184
// OGCG: br label %labelD
@@ -80,6 +103,13 @@ void after_return() {
80103
// CIR: cir.label "label"
81104
// CIR: cir.br ^bb1
82105

106+
// LLVM: define dso_local void @after_return
107+
// LLVM: br label %1
108+
// LLVM: 1:
109+
// LLVM: ret void
110+
// LLVM: 2:
111+
// LLVM: br label %1
112+
83113
// OGCG: define dso_local void @after_return
84114
// OGCG: br label %label
85115
// OGCG: label:
@@ -97,6 +127,11 @@ void after_unreachable() {
97127
// CIR: cir.label "label"
98128
// CIR: cir.return
99129

130+
// LLVM: define dso_local void @after_unreachable
131+
// LLVM: unreachable
132+
// LLVM: 1:
133+
// LLVM: ret void
134+
100135
// OGCG: define dso_local void @after_unreachable
101136
// OGCG: unreachable
102137
// OGCG: label:
@@ -111,6 +146,9 @@ void labelWithoutMatch() {
111146
// CIR: cir.return
112147
// CIR: }
113148

149+
// LLVM: define dso_local void @labelWithoutMatch
150+
// LLVM: ret void
151+
114152
// OGCG: define dso_local void @labelWithoutMatch
115153
// OGCG: br label %end
116154
// OGCG: end:
@@ -132,6 +170,15 @@ void foo() {
132170
// CIR: cir.label "label"
133171
// CIR: %0 = cir.alloca !rec_S, !cir.ptr<!rec_S>, ["agg.tmp0"]
134172

173+
// LLVM:define dso_local void @foo() {
174+
// LLVM: [[ALLOC:%.*]] = alloca %struct.S, i64 1, align 1
175+
// LLVM: br label %2
176+
// LLVM:2:
177+
// LLVM: [[CALL:%.*]] = call %struct.S @get()
178+
// LLVM: store %struct.S [[CALL]], ptr [[ALLOC]], align 1
179+
// LLVM: [[LOAD:%.*]] = load %struct.S, ptr [[ALLOC]], align 1
180+
// LLVM: call void @bar(%struct.S [[LOAD]])
181+
135182
// OGCG: define dso_local void @foo()
136183
// OGCG: %agg.tmp = alloca %struct.S, align 1
137184
// OGCG: %undef.agg.tmp = alloca %struct.S, align 1

0 commit comments

Comments
 (0)