|
10 | 10 | //
|
11 | 11 | //===----------------------------------------------------------------------===//
|
12 | 12 |
|
| 13 | +#include <functional> |
| 14 | + |
13 | 15 | #include "LowerToMLIRHelpers.h"
|
14 | 16 | #include "mlir/Conversion/AffineToStandard/AffineToStandard.h"
|
15 | 17 | #include "mlir/Conversion/ArithToLLVM/ArithToLLVM.h"
|
@@ -91,6 +93,8 @@ struct ConvertCIRToMLIRPass
|
91 | 93 | }
|
92 | 94 |
|
93 | 95 | StringRef getArgument() const override { return "cir-to-mlir"; }
|
| 96 | + |
| 97 | + inline static std::function<void(mlir::ConversionTarget)> runAtStartHook; |
94 | 98 | };
|
95 | 99 |
|
96 | 100 | class CIRCallOpLowering : public mlir::OpConversionPattern<cir::CallOp> {
|
@@ -1461,10 +1465,20 @@ void ConvertCIRToMLIRPass::runOnOperation() {
|
1461 | 1465 | mlir::math::MathDialect, mlir::vector::VectorDialect>();
|
1462 | 1466 | target.addIllegalDialect<cir::CIRDialect>();
|
1463 | 1467 |
|
| 1468 | + if (runAtStartHook) |
| 1469 | + runAtStartHook(target); |
| 1470 | + |
1464 | 1471 | if (failed(applyPartialConversion(module, target, std::move(patterns))))
|
1465 | 1472 | signalPassFailure();
|
1466 | 1473 | }
|
1467 | 1474 |
|
| 1475 | +/// Set a hook to be called just before applying the dialect conversion so other |
| 1476 | +/// dialects or patterns can be added |
| 1477 | +void runAtStartOfConvertCIRToMLIRPass( |
| 1478 | + std::function<void(mlir::ConversionTarget)> hook) { |
| 1479 | + ConvertCIRToMLIRPass::runAtStartHook = std::move(hook); |
| 1480 | +} |
| 1481 | + |
1468 | 1482 | std::unique_ptr<llvm::Module>
|
1469 | 1483 | lowerFromCIRToMLIRToLLVMIR(mlir::ModuleOp theModule,
|
1470 | 1484 | std::unique_ptr<mlir::MLIRContext> mlirCtx,
|
|
0 commit comments