Skip to content

Commit e896915

Browse files
committed
[CIR] Add runAtStartOfConvertCIRToLLVMPass() to ConvertCIRToLLVMPass
This allows injecting some dialects and patterns to the cir::direct::ConvertCIRToLLVMPass lowering pass.
1 parent c2e0d81 commit e896915

File tree

2 files changed

+17
-1
lines changed

2 files changed

+17
-1
lines changed

clang/include/clang/CIR/LowerToLLVM.h

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,9 @@
1313
#define CLANG_CIR_LOWERTOLLVM_H
1414

1515
#include "mlir/Pass/Pass.h"
16+
#include "mlir/Transforms/DialectConversion.h"
1617

18+
#include <functional>
1719
#include <memory>
1820

1921
namespace llvm {
@@ -33,7 +35,9 @@ std::unique_ptr<llvm::Module> lowerDirectlyFromCIRToLLVMIR(
3335
mlir::ModuleOp theModule, llvm::LLVMContext &llvmCtx,
3436
bool disableVerifier = false, bool disableCCLowering = false,
3537
bool disableDebugInfo = false);
36-
}
38+
void runAtStartOfConvertCIRToLLVMPass(
39+
std::function<void(mlir::ConversionTarget)>);
40+
} // namespace direct
3741

3842
// Lower directly from pristine CIR to LLVMIR.
3943
std::unique_ptr<llvm::Module>

clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1347,6 +1347,8 @@ struct ConvertCIRToLLVMPass
13471347

13481348
void processCIRAttrs(mlir::ModuleOp moduleOp);
13491349

1350+
inline static std::function<void(mlir::ConversionTarget)> runAtStartHook;
1351+
13501352
StringRef getDescription() const override {
13511353
return "Convert the prepared CIR dialect module to LLVM dialect";
13521354
}
@@ -4626,6 +4628,9 @@ void ConvertCIRToLLVMPass::runOnOperation() {
46264628
// ,YieldOp
46274629
>();
46284630
// clang-format on
4631+
if (runAtStartHook)
4632+
runAtStartHook(target);
4633+
46294634
target.addLegalDialect<mlir::LLVM::LLVMDialect>();
46304635
target.addIllegalDialect<mlir::BuiltinDialect, cir::CIRDialect,
46314636
mlir::func::FuncDialect>();
@@ -4663,6 +4668,13 @@ void ConvertCIRToLLVMPass::runOnOperation() {
46634668
buildGlobalAnnotationsVar(stringGlobalsMap, argStringGlobalsMap, argsVarMap);
46644669
}
46654670

4671+
/// Set a hook to be called just before applying the dialect conversion so other
4672+
/// dialects or patterns can be added
4673+
void runAtStartOfConvertCIRToLLVMPass(
4674+
std::function<void(mlir::ConversionTarget)> hook) {
4675+
ConvertCIRToLLVMPass::runAtStartHook = std::move(hook);
4676+
}
4677+
46664678
std::unique_ptr<mlir::Pass> createConvertCIRToLLVMPass() {
46674679
return std::make_unique<ConvertCIRToLLVMPass>();
46684680
}

0 commit comments

Comments
 (0)