Skip to content

Commit 6ae8bfa

Browse files
committed
[CIR] Add runAtStartOfConvertCIRToMLIRPass() to ConvertCIRToMLIRPass
This allows injecting some dialects and patterns to the cir::ConvertCIRToMLIRPass lowering pass.
1 parent fa2c8f7 commit 6ae8bfa

File tree

2 files changed

+18
-0
lines changed

2 files changed

+18
-0
lines changed

clang/include/clang/CIR/LowerToMLIR.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,11 +12,15 @@
1212
#ifndef CLANG_CIR_LOWERTOMLIR_H
1313
#define CLANG_CIR_LOWERTOMLIR_H
1414

15+
#include "mlir/Transforms/DialectConversion.h"
16+
#include <functional>
17+
1518
namespace cir {
1619

1720
void populateCIRLoopToSCFConversionPatterns(mlir::RewritePatternSet &patterns,
1821
mlir::TypeConverter &converter);
1922
mlir::TypeConverter prepareTypeConverter();
23+
void runAtStartOfConvertCIRToMLIRPass(std::function<void(mlir::ConversionTarget)>);
2024
} // namespace cir
2125

2226
#endif // CLANG_CIR_LOWERTOMLIR_H_

clang/lib/CIR/Lowering/ThroughMLIR/LowerCIRToMLIR.cpp

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@
1010
//
1111
//===----------------------------------------------------------------------===//
1212

13+
#include <functional>
14+
1315
#include "LowerToMLIRHelpers.h"
1416
#include "mlir/Conversion/AffineToStandard/AffineToStandard.h"
1517
#include "mlir/Conversion/ArithToLLVM/ArithToLLVM.h"
@@ -91,6 +93,8 @@ struct ConvertCIRToMLIRPass
9193
}
9294

9395
StringRef getArgument() const override { return "cir-to-mlir"; }
96+
97+
inline static std::function<void(mlir::ConversionTarget)> runAtStartHook;
9498
};
9599

96100
class CIRCallOpLowering : public mlir::OpConversionPattern<cir::CallOp> {
@@ -1461,10 +1465,20 @@ void ConvertCIRToMLIRPass::runOnOperation() {
14611465
mlir::math::MathDialect, mlir::vector::VectorDialect>();
14621466
target.addIllegalDialect<cir::CIRDialect>();
14631467

1468+
if (runAtStartHook)
1469+
runAtStartHook(target);
1470+
14641471
if (failed(applyPartialConversion(module, target, std::move(patterns))))
14651472
signalPassFailure();
14661473
}
14671474

1475+
/// Set a hook to be called just before applying the dialect conversion so other
1476+
/// dialects or patterns can be added
1477+
void runAtStartOfConvertCIRToMLIRPass(
1478+
std::function<void(mlir::ConversionTarget)> hook) {
1479+
ConvertCIRToMLIRPass::runAtStartHook = std::move(hook);
1480+
}
1481+
14681482
std::unique_ptr<llvm::Module>
14691483
lowerFromCIRToMLIRToLLVMIR(mlir::ModuleOp theModule,
14701484
std::unique_ptr<mlir::MLIRContext> mlirCtx,

0 commit comments

Comments
 (0)