22
22
#include " mlir/Transforms/GreedyPatternRewriteDriver.h"
23
23
24
24
namespace mlir {
25
+ #define GEN_PASS_DEF_SPARSEREINTERPRETMAP
25
26
#define GEN_PASS_DEF_PRESPARSIFICATIONREWRITE
26
27
#define GEN_PASS_DEF_SPARSIFICATIONPASS
27
28
#define GEN_PASS_DEF_POSTSPARSIFICATIONREWRITE
@@ -44,9 +45,21 @@ namespace {
44
45
// Passes implementation.
45
46
// ===----------------------------------------------------------------------===//
46
47
48
+ struct SparseReinterpretMap
49
+ : public impl::SparseReinterpretMapBase<SparseReinterpretMap> {
50
+ SparseReinterpretMap () = default ;
51
+ SparseReinterpretMap (const SparseReinterpretMap &pass) = default ;
52
+
53
+ void runOnOperation () override {
54
+ auto *ctx = &getContext ();
55
+ RewritePatternSet patterns (ctx);
56
+ populateSparseReinterpretMap (patterns);
57
+ (void )applyPatternsAndFoldGreedily (getOperation (), std::move (patterns));
58
+ }
59
+ };
60
+
47
61
struct PreSparsificationRewritePass
48
62
: public impl::PreSparsificationRewriteBase<PreSparsificationRewritePass> {
49
-
50
63
PreSparsificationRewritePass () = default ;
51
64
PreSparsificationRewritePass (const PreSparsificationRewritePass &pass) =
52
65
default ;
@@ -61,7 +74,6 @@ struct PreSparsificationRewritePass
61
74
62
75
struct SparsificationPass
63
76
: public impl::SparsificationPassBase<SparsificationPass> {
64
-
65
77
SparsificationPass () = default ;
66
78
SparsificationPass (const SparsificationPass &pass) = default ;
67
79
SparsificationPass (const SparsificationOptions &options) {
@@ -108,7 +120,6 @@ struct StageSparseOperationsPass
108
120
struct PostSparsificationRewritePass
109
121
: public impl::PostSparsificationRewriteBase<
110
122
PostSparsificationRewritePass> {
111
-
112
123
PostSparsificationRewritePass () = default ;
113
124
PostSparsificationRewritePass (const PostSparsificationRewritePass &pass) =
114
125
default ;
@@ -129,7 +140,6 @@ struct PostSparsificationRewritePass
129
140
130
141
struct SparseTensorConversionPass
131
142
: public impl::SparseTensorConversionPassBase<SparseTensorConversionPass> {
132
-
133
143
SparseTensorConversionPass () = default ;
134
144
SparseTensorConversionPass (const SparseTensorConversionPass &pass) = default ;
135
145
@@ -200,7 +210,6 @@ struct SparseTensorConversionPass
200
210
201
211
struct SparseTensorCodegenPass
202
212
: public impl::SparseTensorCodegenBase<SparseTensorCodegenPass> {
203
-
204
213
SparseTensorCodegenPass () = default ;
205
214
SparseTensorCodegenPass (const SparseTensorCodegenPass &pass) = default ;
206
215
SparseTensorCodegenPass (bool createDeallocs, bool enableInit) {
@@ -266,7 +275,6 @@ struct SparseTensorCodegenPass
266
275
267
276
struct SparseBufferRewritePass
268
277
: public impl::SparseBufferRewriteBase<SparseBufferRewritePass> {
269
-
270
278
SparseBufferRewritePass () = default ;
271
279
SparseBufferRewritePass (const SparseBufferRewritePass &pass) = default ;
272
280
SparseBufferRewritePass (bool enableInit) {
@@ -283,7 +291,6 @@ struct SparseBufferRewritePass
283
291
284
292
struct SparseVectorizationPass
285
293
: public impl::SparseVectorizationBase<SparseVectorizationPass> {
286
-
287
294
SparseVectorizationPass () = default ;
288
295
SparseVectorizationPass (const SparseVectorizationPass &pass) = default ;
289
296
SparseVectorizationPass (unsigned vl, bool vla, bool sidx32) {
@@ -306,7 +313,6 @@ struct SparseVectorizationPass
306
313
307
314
struct SparseGPUCodegenPass
308
315
: public impl::SparseGPUCodegenBase<SparseGPUCodegenPass> {
309
-
310
316
SparseGPUCodegenPass () = default ;
311
317
SparseGPUCodegenPass (const SparseGPUCodegenPass &pass) = default ;
312
318
SparseGPUCodegenPass (unsigned nT) { numThreads = nT; }
@@ -321,7 +327,6 @@ struct SparseGPUCodegenPass
321
327
322
328
struct StorageSpecifierToLLVMPass
323
329
: public impl::StorageSpecifierToLLVMBase<StorageSpecifierToLLVMPass> {
324
-
325
330
StorageSpecifierToLLVMPass () = default ;
326
331
327
332
void runOnOperation () override {
@@ -363,6 +368,10 @@ struct StorageSpecifierToLLVMPass
363
368
// Pass creation methods.
364
369
// ===----------------------------------------------------------------------===//
365
370
371
+ std::unique_ptr<Pass> mlir::createSparseReinterpretMapPass () {
372
+ return std::make_unique<SparseReinterpretMap>();
373
+ }
374
+
366
375
std::unique_ptr<Pass> mlir::createPreSparsificationRewritePass () {
367
376
return std::make_unique<PreSparsificationRewritePass>();
368
377
}
0 commit comments