Skip to content

Commit 7cfac1b

Browse files
authored
[mlir][sparse] add boilterplate code for a new reintepret map pass (#70393)
The interesting stuff is of course still coming ;-)
1 parent 0359a78 commit 7cfac1b

File tree

6 files changed

+77
-9
lines changed

6 files changed

+77
-9
lines changed

mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,14 @@ enum class GPUDataTransferStrategy { kRegularDMA, kZeroCopy, kPinnedDMA };
4747
#define GEN_PASS_DECL
4848
#include "mlir/Dialect/SparseTensor/Transforms/Passes.h.inc"
4949

50+
//===----------------------------------------------------------------------===//
51+
// The SparseReinterpretMap pass.
52+
//===----------------------------------------------------------------------===//
53+
54+
void populateSparseReinterpretMap(RewritePatternSet &patterns);
55+
56+
std::unique_ptr<Pass> createSparseReinterpretMapPass();
57+
5058
//===----------------------------------------------------------------------===//
5159
// The PreSparsificationRewriting pass.
5260
//===----------------------------------------------------------------------===//

mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,24 @@
1111

1212
include "mlir/Pass/PassBase.td"
1313

14+
def SparseReinterpretMap : Pass<"sparse-reinterpret-map", "ModuleOp"> {
15+
let summary = "Reinterprets sparse tensor type mappings";
16+
let description = [{
17+
A pass that reinterprets the mappings in all sparse tensor types in a way that
18+
enables subsequent sparification. This involves expressing all `linalg.generic`
19+
operations in terms of level coordinates (rather than the dimension coordinates
20+
of the input tensors) to align the iteration space with the potentially remapped
21+
level space as well as resolving cycles in the resulting iteration graphs with
22+
explicit sparse tensor conversions where needed.
23+
}];
24+
let constructor = "mlir::createSparseReinterpretMapPass()";
25+
let dependentDialects = [
26+
"affine::AffineDialect",
27+
"linalg::LinalgDialect",
28+
"sparse_tensor::SparseTensorDialect",
29+
];
30+
}
31+
1432
def PreSparsificationRewrite : Pass<"pre-sparsification-rewrite", "ModuleOp"> {
1533
let summary = "Applies sparse tensor rewriting rules prior to sparsification";
1634
let description = [{

mlir/lib/Dialect/SparseTensor/Transforms/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ add_mlir_dialect_library(MLIRSparseTensorTransforms
55
LoopEmitter.cpp
66
SparseBufferRewriting.cpp
77
SparseGPUCodegen.cpp
8+
SparseReinterpretMap.cpp
89
SparseStorageSpecifierToLLVM.cpp
910
SparseTensorCodegen.cpp
1011
SparseTensorConversion.cpp
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
//===- SparseReinterpretMap.cpp - reinterpret sparse tensor maps ----------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
10+
#include "mlir/Dialect/SparseTensor/IR/SparseTensorType.h"
11+
#include "mlir/Dialect/SparseTensor/Transforms/Passes.h"
12+
13+
namespace {
14+
15+
// TODO:
16+
// (1) insert the zero-cost sparse_tensor.reinterpret_map ops
17+
// (2) rewrite linalg.generic ops traits on level crds
18+
// (3) compute topsort, and resolve cyles with sparse_tensor.convert ops
19+
20+
} // namespace
21+
22+
void mlir::populateSparseReinterpretMap(RewritePatternSet &patterns) {}

mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp

Lines changed: 18 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
2323

2424
namespace mlir {
25+
#define GEN_PASS_DEF_SPARSEREINTERPRETMAP
2526
#define GEN_PASS_DEF_PRESPARSIFICATIONREWRITE
2627
#define GEN_PASS_DEF_SPARSIFICATIONPASS
2728
#define GEN_PASS_DEF_POSTSPARSIFICATIONREWRITE
@@ -44,9 +45,21 @@ namespace {
4445
// Passes implementation.
4546
//===----------------------------------------------------------------------===//
4647

48+
struct SparseReinterpretMap
49+
: public impl::SparseReinterpretMapBase<SparseReinterpretMap> {
50+
SparseReinterpretMap() = default;
51+
SparseReinterpretMap(const SparseReinterpretMap &pass) = default;
52+
53+
void runOnOperation() override {
54+
auto *ctx = &getContext();
55+
RewritePatternSet patterns(ctx);
56+
populateSparseReinterpretMap(patterns);
57+
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
58+
}
59+
};
60+
4761
struct PreSparsificationRewritePass
4862
: public impl::PreSparsificationRewriteBase<PreSparsificationRewritePass> {
49-
5063
PreSparsificationRewritePass() = default;
5164
PreSparsificationRewritePass(const PreSparsificationRewritePass &pass) =
5265
default;
@@ -61,7 +74,6 @@ struct PreSparsificationRewritePass
6174

6275
struct SparsificationPass
6376
: public impl::SparsificationPassBase<SparsificationPass> {
64-
6577
SparsificationPass() = default;
6678
SparsificationPass(const SparsificationPass &pass) = default;
6779
SparsificationPass(const SparsificationOptions &options) {
@@ -108,7 +120,6 @@ struct StageSparseOperationsPass
108120
struct PostSparsificationRewritePass
109121
: public impl::PostSparsificationRewriteBase<
110122
PostSparsificationRewritePass> {
111-
112123
PostSparsificationRewritePass() = default;
113124
PostSparsificationRewritePass(const PostSparsificationRewritePass &pass) =
114125
default;
@@ -129,7 +140,6 @@ struct PostSparsificationRewritePass
129140

130141
struct SparseTensorConversionPass
131142
: public impl::SparseTensorConversionPassBase<SparseTensorConversionPass> {
132-
133143
SparseTensorConversionPass() = default;
134144
SparseTensorConversionPass(const SparseTensorConversionPass &pass) = default;
135145

@@ -200,7 +210,6 @@ struct SparseTensorConversionPass
200210

201211
struct SparseTensorCodegenPass
202212
: public impl::SparseTensorCodegenBase<SparseTensorCodegenPass> {
203-
204213
SparseTensorCodegenPass() = default;
205214
SparseTensorCodegenPass(const SparseTensorCodegenPass &pass) = default;
206215
SparseTensorCodegenPass(bool createDeallocs, bool enableInit) {
@@ -266,7 +275,6 @@ struct SparseTensorCodegenPass
266275

267276
struct SparseBufferRewritePass
268277
: public impl::SparseBufferRewriteBase<SparseBufferRewritePass> {
269-
270278
SparseBufferRewritePass() = default;
271279
SparseBufferRewritePass(const SparseBufferRewritePass &pass) = default;
272280
SparseBufferRewritePass(bool enableInit) {
@@ -283,7 +291,6 @@ struct SparseBufferRewritePass
283291

284292
struct SparseVectorizationPass
285293
: public impl::SparseVectorizationBase<SparseVectorizationPass> {
286-
287294
SparseVectorizationPass() = default;
288295
SparseVectorizationPass(const SparseVectorizationPass &pass) = default;
289296
SparseVectorizationPass(unsigned vl, bool vla, bool sidx32) {
@@ -306,7 +313,6 @@ struct SparseVectorizationPass
306313

307314
struct SparseGPUCodegenPass
308315
: public impl::SparseGPUCodegenBase<SparseGPUCodegenPass> {
309-
310316
SparseGPUCodegenPass() = default;
311317
SparseGPUCodegenPass(const SparseGPUCodegenPass &pass) = default;
312318
SparseGPUCodegenPass(unsigned nT) { numThreads = nT; }
@@ -321,7 +327,6 @@ struct SparseGPUCodegenPass
321327

322328
struct StorageSpecifierToLLVMPass
323329
: public impl::StorageSpecifierToLLVMBase<StorageSpecifierToLLVMPass> {
324-
325330
StorageSpecifierToLLVMPass() = default;
326331

327332
void runOnOperation() override {
@@ -363,6 +368,10 @@ struct StorageSpecifierToLLVMPass
363368
// Pass creation methods.
364369
//===----------------------------------------------------------------------===//
365370

371+
std::unique_ptr<Pass> mlir::createSparseReinterpretMapPass() {
372+
return std::make_unique<SparseReinterpretMap>();
373+
}
374+
366375
std::unique_ptr<Pass> mlir::createPreSparsificationRewritePass() {
367376
return std::make_unique<PreSparsificationRewritePass>();
368377
}
Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
// RUN: mlir-opt %s --sparse-reinterpret-map | FileCheck %s
2+
3+
#SparseVector = #sparse_tensor.encoding<{ map = (d0) -> (d0 : compressed) }>
4+
5+
// CHECK-LABEL: func @sparse_nop(
6+
// CHECK-SAME: %[[A0:.*]]: tensor<?xf64, #sparse_tensor.encoding<{{{.*}}}>>)
7+
// CHECK: return %[[A0]]
8+
func.func @sparse_nop(%arg0: tensor<?xf64, #SparseVector>) -> tensor<?xf64, #SparseVector> {
9+
return %arg0 : tensor<?xf64, #SparseVector>
10+
}

0 commit comments

Comments
 (0)