Skip to content

Commit 0e34dbb

Browse files
authored
[mlir][sparse] fix bug with all-dense assembler (#108615)
When only all-dense "sparse" tensors occur in a function prototype, the assembler would skip the method conversion purely based on input/output counts. It should rewrite based on the presence of any annotation, however.
1 parent 459a82e commit 0e34dbb

File tree

7 files changed

+124
-11
lines changed

7 files changed

+124
-11
lines changed

mlir/lib/Dialect/SparseTensor/Transforms/SparseAssembler.cpp

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -24,14 +24,16 @@ using namespace sparse_tensor;
2424
//===----------------------------------------------------------------------===//
2525

2626
// Convert type range to new types range, with sparse tensors externalized.
27-
static void convTypes(TypeRange types, SmallVectorImpl<Type> &convTypes,
27+
static void convTypes(bool &hasAnnotation, TypeRange types,
28+
SmallVectorImpl<Type> &convTypes,
2829
SmallVectorImpl<Type> *extraTypes, bool directOut) {
2930
for (auto type : types) {
3031
// All "dense" data passes through unmodified.
3132
if (!getSparseTensorEncoding(type)) {
3233
convTypes.push_back(type);
3334
continue;
3435
}
36+
hasAnnotation = true;
3537

3638
// Convert the external representations of the pos/crd/val arrays.
3739
const SparseTensorType stt(cast<RankedTensorType>(type));
@@ -176,12 +178,14 @@ struct SparseFuncAssembler : public OpRewritePattern<func::FuncOp> {
176178
SmallVector<Type> inputTypes;
177179
SmallVector<Type> outputTypes;
178180
SmallVector<Type> extraTypes;
179-
convTypes(funcOp.getArgumentTypes(), inputTypes, nullptr, false);
180-
convTypes(funcOp.getResultTypes(), outputTypes, &extraTypes, directOut);
181+
bool hasAnnotation = false;
182+
convTypes(hasAnnotation, funcOp.getArgumentTypes(), inputTypes, nullptr,
183+
false);
184+
convTypes(hasAnnotation, funcOp.getResultTypes(), outputTypes, &extraTypes,
185+
directOut);
181186

182187
// Only sparse inputs or outputs need a wrapper method.
183-
if (inputTypes.size() == funcOp.getArgumentTypes().size() &&
184-
outputTypes.size() == funcOp.getResultTypes().size())
188+
if (!hasAnnotation)
185189
return failure();
186190

187191
// Modify the original method into an internal, private method.

mlir/test/Integration/Dialect/SparseTensor/python/test_SDDMM.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -163,7 +163,10 @@ def main():
163163
)
164164
opt = f"parallelization-strategy=none"
165165
compiler = sparsifier.Sparsifier(
166-
options=opt, opt_level=0, shared_libs=[support_lib]
166+
extras="",
167+
options=opt,
168+
opt_level=0,
169+
shared_libs=[support_lib],
167170
)
168171
build_compile_and_run_SDDMMM(attr, compiler)
169172
count = count + 1

mlir/test/Integration/Dialect/SparseTensor/python/test_SpMM.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -141,7 +141,7 @@ def main():
141141
]
142142
bitwidths = [0]
143143
compiler = sparsifier.Sparsifier(
144-
options=opt, opt_level=0, shared_libs=[support_lib]
144+
extras="", options=opt, opt_level=0, shared_libs=[support_lib]
145145
)
146146
for level in levels:
147147
for ordering in orderings:
Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,94 @@
1+
# RUN: env SUPPORT_LIB=%mlir_c_runner_utils \
2+
# RUN: %PYTHON %s | FileCheck %s
3+
4+
import ctypes
5+
import os
6+
import sys
7+
import tempfile
8+
9+
from mlir import ir
10+
from mlir import runtime as rt
11+
from mlir.dialects import builtin
12+
from mlir.dialects import sparse_tensor as st
13+
import numpy as np
14+
15+
_SCRIPT_PATH = os.path.dirname(os.path.abspath(__file__))
16+
sys.path.append(_SCRIPT_PATH)
17+
from tools import sparsifier
18+
19+
20+
def boilerplate():
21+
"""Returns boilerplate main method."""
22+
return """
23+
#Dense = #sparse_tensor.encoding<{
24+
map = (i, j) -> (i: dense, j: dense)
25+
}>
26+
27+
#map = affine_map<(d0, d1) -> (d0, d1)>
28+
func.func @add(%st_0 : tensor<3x4xf64, #Dense>,
29+
%st_1 : tensor<3x4xf64, #Dense>) attributes { llvm.emit_c_interface } {
30+
%out_st = tensor.empty() : tensor<3x4xf64, #Dense>
31+
%res = linalg.generic {indexing_maps = [#map, #map, #map],
32+
iterator_types = ["parallel", "parallel"]}
33+
ins(%st_0, %st_1 : tensor<3x4xf64, #Dense>, tensor<3x4xf64, #Dense>)
34+
outs(%out_st : tensor<3x4xf64, #Dense>) {
35+
^bb0(%in_0: f64, %in_1: f64, %out: f64):
36+
%2 = sparse_tensor.binary %in_0, %in_1 : f64, f64 to f64
37+
overlap = {
38+
^bb0(%arg1: f64, %arg2: f64):
39+
%3 = arith.addf %arg1, %arg2 : f64
40+
sparse_tensor.yield %3 : f64
41+
}
42+
left = {
43+
^bb0(%arg1: f64):
44+
sparse_tensor.yield %arg1 : f64
45+
}
46+
right = {
47+
^bb0(%arg1: f64):
48+
sparse_tensor.yield %arg1 : f64
49+
}
50+
linalg.yield %2 : f64
51+
} -> tensor<3x4xf64, #Dense>
52+
sparse_tensor.print %res : tensor<3x4xf64, #Dense>
53+
return
54+
}
55+
"""
56+
57+
58+
def main():
59+
support_lib = os.getenv("SUPPORT_LIB")
60+
assert support_lib is not None, "SUPPORT_LIB is undefined"
61+
if not os.path.exists(support_lib):
62+
raise FileNotFoundError(errno.ENOENT, os.strerror(errno.ENOENT), support_lib)
63+
64+
# CHECK-LABEL: TEST: all dense
65+
# CHECK: ---- Sparse Tensor ----
66+
# CHECK: nse = 12
67+
# CHECK: dim = ( 3, 4 )
68+
# CHECK: lvl = ( 3, 4 )
69+
# CHECK: values : ( 1, 1, 0, 1, 0, 6, 2, 3, 0, 0, 0, 2 )
70+
# CHECK: ----
71+
print("\nTEST: all dense")
72+
with ir.Context() as ctx, ir.Location.unknown():
73+
compiler = sparsifier.Sparsifier(
74+
extras="sparse-assembler,",
75+
options="enable-runtime-library=false",
76+
opt_level=2,
77+
shared_libs=[support_lib],
78+
)
79+
module = ir.Module.parse(boilerplate())
80+
engine = compiler.compile_and_jit(module)
81+
print(module)
82+
83+
a = np.array([1, 0, 0, 1, 0, 2, 2, 0, 0, 0, 0, 1], dtype=np.float64)
84+
b = np.array([0, 1, 0, 0, 0, 4, 0, 3, 0, 0, 0, 1], dtype=np.float64)
85+
mem_a = ctypes.pointer(ctypes.pointer(rt.get_ranked_memref_descriptor(a)))
86+
mem_b = ctypes.pointer(ctypes.pointer(rt.get_ranked_memref_descriptor(b)))
87+
88+
# Invoke the kernel and get numpy output.
89+
# Built-in bufferization uses in-out buffers.
90+
engine.invoke("add", mem_a, mem_b)
91+
92+
93+
if __name__ == "__main__":
94+
main()

mlir/test/Integration/Dialect/SparseTensor/python/test_output.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -139,7 +139,7 @@ def main():
139139
]
140140
bitwidths = [8, 64]
141141
compiler = sparsifier.Sparsifier(
142-
options="", opt_level=2, shared_libs=[support_lib]
142+
extras="", options="", opt_level=2, shared_libs=[support_lib]
143143
)
144144
for level in levels:
145145
for ordering, id_map in orderings:

mlir/test/Integration/Dialect/SparseTensor/python/test_stress.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -195,7 +195,10 @@ def main():
195195
with ir.Context() as ctx, ir.Location.unknown():
196196
sparsification_options = f"parallelization-strategy=none "
197197
compiler = sparsifier.Sparsifier(
198-
options=sparsification_options, opt_level=0, shared_libs=[support_lib]
198+
extras="",
199+
options=sparsification_options,
200+
opt_level=0,
201+
shared_libs=[support_lib],
199202
)
200203
f64 = ir.F64Type.get()
201204
# Be careful about increasing this because

mlir/test/Integration/Dialect/SparseTensor/python/tools/sparsifier.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,17 @@
1313
class Sparsifier:
1414
"""Sparsifier class for compiling and building MLIR modules."""
1515

16-
def __init__(self, options: str, opt_level: int, shared_libs: Sequence[str]):
17-
pipeline = f"builtin.module(sparsifier{{{options} reassociate-fp-reductions=1 enable-index-optimizations=1}})"
16+
def __init__(
17+
self,
18+
extras: str,
19+
options: str,
20+
opt_level: int,
21+
shared_libs: Sequence[str],
22+
):
23+
pipeline = (
24+
f"builtin.module({extras}sparsifier{{{options} reassociate-fp-reductions=1"
25+
" enable-index-optimizations=1})"
26+
)
1827
self.pipeline = pipeline
1928
self.opt_level = opt_level
2029
self.shared_libs = shared_libs

0 commit comments

Comments
 (0)