|
6 | 6 | // |
7 | 7 | //===----------------------------------------------------------------------===// |
8 | 8 |
|
| 9 | +#include "mlir/Conversion/ConvertToSPIRV/ConvertToSPIRVPass.h" |
9 | 10 | #include "mlir/Conversion/ArithToSPIRV/ArithToSPIRV.h" |
10 | 11 | #include "mlir/Conversion/FuncToSPIRV/FuncToSPIRV.h" |
11 | 12 | #include "mlir/Conversion/GPUToSPIRV/GPUToSPIRV.h" |
|
23 | 24 | #include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h" |
24 | 25 | #include "mlir/IR/PatternMatch.h" |
25 | 26 | #include "mlir/Pass/Pass.h" |
26 | | -#include "mlir/Pass/PassOptions.h" |
27 | 27 | #include "mlir/Rewrite/FrozenRewritePatternSet.h" |
28 | 28 | #include "mlir/Transforms/DialectConversion.h" |
29 | 29 | #include <memory> |
30 | 30 |
|
31 | | -#define DEBUG_TYPE "test-convert-to-spirv" |
| 31 | +#define DEBUG_TYPE "convert-to-spirv" |
| 32 | + |
| 33 | +namespace mlir { |
| 34 | +#define GEN_PASS_DEF_CONVERTTOSPIRVPASS |
| 35 | +#include "mlir/Conversion/Passes.h.inc" |
| 36 | +} // namespace mlir |
32 | 37 |
|
33 | 38 | using namespace mlir; |
34 | 39 |
|
@@ -64,44 +69,9 @@ void populateConvertToSPIRVPatterns(const SPIRVTypeConverter &typeConverter, |
64 | 69 | } |
65 | 70 |
|
66 | 71 | /// A pass to perform the SPIR-V conversion. |
67 | | -struct TestConvertToSPIRVPass final |
68 | | - : PassWrapper<TestConvertToSPIRVPass, OperationPass<>> { |
69 | | - Option<bool> runSignatureConversion{ |
70 | | - *this, "run-signature-conversion", |
71 | | - llvm::cl::desc( |
72 | | - "Run function signature conversion to convert vector types"), |
73 | | - llvm::cl::init(true)}; |
74 | | - Option<bool> runVectorUnrolling{ |
75 | | - *this, "run-vector-unrolling", |
76 | | - llvm::cl::desc( |
77 | | - "Run vector unrolling to convert vector types in function bodies"), |
78 | | - llvm::cl::init(true)}; |
79 | | - Option<bool> convertGPUModules{ |
80 | | - *this, "convert-gpu-modules", |
81 | | - llvm::cl::desc("Clone and convert GPU modules"), llvm::cl::init(false)}; |
82 | | - Option<bool> nestInGPUModule{ |
83 | | - *this, "nest-in-gpu-module", |
84 | | - llvm::cl::desc("Put converted SPIR-V module inside the gpu.module " |
85 | | - "instead of alongside it."), |
86 | | - llvm::cl::init(false)}; |
87 | | - |
88 | | - MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestConvertToSPIRVPass) |
89 | | - |
90 | | - StringRef getArgument() const final { return "test-convert-to-spirv"; } |
91 | | - StringRef getDescription() const final { |
92 | | - return "Conversion to SPIR-V pass only used for internal tests."; |
93 | | - } |
94 | | - void getDependentDialects(DialectRegistry ®istry) const override { |
95 | | - registry.insert<spirv::SPIRVDialect>(); |
96 | | - registry.insert<vector::VectorDialect>(); |
97 | | - } |
98 | | - |
99 | | - TestConvertToSPIRVPass() = default; |
100 | | - TestConvertToSPIRVPass(bool convertGPUModules, bool nestInGPUModule) { |
101 | | - this->convertGPUModules = convertGPUModules; |
102 | | - this->nestInGPUModule = nestInGPUModule; |
103 | | - }; |
104 | | - TestConvertToSPIRVPass(const TestConvertToSPIRVPass &) {} |
| 72 | +struct ConvertToSPIRVPass final |
| 73 | + : impl::ConvertToSPIRVPassBase<ConvertToSPIRVPass> { |
| 74 | + using ConvertToSPIRVPassBase::ConvertToSPIRVPassBase; |
105 | 75 |
|
106 | 76 | void runOnOperation() override { |
107 | 77 | Operation *op = getOperation(); |
@@ -162,14 +132,3 @@ struct TestConvertToSPIRVPass final |
162 | 132 | }; |
163 | 133 |
|
164 | 134 | } // namespace |
165 | | - |
166 | | -namespace mlir::test { |
167 | | -void registerTestConvertToSPIRVPass() { |
168 | | - PassRegistration<TestConvertToSPIRVPass>(); |
169 | | -} |
170 | | -std::unique_ptr<Pass> createTestConvertToSPIRVPass(bool convertGPUModules, |
171 | | - bool nestInGPUModule) { |
172 | | - return std::make_unique<TestConvertToSPIRVPass>(convertGPUModules, |
173 | | - nestInGPUModule); |
174 | | -} |
175 | | -} // namespace mlir::test |
0 commit comments