Skip to content

Commit 229cd13

Browse files
committed
[CIR] Add runAtStartOfConvertCIRToMLIRPass() to ConvertCIRToMLIRPass
This allows injecting some dialects and patterns to the cir::ConvertCIRToMLIRPass lowering pass.
1 parent 754364d commit 229cd13

File tree

2 files changed

+17
-0
lines changed

2 files changed

+17
-0
lines changed

clang/include/clang/CIR/LowerToMLIR.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
#define CLANG_CIR_LOWERTOMLIR_H
1414

1515
#include "mlir/Transforms/DialectConversion.h"
16+
#include <functional>
1617

1718
namespace cir {
1819

@@ -21,6 +22,8 @@ void populateCIRLoopToSCFConversionPatterns(mlir::RewritePatternSet &patterns,
2122

2223
mlir::TypeConverter prepareTypeConverter();
2324

25+
void runAtStartOfConvertCIRToMLIRPass(std::function<void(mlir::ConversionTarget)>);
26+
2427
mlir::ModuleOp
2528
lowerFromCIRToMLIRToLLVMDialect(mlir::ModuleOp theModule,
2629
mlir::MLIRContext *mlirCtx = nullptr);

clang/lib/CIR/Lowering/ThroughMLIR/LowerCIRToMLIR.cpp

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@
1010
//
1111
//===----------------------------------------------------------------------===//
1212

13+
#include <functional>
14+
1315
#include "LowerToMLIRHelpers.h"
1416
#include "mlir/Conversion/AffineToStandard/AffineToStandard.h"
1517
#include "mlir/Conversion/ArithToLLVM/ArithToLLVM.h"
@@ -92,6 +94,8 @@ struct ConvertCIRToMLIRPass
9294
}
9395

9496
StringRef getArgument() const override { return "cir-to-mlir"; }
97+
98+
inline static std::function<void(mlir::ConversionTarget)> runAtStartHook;
9599
};
96100

97101
class CIRCallOpLowering : public mlir::OpConversionPattern<cir::CallOp> {
@@ -1460,10 +1464,20 @@ void ConvertCIRToMLIRPass::runOnOperation() {
14601464
mlir::math::MathDialect, mlir::vector::VectorDialect>();
14611465
target.addIllegalDialect<cir::CIRDialect>();
14621466

1467+
if (runAtStartHook)
1468+
runAtStartHook(target);
1469+
14631470
if (failed(applyPartialConversion(module, target, std::move(patterns))))
14641471
signalPassFailure();
14651472
}
14661473

1474+
/// Set a hook to be called just before applying the dialect conversion so other
1475+
/// dialects or patterns can be added
1476+
void runAtStartOfConvertCIRToMLIRPass(
1477+
std::function<void(mlir::ConversionTarget)> hook) {
1478+
ConvertCIRToMLIRPass::runAtStartHook = std::move(hook);
1479+
}
1480+
14671481
mlir::ModuleOp lowerFromCIRToMLIRToLLVMDialect(mlir::ModuleOp theModule,
14681482
mlir::MLIRContext *mlirCtx) {
14691483
llvm::TimeTraceScope scope("Lower from CIR to MLIR To LLVM Dialect");

0 commit comments

Comments
 (0)