|
10 | 10 | //
|
11 | 11 | //===----------------------------------------------------------------------===//
|
12 | 12 |
|
| 13 | +#include <functional> |
| 14 | + |
13 | 15 | #include "LowerToMLIRHelpers.h"
|
14 | 16 | #include "mlir/Conversion/AffineToStandard/AffineToStandard.h"
|
15 | 17 | #include "mlir/Conversion/ArithToLLVM/ArithToLLVM.h"
|
@@ -91,6 +93,8 @@ struct ConvertCIRToMLIRPass
|
91 | 93 | }
|
92 | 94 |
|
93 | 95 | StringRef getArgument() const override { return "cir-to-mlir"; }
|
| 96 | + |
| 97 | + inline static std::function<void(mlir::ConversionTarget)> runAtStartHook; |
94 | 98 | };
|
95 | 99 |
|
96 | 100 | class CIRCallOpLowering : public mlir::OpConversionPattern<cir::CallOp> {
|
@@ -1454,10 +1458,20 @@ void ConvertCIRToMLIRPass::runOnOperation() {
|
1454 | 1458 | mlir::math::MathDialect, mlir::vector::VectorDialect>();
|
1455 | 1459 | target.addIllegalDialect<cir::CIRDialect>();
|
1456 | 1460 |
|
| 1461 | + if (runAtStartHook) |
| 1462 | + runAtStartHook(target); |
| 1463 | + |
1457 | 1464 | if (failed(applyPartialConversion(module, target, std::move(patterns))))
|
1458 | 1465 | signalPassFailure();
|
1459 | 1466 | }
|
1460 | 1467 |
|
| 1468 | +/// Set a hook to be called just before applying the dialect conversion so other |
| 1469 | +/// dialects or patterns can be added |
| 1470 | +void runAtStartOfConvertCIRToMLIRPass( |
| 1471 | + std::function<void(mlir::ConversionTarget)> hook) { |
| 1472 | + ConvertCIRToMLIRPass::runAtStartHook = std::move(hook); |
| 1473 | +} |
| 1474 | + |
1461 | 1475 | std::unique_ptr<llvm::Module>
|
1462 | 1476 | lowerFromCIRToMLIRToLLVMIR(mlir::ModuleOp theModule,
|
1463 | 1477 | std::unique_ptr<mlir::MLIRContext> mlirCtx,
|
|
0 commit comments