|
22 | 22 | #include "mlir/Dialect/SCF/IR/SCF.h" |
23 | 23 | #include "mlir/IR/SymbolTable.h" |
24 | 24 | #include "mlir/Pass/Pass.h" |
25 | | -#include "mlir/Transforms/DialectConversion.h" |
| 25 | +#include "mlir/Transforms/WalkPatternRewriteDriver.h" |
26 | 26 |
|
27 | 27 | namespace mlir { |
28 | 28 | #define GEN_PASS_DEF_CONVERTSCFTOOPENMPPASS |
@@ -538,15 +538,18 @@ struct ParallelOpLowering : public OpRewritePattern<scf::ParallelOp> { |
538 | 538 |
|
539 | 539 | /// Applies the conversion patterns in the given function. |
540 | 540 | static LogicalResult applyPatterns(ModuleOp module, unsigned numThreads) { |
541 | | - ConversionTarget target(*module.getContext()); |
542 | | - target.addIllegalOp<scf::ReduceOp, scf::ReduceReturnOp, scf::ParallelOp>(); |
543 | | - target.addLegalDialect<omp::OpenMPDialect, LLVM::LLVMDialect, |
544 | | - memref::MemRefDialect>(); |
545 | | - |
546 | 541 | RewritePatternSet patterns(module.getContext()); |
547 | 542 | patterns.add<ParallelOpLowering>(module.getContext(), numThreads); |
548 | 543 | FrozenRewritePatternSet frozen(std::move(patterns)); |
549 | | - return applyPartialConversion(module, target, frozen); |
| 544 | + walkAndApplyPatterns(module, frozen); |
| 545 | + auto status = module.walk([](Operation *op) { |
| 546 | + if (isa<scf::ReduceOp, scf::ReduceReturnOp, scf::ParallelOp>(op)) { |
| 547 | + op->emitError("unconverted operation found"); |
| 548 | + return WalkResult::interrupt(); |
| 549 | + } |
| 550 | + return WalkResult::advance(); |
| 551 | + }); |
| 552 | + return failure(status.wasInterrupted()); |
550 | 553 | } |
551 | 554 |
|
552 | 555 | /// A pass converting SCF operations to OpenMP operations. |
|
0 commit comments