|
6 | 6 | //
|
7 | 7 | //===----------------------------------------------------------------------===//
|
8 | 8 |
|
| 9 | +#include "mlir/Conversion/ConvertToSPIRV/ConvertToSPIRVPass.h" |
9 | 10 | #include "mlir/Conversion/ArithToSPIRV/ArithToSPIRV.h"
|
10 | 11 | #include "mlir/Conversion/FuncToSPIRV/FuncToSPIRV.h"
|
11 | 12 | #include "mlir/Conversion/GPUToSPIRV/GPUToSPIRV.h"
|
|
23 | 24 | #include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"
|
24 | 25 | #include "mlir/IR/PatternMatch.h"
|
25 | 26 | #include "mlir/Pass/Pass.h"
|
26 |
| -#include "mlir/Pass/PassOptions.h" |
27 | 27 | #include "mlir/Rewrite/FrozenRewritePatternSet.h"
|
28 | 28 | #include "mlir/Transforms/DialectConversion.h"
|
29 | 29 | #include <memory>
|
30 | 30 |
|
31 |
| -#define DEBUG_TYPE "test-convert-to-spirv" |
| 31 | +#define DEBUG_TYPE "convert-to-spirv" |
| 32 | + |
| 33 | +namespace mlir { |
| 34 | +#define GEN_PASS_DEF_CONVERTTOSPIRVPASS |
| 35 | +#include "mlir/Conversion/Passes.h.inc" |
| 36 | +} // namespace mlir |
32 | 37 |
|
33 | 38 | using namespace mlir;
|
34 | 39 |
|
@@ -64,44 +69,9 @@ void populateConvertToSPIRVPatterns(const SPIRVTypeConverter &typeConverter,
|
64 | 69 | }
|
65 | 70 |
|
66 | 71 | /// A pass to perform the SPIR-V conversion.
|
67 |
| -struct TestConvertToSPIRVPass final |
68 |
| - : PassWrapper<TestConvertToSPIRVPass, OperationPass<>> { |
69 |
| - Option<bool> runSignatureConversion{ |
70 |
| - *this, "run-signature-conversion", |
71 |
| - llvm::cl::desc( |
72 |
| - "Run function signature conversion to convert vector types"), |
73 |
| - llvm::cl::init(true)}; |
74 |
| - Option<bool> runVectorUnrolling{ |
75 |
| - *this, "run-vector-unrolling", |
76 |
| - llvm::cl::desc( |
77 |
| - "Run vector unrolling to convert vector types in function bodies"), |
78 |
| - llvm::cl::init(true)}; |
79 |
| - Option<bool> convertGPUModules{ |
80 |
| - *this, "convert-gpu-modules", |
81 |
| - llvm::cl::desc("Clone and convert GPU modules"), llvm::cl::init(false)}; |
82 |
| - Option<bool> nestInGPUModule{ |
83 |
| - *this, "nest-in-gpu-module", |
84 |
| - llvm::cl::desc("Put converted SPIR-V module inside the gpu.module " |
85 |
| - "instead of alongside it."), |
86 |
| - llvm::cl::init(false)}; |
87 |
| - |
88 |
| - MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestConvertToSPIRVPass) |
89 |
| - |
90 |
| - StringRef getArgument() const final { return "test-convert-to-spirv"; } |
91 |
| - StringRef getDescription() const final { |
92 |
| - return "Conversion to SPIR-V pass only used for internal tests."; |
93 |
| - } |
94 |
| - void getDependentDialects(DialectRegistry ®istry) const override { |
95 |
| - registry.insert<spirv::SPIRVDialect>(); |
96 |
| - registry.insert<vector::VectorDialect>(); |
97 |
| - } |
98 |
| - |
99 |
| - TestConvertToSPIRVPass() = default; |
100 |
| - TestConvertToSPIRVPass(bool convertGPUModules, bool nestInGPUModule) { |
101 |
| - this->convertGPUModules = convertGPUModules; |
102 |
| - this->nestInGPUModule = nestInGPUModule; |
103 |
| - }; |
104 |
| - TestConvertToSPIRVPass(const TestConvertToSPIRVPass &) {} |
| 72 | +struct ConvertToSPIRVPass final |
| 73 | + : impl::ConvertToSPIRVPassBase<ConvertToSPIRVPass> { |
| 74 | + using ConvertToSPIRVPassBase::ConvertToSPIRVPassBase; |
105 | 75 |
|
106 | 76 | void runOnOperation() override {
|
107 | 77 | Operation *op = getOperation();
|
@@ -162,14 +132,3 @@ struct TestConvertToSPIRVPass final
|
162 | 132 | };
|
163 | 133 |
|
164 | 134 | } // namespace
|
165 |
| - |
166 |
| -namespace mlir::test { |
167 |
| -void registerTestConvertToSPIRVPass() { |
168 |
| - PassRegistration<TestConvertToSPIRVPass>(); |
169 |
| -} |
170 |
| -std::unique_ptr<Pass> createTestConvertToSPIRVPass(bool convertGPUModules, |
171 |
| - bool nestInGPUModule) { |
172 |
| - return std::make_unique<TestConvertToSPIRVPass>(convertGPUModules, |
173 |
| - nestInGPUModule); |
174 |
| -} |
175 |
| -} // namespace mlir::test |
0 commit comments