Skip to content

Commit e776e31

Browse files
committed
[CIR] Add runAtStartOfConvertCIRToLLVMPass() to ConvertCIRToLLVMPass
This allows injecting some dialects and patterns to the cir::direct::ConvertCIRToLLVMPass lowering pass.
1 parent 6a1848c commit e776e31

File tree

2 files changed

+17
-1
lines changed

2 files changed

+17
-1
lines changed

clang/include/clang/CIR/LowerToLLVM.h

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,9 @@
1313
#define CLANG_CIR_LOWERTOLLVM_H
1414

1515
#include "mlir/Pass/Pass.h"
16+
#include "mlir/Transforms/DialectConversion.h"
1617

18+
#include <functional>
1719
#include <memory>
1820

1921
namespace llvm {
@@ -33,7 +35,9 @@ std::unique_ptr<llvm::Module> lowerDirectlyFromCIRToLLVMIR(
3335
mlir::ModuleOp theModule, llvm::LLVMContext &llvmCtx,
3436
bool disableVerifier = false, bool disableCCLowering = false,
3537
bool disableDebugInfo = false);
36-
}
38+
void runAtStartOfConvertCIRToLLVMPass(
39+
std::function<void(mlir::ConversionTarget)>);
40+
} // namespace direct
3741

3842
// Lower directly from pristine CIR to LLVMIR.
3943
std::unique_ptr<llvm::Module>

clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1384,6 +1384,8 @@ struct ConvertCIRToLLVMPass
13841384

13851385
void processCIRAttrs(mlir::ModuleOp moduleOp);
13861386

1387+
inline static std::function<void(mlir::ConversionTarget)> runAtStartHook;
1388+
13871389
StringRef getDescription() const override {
13881390
return "Convert the prepared CIR dialect module to LLVM dialect";
13891391
}
@@ -4640,6 +4642,9 @@ void ConvertCIRToLLVMPass::runOnOperation() {
46404642
// ,YieldOp
46414643
>();
46424644
// clang-format on
4645+
if (runAtStartHook)
4646+
runAtStartHook(target);
4647+
46434648
target.addLegalDialect<mlir::LLVM::LLVMDialect>();
46444649
target.addIllegalDialect<mlir::BuiltinDialect, cir::CIRDialect,
46454650
mlir::func::FuncDialect>();
@@ -4677,6 +4682,13 @@ void ConvertCIRToLLVMPass::runOnOperation() {
46774682
buildGlobalAnnotationsVar(stringGlobalsMap, argStringGlobalsMap, argsVarMap);
46784683
}
46794684

4685+
/// Set a hook to be called just before applying the dialect conversion so other
4686+
/// dialects or patterns can be added
4687+
void runAtStartOfConvertCIRToLLVMPass(
4688+
std::function<void(mlir::ConversionTarget)> hook) {
4689+
ConvertCIRToLLVMPass::runAtStartHook = std::move(hook);
4690+
}
4691+
46804692
std::unique_ptr<mlir::Pass> createConvertCIRToLLVMPass() {
46814693
return std::make_unique<ConvertCIRToLLVMPass>();
46824694
}

0 commit comments

Comments
 (0)