Skip to content

Commit 5ac8751

Browse files
committed
[CIR] Add runAtStartOfConvertCIRToLLVMPass() to ConvertCIRToLLVMPass
This allows injecting some dialects and patterns to the cir::direct::ConvertCIRToLLVMPass lowering pass.
1 parent 6ae8bfa commit 5ac8751

File tree

2 files changed

+17
-1
lines changed

2 files changed

+17
-1
lines changed

clang/include/clang/CIR/LowerToLLVM.h

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,9 @@
1313
#define CLANG_CIR_LOWERTOLLVM_H
1414

1515
#include "mlir/Pass/Pass.h"
16+
#include "mlir/Transforms/DialectConversion.h"
1617

18+
#include <functional>
1719
#include <memory>
1820

1921
namespace llvm {
@@ -33,7 +35,9 @@ std::unique_ptr<llvm::Module> lowerDirectlyFromCIRToLLVMIR(
3335
mlir::ModuleOp theModule, llvm::LLVMContext &llvmCtx,
3436
bool disableVerifier = false, bool disableCCLowering = false,
3537
bool disableDebugInfo = false);
36-
}
38+
void runAtStartOfConvertCIRToLLVMPass(
39+
std::function<void(mlir::ConversionTarget)>);
40+
} // namespace direct
3741

3842
// Lower directly from pristine CIR to LLVMIR.
3943
std::unique_ptr<llvm::Module>

clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1352,6 +1352,8 @@ struct ConvertCIRToLLVMPass
13521352

13531353
void processCIRAttrs(mlir::ModuleOp moduleOp);
13541354

1355+
inline static std::function<void(mlir::ConversionTarget)> runAtStartHook;
1356+
13551357
StringRef getDescription() const override {
13561358
return "Convert the prepared CIR dialect module to LLVM dialect";
13571359
}
@@ -4692,6 +4694,9 @@ void ConvertCIRToLLVMPass::runOnOperation() {
46924694
// ,YieldOp
46934695
>();
46944696
// clang-format on
4697+
if (runAtStartHook)
4698+
runAtStartHook(target);
4699+
46954700
target.addLegalDialect<mlir::LLVM::LLVMDialect>();
46964701
target.addIllegalDialect<mlir::BuiltinDialect, cir::CIRDialect,
46974702
mlir::func::FuncDialect>();
@@ -4729,6 +4734,13 @@ void ConvertCIRToLLVMPass::runOnOperation() {
47294734
buildGlobalAnnotationsVar(stringGlobalsMap, argStringGlobalsMap, argsVarMap);
47304735
}
47314736

4737+
/// Set a hook to be called just before applying the dialect conversion so other
4738+
/// dialects or patterns can be added
4739+
void runAtStartOfConvertCIRToLLVMPass(
4740+
std::function<void(mlir::ConversionTarget)> hook) {
4741+
ConvertCIRToLLVMPass::runAtStartHook = std::move(hook);
4742+
}
4743+
47324744
std::unique_ptr<mlir::Pass> createConvertCIRToLLVMPass() {
47334745
return std::make_unique<ConvertCIRToLLVMPass>();
47344746
}

0 commit comments

Comments
 (0)