Skip to content

Commit 0d3fa2d

Browse files
committed
[CIR] Add runAtStartOfConvertCIRToLLVMPass() to ConvertCIRToLLVMPass
This allows injecting some dialects and patterns to the cir::direct::ConvertCIRToLLVMPass lowering pass.
1 parent 229cd13 commit 0d3fa2d

File tree

2 files changed

+17
-0
lines changed

2 files changed

+17
-0
lines changed

clang/include/clang/CIR/LowerToLLVM.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,9 @@
1313
#define CLANG_CIR_LOWERTOLLVM_H
1414

1515
#include "mlir/Pass/Pass.h"
16+
#include "mlir/Transforms/DialectConversion.h"
1617

18+
#include <functional>
1719
#include <memory>
1820

1921
namespace llvm {
@@ -34,6 +36,9 @@ mlir::ModuleOp lowerDirectlyFromCIRToLLVMDialect(mlir::ModuleOp theModule,
3436
bool disableCCLowering = false,
3537
bool disableDebugInfo = false);
3638

39+
void runAtStartOfConvertCIRToLLVMPass(
40+
std::function<void(mlir::ConversionTarget)>);
41+
3742
// Lower directly from pristine CIR to LLVMIR.
3843
std::unique_ptr<llvm::Module> lowerDirectlyFromCIRToLLVMIR(
3944
mlir::ModuleOp theModule, llvm::LLVMContext &llvmCtx,

clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1381,6 +1381,8 @@ struct ConvertCIRToLLVMPass
13811381

13821382
void processCIRAttrs(mlir::ModuleOp moduleOp);
13831383

1384+
inline static std::function<void(mlir::ConversionTarget)> runAtStartHook;
1385+
13841386
StringRef getDescription() const override {
13851387
return "Convert the prepared CIR dialect module to LLVM dialect";
13861388
}
@@ -4741,6 +4743,9 @@ void ConvertCIRToLLVMPass::runOnOperation() {
47414743
// ,YieldOp
47424744
>();
47434745
// clang-format on
4746+
if (runAtStartHook)
4747+
runAtStartHook(target);
4748+
47444749
target.addLegalDialect<mlir::LLVM::LLVMDialect>();
47454750
target.addIllegalDialect<mlir::BuiltinDialect, cir::CIRDialect,
47464751
mlir::func::FuncDialect>();
@@ -4778,6 +4783,13 @@ void ConvertCIRToLLVMPass::runOnOperation() {
47784783
processCIRAttrs(module);
47794784
}
47804785

4786+
/// Set a hook to be called just before applying the dialect conversion so other
4787+
/// dialects or patterns can be added
4788+
void runAtStartOfConvertCIRToLLVMPass(
4789+
std::function<void(mlir::ConversionTarget)> hook) {
4790+
ConvertCIRToLLVMPass::runAtStartHook = std::move(hook);
4791+
}
4792+
47814793
std::unique_ptr<mlir::Pass> createConvertCIRToLLVMPass() {
47824794
return std::make_unique<ConvertCIRToLLVMPass>();
47834795
}

0 commit comments

Comments
 (0)