|
22 | 22 | #include "mlir/Dialect/SCF/IR/SCF.h"
|
23 | 23 | #include "mlir/IR/SymbolTable.h"
|
24 | 24 | #include "mlir/Pass/Pass.h"
|
25 |
| -#include "mlir/Transforms/DialectConversion.h" |
| 25 | +#include "mlir/Transforms/WalkPatternRewriteDriver.h" |
26 | 26 |
|
27 | 27 | namespace mlir {
|
28 | 28 | #define GEN_PASS_DEF_CONVERTSCFTOOPENMPPASS
|
@@ -538,15 +538,18 @@ struct ParallelOpLowering : public OpRewritePattern<scf::ParallelOp> {
|
538 | 538 |
|
539 | 539 | /// Applies the conversion patterns in the given function.
|
540 | 540 | static LogicalResult applyPatterns(ModuleOp module, unsigned numThreads) {
|
541 |
| - ConversionTarget target(*module.getContext()); |
542 |
| - target.addIllegalOp<scf::ReduceOp, scf::ReduceReturnOp, scf::ParallelOp>(); |
543 |
| - target.addLegalDialect<omp::OpenMPDialect, LLVM::LLVMDialect, |
544 |
| - memref::MemRefDialect>(); |
545 |
| - |
546 | 541 | RewritePatternSet patterns(module.getContext());
|
547 | 542 | patterns.add<ParallelOpLowering>(module.getContext(), numThreads);
|
548 | 543 | FrozenRewritePatternSet frozen(std::move(patterns));
|
549 |
| - return applyPartialConversion(module, target, frozen); |
| 544 | + walkAndApplyPatterns(module, frozen); |
| 545 | + auto status = module.walk([](Operation *op) { |
| 546 | + if (isa<scf::ReduceOp, scf::ReduceReturnOp, scf::ParallelOp>(op)) { |
| 547 | + op->emitError("unconverted operation found"); |
| 548 | + return WalkResult::interrupt(); |
| 549 | + } |
| 550 | + return WalkResult::advance(); |
| 551 | + }); |
| 552 | + return failure(status.wasInterrupted()); |
550 | 553 | }
|
551 | 554 |
|
552 | 555 | /// A pass converting SCF operations to OpenMP operations.
|
|
0 commit comments