-
Notifications
You must be signed in to change notification settings - Fork 15.4k
[CIR] Upstream GotoSolver pass #154596
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[CIR] Upstream GotoSolver pass #154596
Conversation
|
@llvm/pr-subscribers-clang @llvm/pr-subscribers-clangir Author: None (Andres-Salamanca) ChangesThis PR upstreams the GotoSolver pass. Full diff: https://github.com/llvm/llvm-project/pull/154596.diff 8 Files Affected:
diff --git a/clang/include/clang/CIR/Dialect/Passes.h b/clang/include/clang/CIR/Dialect/Passes.h
index 7a202b1e04ef9..32c3e27d07dfb 100644
--- a/clang/include/clang/CIR/Dialect/Passes.h
+++ b/clang/include/clang/CIR/Dialect/Passes.h
@@ -26,6 +26,7 @@ std::unique_ptr<Pass> createCIRSimplifyPass();
std::unique_ptr<Pass> createHoistAllocasPass();
std::unique_ptr<Pass> createLoweringPreparePass();
std::unique_ptr<Pass> createLoweringPreparePass(clang::ASTContext *astCtx);
+std::unique_ptr<Pass> createGotoSolverPass();
void populateCIRPreLoweringPasses(mlir::OpPassManager &pm);
diff --git a/clang/include/clang/CIR/Dialect/Passes.td b/clang/include/clang/CIR/Dialect/Passes.td
index 7d5ec2ffed39d..0f5783945f8ae 100644
--- a/clang/include/clang/CIR/Dialect/Passes.td
+++ b/clang/include/clang/CIR/Dialect/Passes.td
@@ -72,6 +72,16 @@ def CIRFlattenCFG : Pass<"cir-flatten-cfg"> {
let dependentDialects = ["cir::CIRDialect"];
}
+def GotoSolver : Pass<"cir-goto-solver"> {
+ let summary = "Replaces goto operations with branches";
+ let description = [{
+ This pass transforms CIR and replaces goto-s with branch
+ operations to the proper blocks.
+ }];
+ let constructor = "mlir::createGotoSolverPass()";
+ let dependentDialects = ["cir::CIRDialect"];
+}
+
def LoweringPrepare : Pass<"cir-lowering-prepare"> {
let summary = "Lower to more fine-grained CIR operations before lowering to "
"other dialects";
diff --git a/clang/lib/CIR/Dialect/Transforms/CMakeLists.txt b/clang/lib/CIR/Dialect/Transforms/CMakeLists.txt
index 18beca7b9a680..df7a1a3e0acb5 100644
--- a/clang/lib/CIR/Dialect/Transforms/CMakeLists.txt
+++ b/clang/lib/CIR/Dialect/Transforms/CMakeLists.txt
@@ -4,6 +4,7 @@ add_clang_library(MLIRCIRTransforms
FlattenCFG.cpp
HoistAllocas.cpp
LoweringPrepare.cpp
+ GotoSolver.cpp
DEPENDS
MLIRCIRPassIncGen
diff --git a/clang/lib/CIR/Dialect/Transforms/GotoSolver.cpp b/clang/lib/CIR/Dialect/Transforms/GotoSolver.cpp
new file mode 100644
index 0000000000000..e1c47a1ce16f1
--- /dev/null
+++ b/clang/lib/CIR/Dialect/Transforms/GotoSolver.cpp
@@ -0,0 +1,52 @@
+#include "PassDetail.h"
+#include "clang/CIR/Dialect/IR/CIRDialect.h"
+#include "clang/CIR/Dialect/Passes.h"
+#include "llvm/Support/TimeProfiler.h"
+#include <memory>
+
+using namespace mlir;
+using namespace cir;
+
+namespace {
+
+struct GotoSolverPass : public GotoSolverBase<GotoSolverPass> {
+
+ GotoSolverPass() = default;
+ void runOnOperation() override;
+};
+
+static void process(cir::FuncOp func) {
+
+ mlir::OpBuilder rewriter(func.getContext());
+ llvm::StringMap<Block *> labels;
+ llvm::SmallVector<cir::GotoOp, 4> gotos;
+
+ func.getBody().walk([&](mlir::Operation *op) {
+ if (auto lab = dyn_cast<cir::LabelOp>(op)) {
+ // Will construct a string copy inplace. Safely erase the label
+ labels.try_emplace(lab.getLabel(), lab->getBlock());
+ lab.erase();
+ } else if (auto goTo = dyn_cast<cir::GotoOp>(op)) {
+ gotos.push_back(goTo);
+ }
+ });
+
+ for (auto goTo : gotos) {
+ mlir::OpBuilder::InsertionGuard guard(rewriter);
+ rewriter.setInsertionPoint(goTo);
+ Block *dest = labels[goTo.getLabel()];
+ rewriter.create<cir::BrOp>(goTo.getLoc(), dest);
+ goTo.erase();
+ }
+}
+
+void GotoSolverPass::runOnOperation() {
+ llvm::TimeTraceScope scope("Goto Solver");
+ getOperation()->walk([&](cir::FuncOp op) { process(op); });
+}
+
+} // namespace
+
+std::unique_ptr<Pass> mlir::createGotoSolverPass() {
+ return std::make_unique<GotoSolverPass>();
+}
diff --git a/clang/lib/CIR/Lowering/CIRPasses.cpp b/clang/lib/CIR/Lowering/CIRPasses.cpp
index bb9781be897eb..ccc838717e421 100644
--- a/clang/lib/CIR/Lowering/CIRPasses.cpp
+++ b/clang/lib/CIR/Lowering/CIRPasses.cpp
@@ -45,6 +45,7 @@ namespace mlir {
void populateCIRPreLoweringPasses(OpPassManager &pm) {
pm.addPass(createHoistAllocasPass());
pm.addPass(createCIRFlattenCFGPass());
+ pm.addPass(createGotoSolverPass());
}
} // namespace mlir
diff --git a/clang/test/CIR/CodeGen/goto.cpp b/clang/test/CIR/CodeGen/goto.cpp
index 13ca65344a150..48cb44ed0f478 100644
--- a/clang/test/CIR/CodeGen/goto.cpp
+++ b/clang/test/CIR/CodeGen/goto.cpp
@@ -1,5 +1,7 @@
// RUN: %clang_cc1 -triple x86_64-unknown-linux-gnu -fclangir -emit-cir %s -o %t.cir
// RUN: FileCheck --input-file=%t.cir %s -check-prefix=CIR
+// RUN: %clang_cc1 -triple x86_64-unknown-linux-gnu -fclangir -emit-llvm %s -o %t-cir.ll
+// RUN: FileCheck --input-file=%t-cir.ll %s --check-prefix=LLVM
// RUN: %clang_cc1 -triple x86_64-unknown-linux-gnu -emit-llvm %s -o %t.ll
// RUN: FileCheck --input-file=%t.ll %s --check-prefix=OGCG
@@ -27,6 +29,24 @@ int shouldNotGenBranchRet(int x) {
// CIR: cir.store [[MINUS]], [[RETVAL]] : !s32i, !cir.ptr<!s32i>
// CIR: cir.br ^bb1
+// LLVM: define dso_local i32 @_Z21shouldNotGenBranchReti
+// LLVM: [[COND:%.*]] = load i32, ptr {{.*}}, align 4
+// LLVM: [[CMP:%.*]] = icmp sgt i32 [[COND]], 5
+// LLVM: br i1 [[CMP]], label %[[IFTHEN:.*]], label %[[IFEND:.*]]
+// LLVM: [[IFTHEN]]:
+// LLVM: br label %[[ERR:.*]]
+// LLVM: [[IFEND]]:
+// LLVM: br label %[[BB9:.*]]
+// LLVM: [[BB9]]:
+// LLVM: store i32 0, ptr %[[RETVAL:.*]], align 4
+// LLVM: br label %[[BBRET:.*]]
+// LLVM: [[BBRET]]:
+// LLVM: [[RET:%.*]] = load i32, ptr %[[RETVAL]], align 4
+// LLVM: ret i32 [[RET]]
+// LLVM: [[ERR]]:
+// LLVM: store i32 -1, ptr %[[RETVAL]], align 4
+// LLVM: br label %10
+
// OGCG: define dso_local noundef i32 @_Z21shouldNotGenBranchReti
// OGCG: if.then:
// OGCG: br label %err
@@ -51,6 +71,17 @@ int shouldGenBranch(int x) {
// CIR: ^bb1:
// CIR: cir.label "err"
+// LLVM: define dso_local i32 @_Z15shouldGenBranchi
+// LLVM: br i1 [[CMP:%.*]], label %[[IFTHEN:.*]], label %[[IFEND:.*]]
+// LLVM: [[IFTHEN]]:
+// LLVM: br label %[[ERR:.*]]
+// LLVM: [[IFEND]]:
+// LLVM: br label %[[BB9:.*]]
+// LLVM: [[BB9]]:
+// LLVM: br label %[[ERR]]
+// LLVM: [[ERR]]:
+// LLVM: ret i32 [[RET:%.*]]
+
// OGCG: define dso_local noundef i32 @_Z15shouldGenBranchi
// OGCG: if.then:
// OGCG: br label %err
@@ -78,6 +109,15 @@ void severalLabelsInARow(int a) {
// CIR: ^bb[[#BLK3]]:
// CIR: cir.label "end2"
+// LLVM: define dso_local void @_Z19severalLabelsInARowi
+// LLVM: br label %[[END1:.*]]
+// LLVM: [[UNRE:.*]]: ; No predecessors!
+// LLVM: br label %[[END2:.*]]
+// LLVM: [[END1]]:
+// LLVM: br label %[[END2]]
+// LLVM: [[END2]]:
+// LLVM: ret
+
// OGCG: define dso_local void @_Z19severalLabelsInARowi
// OGCG: br label %end1
// OGCG: end1:
@@ -99,6 +139,13 @@ void severalGotosInARow(int a) {
// CIR: ^bb[[#BLK2:]]:
// CIR: cir.label "end"
+// LLVM: define dso_local void @_Z18severalGotosInARowi
+// LLVM: br label %[[END:.*]]
+// LLVM: [[UNRE:.*]]: ; No predecessors!
+// LLVM: br label %[[END]]
+// LLVM: [[END]]:
+// LLVM: ret void
+
// OGCG: define dso_local void @_Z18severalGotosInARowi(i32 noundef %a) #0 {
// OGCG: br label %end
// OGCG: end:
@@ -126,6 +173,14 @@ extern "C" void multiple_non_case(int v) {
// CIR: cir.call @action2()
// CIR: cir.break
+// LLVM: define dso_local void @multiple_non_case
+// LLVM: [[SWDEFAULT:.*]]:
+// LLVM: call void @action1()
+// LLVM: br label %[[L2:.*]]
+// LLVM: [[L2]]:
+// LLVM: call void @action2()
+// LLVM: br label %[[BREAK:.*]]
+
// OGCG: define dso_local void @multiple_non_case
// OGCG: sw.default:
// OGCG: call void @action1()
@@ -158,6 +213,26 @@ extern "C" void case_follow_label(int v) {
// CIR: cir.call @action2()
// CIR: cir.goto "label"
+// LLVM: define dso_local void @case_follow_label
+// LLVM: switch i32 {{.*}}, label %[[SWDEFAULT:.*]] [
+// LLVM: i32 1, label %[[LABEL:.*]]
+// LLVM: i32 2, label %[[CASE2:.*]]
+// LLVM: ]
+// LLVM: [[LABEL]]:
+// LLVM: br label %[[CASE2]]
+// LLVM: [[CASE2]]:
+// LLVM: call void @action1()
+// LLVM: br label %[[BREAK:.*]]
+// LLVM: [[BREAK]]:
+// LLVM: br label %[[END:.*]]
+// LLVM: [[SWDEFAULT]]:
+// LLVM: call void @action2()
+// LLVM: br label %[[LABEL]]
+// LLVM: [[END]]:
+// LLVM: br label %[[RET:.*]]
+// LLVM: [[RET]]:
+// LLVM: ret void
+
// OGCG: define dso_local void @case_follow_label
// OGCG: sw.bb:
// OGCG: br label %label
@@ -197,6 +272,26 @@ extern "C" void default_follow_label(int v) {
// CIR: cir.call @action2()
// CIR: cir.goto "label"
+// LLVM: define dso_local void @default_follow_label
+// LLVM: [[CASE1:.*]]:
+// LLVM: br label %[[BB8:.*]]
+// LLVM: [[BB8]]:
+// LLVM: br label %[[CASE2:.*]]
+// LLVM: [[CASE2]]:
+// LLVM: call void @action1()
+// LLVM: br label %[[BREAK:.*]]
+// LLVM: [[LABEL:.*]]:
+// LLVM: br label %[[SWDEFAULT:.*]]
+// LLVM: [[SWDEFAULT]]:
+// LLVM: call void @action2()
+// LLVM: br label %[[BB9:.*]]
+// LLVM: [[BB9]]:
+// LLVM: br label %[[LABEL]]
+// LLVM: [[BREAK]]:
+// LLVM: br label %[[RET:.*]]
+// LLVM: [[RET]]:
+// LLVM: ret void
+
// OGCG: define dso_local void @default_follow_label
// OGCG: sw.bb:
// OGCG: call void @action1()
diff --git a/clang/test/CIR/CodeGen/label.c b/clang/test/CIR/CodeGen/label.c
index 797c44475a621..a050094de678b 100644
--- a/clang/test/CIR/CodeGen/label.c
+++ b/clang/test/CIR/CodeGen/label.c
@@ -1,5 +1,7 @@
// RUN: %clang_cc1 -triple x86_64-unknown-linux-gnu -fclangir -emit-cir %s -o %t.cir
// RUN: FileCheck --input-file=%t.cir %s --check-prefix=CIR
+// RUN: %clang_cc1 -triple x86_64-unknown-linux-gnu -fclangir -emit-llvm %s -o %t-cir.ll
+// RUN: FileCheck --input-file=%t-cir.ll %s --check-prefix=LLVM
// RUN: %clang_cc1 -triple x86_64-unknown-linux-gnu -emit-llvm %s -o %t.ll
// RUN: FileCheck --input-file=%t.ll %s --check-prefix=OGCG
@@ -12,8 +14,8 @@ void label() {
// CIR: cir.label "labelA"
// CIR: cir.return
-// Note: We are not lowering to LLVM IR via CIR at this stage because that
-// process depends on the GotoSolver.
+// LLVM:define dso_local void @label
+// LLVM: ret void
// OGCG: define dso_local void @label
// OGCG: br label %labelA
@@ -33,6 +35,11 @@ void multiple_labels() {
// CIR: cir.label "labelC"
// CIR: cir.return
+// LLVM: define dso_local void @multiple_labels()
+// LLVM: br label %1
+// LLVM: 1:
+// LLVM: ret void
+
// OGCG: define dso_local void @multiple_labels
// OGCG: br label %labelB
// OGCG: labelB:
@@ -56,6 +63,22 @@ void label_in_if(int cond) {
// CIR: }
// CIR: cir.return
+// LLVM: define dso_local void @label_in_if
+// LLVM: br label %3
+// LLVM: 3:
+// LLVM: [[LOAD:%.*]] = load i32, ptr [[COND:%.*]], align 4
+// LLVM: [[CMP:%.*]] = icmp ne i32 [[LOAD]], 0
+// LLVM: br i1 [[CMP]], label %6, label %9
+// LLVM: 6:
+// LLVM: [[LOAD2:%.*]] = load i32, ptr [[COND]], align 4
+// LLVM: [[ADD1:%.*]] = add nsw i32 [[LOAD2]], 1
+// LLVM: store i32 [[ADD1]], ptr [[COND]], align 4
+// LLVM: br label %9
+// LLVM: 9:
+// LLVM: br label %10
+// LLVM: 10:
+// LLVM: ret void
+
// OGCG: define dso_local void @label_in_if
// OGCG: if.then:
// OGCG: br label %labelD
@@ -80,6 +103,13 @@ void after_return() {
// CIR: cir.label "label"
// CIR: cir.br ^bb1
+// LLVM: define dso_local void @after_return
+// LLVM: br label %1
+// LLVM: 1:
+// LLVM: ret void
+// LLVM: 2:
+// LLVM: br label %1
+
// OGCG: define dso_local void @after_return
// OGCG: br label %label
// OGCG: label:
@@ -97,6 +127,11 @@ void after_unreachable() {
// CIR: cir.label "label"
// CIR: cir.return
+// LLVM: define dso_local void @after_unreachable
+// LLVM: unreachable
+// LLVM: 1:
+// LLVM: ret void
+
// OGCG: define dso_local void @after_unreachable
// OGCG: unreachable
// OGCG: label:
@@ -111,6 +146,9 @@ void labelWithoutMatch() {
// CIR: cir.return
// CIR: }
+// LLVM: define dso_local void @labelWithoutMatch
+// LLVM: ret void
+
// OGCG: define dso_local void @labelWithoutMatch
// OGCG: br label %end
// OGCG: end:
@@ -132,6 +170,15 @@ void foo() {
// CIR: cir.label "label"
// CIR: %0 = cir.alloca !rec_S, !cir.ptr<!rec_S>, ["agg.tmp0"]
+// LLVM:define dso_local void @foo() {
+// LLVM: [[ALLOC:%.*]] = alloca %struct.S, i64 1, align 1
+// LLVM: br label %2
+// LLVM:2:
+// LLVM: [[CALL:%.*]] = call %struct.S @get()
+// LLVM: store %struct.S [[CALL]], ptr [[ALLOC]], align 1
+// LLVM: [[LOAD:%.*]] = load %struct.S, ptr [[ALLOC]], align 1
+// LLVM: call void @bar(%struct.S [[LOAD]])
+
// OGCG: define dso_local void @foo()
// OGCG: %agg.tmp = alloca %struct.S, align 1
// OGCG: %undef.agg.tmp = alloca %struct.S, align 1
diff --git a/clang/test/CIR/Lowering/goto.cir b/clang/test/CIR/Lowering/goto.cir
new file mode 100644
index 0000000000000..cd3a57d2e7138
--- /dev/null
+++ b/clang/test/CIR/Lowering/goto.cir
@@ -0,0 +1,52 @@
+// RUN: cir-opt %s --pass-pipeline='builtin.module(cir-to-llvm,canonicalize{region-simplify=disabled})' -o - | FileCheck %s -check-prefix=MLIR
+
+!s32i = !cir.int<s, 32>
+
+module {
+
+ cir.func @gotoFromIf(%arg0: !s32i) -> !s32i {
+ %0 = cir.alloca !s32i, !cir.ptr<!s32i>, ["x", init] {alignment = 4 : i64}
+ %1 = cir.alloca !s32i, !cir.ptr<!s32i>, ["__retval"] {alignment = 4 : i64}
+ cir.store %arg0, %0 : !s32i, !cir.ptr<!s32i>
+ cir.scope {
+ %6 = cir.load %0 : !cir.ptr<!s32i>, !s32i
+ %7 = cir.const #cir.int<5> : !s32i
+ %8 = cir.cmp(gt, %6, %7) : !s32i, !cir.bool
+ cir.if %8 {
+ cir.goto "err"
+ }
+ }
+ %2 = cir.const #cir.int<0> : !s32i
+ cir.store %2, %1 : !s32i, !cir.ptr<!s32i>
+ cir.br ^bb1
+ ^bb1:
+ %3 = cir.load %1 : !cir.ptr<!s32i>, !s32i
+ cir.return %3 : !s32i
+ ^bb2:
+ cir.label "err"
+ %4 = cir.const #cir.int<1> : !s32i
+ %5 = cir.unary(minus, %4) : !s32i, !s32i
+ cir.store %5, %1 : !s32i, !cir.ptr<!s32i>
+ cir.br ^bb1
+ }
+
+// MLIR: llvm.func @gotoFromIf
+// MLIR: %[[#One:]] = llvm.mlir.constant(1 : i32) : i32
+// MLIR: %[[#Zero:]] = llvm.mlir.constant(0 : i32) : i32
+// MLIR: llvm.cond_br {{.*}}, ^bb[[#COND_YES:]], ^bb[[#COND_NO:]]
+// MLIR: ^bb[[#COND_YES]]:
+// MLIR: llvm.br ^bb[[#GOTO_BLK:]]
+// MLIR: ^bb[[#COND_NO]]:
+// MLIR: llvm.br ^bb[[#BLK:]]
+// MLIR: ^bb[[#BLK]]:
+// MLIR: llvm.store %[[#Zero]], %[[#Ret_val_addr:]] {{.*}}: i32, !llvm.ptr
+// MLIR: llvm.br ^bb[[#RETURN:]]
+// MLIR: ^bb[[#RETURN]]:
+// MLIR: %[[#Ret_val:]] = llvm.load %[[#Ret_val_addr]] {alignment = 4 : i64} : !llvm.ptr -> i32
+// MLIR: llvm.return %[[#Ret_val]] : i32
+// MLIR: ^bb[[#GOTO_BLK]]:
+// MLIR: %[[#Neg_one:]] = llvm.sub %[[#Zero]], %[[#One]] : i32
+// MLIR: llvm.store %[[#Neg_one]], %[[#Ret_val_addr]] {{.*}}: i32, !llvm.ptr
+// MLIR: llvm.br ^bb[[#RETURN]]
+// MLIR: }
+}
|
andykaylor
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This looks great. I have just a few nits.
| namespace { | ||
|
|
||
| struct GotoSolverPass : public GotoSolverBase<GotoSolverPass> { | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: Can you remove this blank line?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
| }; | ||
|
|
||
| static void process(cir::FuncOp func) { | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Remove blank line? I know this is probably a personal preference, but I'd like to see us be consistent about not having blank lines at the start of a scope.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
| mlir::OpBuilder::InsertionGuard guard(rewriter); | ||
| rewriter.setInsertionPoint(goTo); | ||
| Block *dest = labels[goTo.getLabel()]; | ||
| rewriter.create<cir::BrOp>(goTo.getLoc(), dest); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| rewriter.create<cir::BrOp>(goTo.getLoc(), dest); | |
| cir::BrOp::create(rewriter, goTo.getLoc(), dest); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
bcardosolopes
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM after nit
| @@ -0,0 +1,52 @@ | |||
| #include "PassDetail.h" | |||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You need the LLVM comment header boilerplate here!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for catching that.
mmha
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, just one nit
|
|
||
| void GotoSolverPass::runOnOperation() { | ||
| llvm::TimeTraceScope scope("Goto Solver"); | ||
| getOperation()->walk([&](cir::FuncOp op) { process(op); }); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You can probably pass process directly instead of wrapping it in a lambda
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
|
LLVM Buildbot has detected a new failure on builder Full details are available at: https://lab.llvm.org/buildbot/#/builders/195/builds/13560 Here is the relevant piece of the build log for the reference |
This PR upstreams the GotoSolver pass.
It works by walking the function and matching each label to a goto. If a label is not matched to a goto, it is removed and not lowered.