Skip to content

Commit d4088e7

Browse files
[mlir][sparse] Populate lvlToDim (#68937)
Updates: 1. Infer lvlToDim from dimToLvl 2. Add more tests for block sparsity 3. Finish TODOs related to lvlToDim, including adding lvlToDim to python binding Verification of lvlToDim that user provides will be implemented in the next PR.
1 parent 9922aad commit d4088e7

File tree

13 files changed

+177
-22
lines changed

13 files changed

+177
-22
lines changed

mlir/include/mlir-c/Dialect/SparseTensor.h

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -51,11 +51,10 @@ MLIR_CAPI_EXPORTED bool
5151
mlirAttributeIsASparseTensorEncodingAttr(MlirAttribute attr);
5252

5353
/// Creates a `sparse_tensor.encoding` attribute with the given parameters.
54-
/// TODO: add a version that supplied lvlToDim when it cannot be inferred
5554
MLIR_CAPI_EXPORTED MlirAttribute mlirSparseTensorEncodingAttrGet(
5655
MlirContext ctx, intptr_t lvlRank,
5756
enum MlirSparseTensorDimLevelType const *lvlTypes, MlirAffineMap dimToLvl,
58-
int posWidth, int crdWidth);
57+
MlirAffineMap lvlTodim, int posWidth, int crdWidth);
5958

6059
/// Returns the level-rank of the `sparse_tensor.encoding` attribute.
6160
MLIR_CAPI_EXPORTED intptr_t

mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -160,6 +160,19 @@ inline bool hasAnySparseOperandOrResult(Operation *op) {
160160
return hasAnySparseOperand(op) || hasAnySparseResult(op);
161161
}
162162

163+
//
164+
// Inference.
165+
//
166+
167+
/// Given the dimToLvl map, infers the lvlToDim map, or returns
168+
/// empty Affine map when inference fails.
169+
AffineMap inferLvlToDim(AffineMap dimToLvl, MLIRContext *context);
170+
171+
/// Returns the lvlToDim map for the given dimToLvl map specific
172+
/// to the block sparse cases.
173+
/// Asserts on failure (so only use when known to succeed).
174+
AffineMap inverseBlockSparsity(AffineMap dimToLvl, MLIRContext *context);
175+
163176
//
164177
// Reordering.
165178
//

mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -307,6 +307,9 @@ def SparseTensorEncodingAttr : SparseTensor_Attr<"SparseTensorEncoding",
307307
"AffineMap":$lvlToDim,
308308
"unsigned":$posWidth,
309309
"unsigned":$crdWidth), [{
310+
if (!lvlToDim) {
311+
lvlToDim = ::mlir::sparse_tensor::inferLvlToDim(dimToLvl, $_ctxt);
312+
}
310313
return $_get($_ctxt, lvlTypes, dimToLvl, lvlToDim, posWidth, crdWidth,
311314
ArrayRef<::mlir::sparse_tensor::SparseTensorDimSliceAttr>{});
312315
}]>

mlir/lib/Bindings/Python/DialectSparseTensor.cpp

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -41,16 +41,17 @@ static void populateDialectSparseTensorSubmodule(const py::module &m) {
4141
.def_classmethod(
4242
"get",
4343
[](py::object cls, std::vector<MlirSparseTensorDimLevelType> lvlTypes,
44-
std::optional<MlirAffineMap> dimToLvl, int posWidth, int crdWidth,
44+
std::optional<MlirAffineMap> dimToLvl,
45+
std::optional<MlirAffineMap> lvlToDim, int posWidth, int crdWidth,
4546
MlirContext context) {
46-
// TODO: provide dimToLvl
4747
return cls(mlirSparseTensorEncodingAttrGet(
4848
context, lvlTypes.size(), lvlTypes.data(),
49-
dimToLvl ? *dimToLvl : MlirAffineMap{nullptr}, posWidth,
49+
dimToLvl ? *dimToLvl : MlirAffineMap{nullptr},
50+
lvlToDim ? *lvlToDim : MlirAffineMap{nullptr}, posWidth,
5051
crdWidth));
5152
},
5253
py::arg("cls"), py::arg("lvl_types"), py::arg("dim_to_lvl"),
53-
py::arg("pos_width"), py::arg("crd_width"),
54+
py::arg("lvl_to_dim"), py::arg("pos_width"), py::arg("crd_width"),
5455
py::arg("context") = py::none(),
5556
"Gets a sparse_tensor.encoding from parameters.")
5657
.def_property_readonly(
@@ -71,6 +72,14 @@ static void populateDialectSparseTensorSubmodule(const py::module &m) {
7172
return {};
7273
return ret;
7374
})
75+
.def_property_readonly(
76+
"lvl_to_dim",
77+
[](MlirAttribute self) -> std::optional<MlirAffineMap> {
78+
MlirAffineMap ret = mlirSparseTensorEncodingAttrGetLvlToDim(self);
79+
if (mlirAffineMapIsNull(ret))
80+
return {};
81+
return ret;
82+
})
7483
.def_property_readonly("pos_width",
7584
mlirSparseTensorEncodingAttrGetPosWidth)
7685
.def_property_readonly("crd_width",

mlir/lib/CAPI/Dialect/SparseTensor.cpp

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -48,15 +48,14 @@ bool mlirAttributeIsASparseTensorEncodingAttr(MlirAttribute attr) {
4848
MlirAttribute
4949
mlirSparseTensorEncodingAttrGet(MlirContext ctx, intptr_t lvlRank,
5050
MlirSparseTensorDimLevelType const *lvlTypes,
51-
MlirAffineMap dimToLvl, int posWidth,
52-
int crdWidth) {
51+
MlirAffineMap dimToLvl, MlirAffineMap lvlToDim,
52+
int posWidth, int crdWidth) {
5353
SmallVector<DimLevelType> cppLvlTypes;
5454
cppLvlTypes.reserve(lvlRank);
5555
for (intptr_t l = 0; l < lvlRank; ++l)
5656
cppLvlTypes.push_back(static_cast<DimLevelType>(lvlTypes[l]));
57-
mlir::AffineMap lvlToDim; // TODO: provide in API
5857
return wrap(SparseTensorEncodingAttr::get(unwrap(ctx), cppLvlTypes,
59-
unwrap(dimToLvl), lvlToDim,
58+
unwrap(dimToLvl), unwrap(lvlToDim),
6059
posWidth, crdWidth));
6160
}
6261

mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp

Lines changed: 73 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -293,9 +293,8 @@ Type SparseTensorEncodingAttr::getCrdType() const {
293293
SparseTensorEncodingAttr
294294
SparseTensorEncodingAttr::withDimToLvl(AffineMap dimToLvl) const {
295295
assert(getImpl() && "Uninitialized SparseTensorEncodingAttr");
296-
// TODO: infer lvlToDim
297296
return SparseTensorEncodingAttr::get(getContext(), getLvlTypes(), dimToLvl,
298-
/*lvlToDim*/ AffineMap(), getPosWidth(),
297+
getLvlToDim(), getPosWidth(),
299298
getCrdWidth());
300299
}
301300

@@ -583,7 +582,8 @@ Attribute SparseTensorEncodingAttr::parse(AsmParser &parser, Type type) {
583582
#undef RETURN_ON_FAIL
584583

585584
// Construct struct-like storage for attribute.
586-
AffineMap lvlToDim; // TODO: infer
585+
// TODO: Fetch lvlToDim if user provides one
586+
AffineMap lvlToDim = inferLvlToDim(dimToLvl, parser.getContext());
587587
return parser.getChecked<SparseTensorEncodingAttr>(
588588
parser.getContext(), lvlTypes, dimToLvl, lvlToDim, posWidth, crdWidth,
589589
dimSlices);
@@ -749,6 +749,75 @@ mlir::sparse_tensor::getSparseTensorEncoding(Type type) {
749749
return nullptr;
750750
}
751751

752+
AffineMap mlir::sparse_tensor::inferLvlToDim(AffineMap dimToLvl,
753+
MLIRContext *context) {
754+
auto map = static_cast<AffineMap>(dimToLvl);
755+
AffineMap lvlToDim;
756+
// Return an empty lvlToDim when inference is not successful.
757+
if (!map || map.getNumSymbols() != 0) {
758+
lvlToDim = AffineMap();
759+
} else if (map.isPermutation()) {
760+
lvlToDim = inversePermutation(map);
761+
} else {
762+
// TODO: check if it's block sparsity
763+
lvlToDim = inverseBlockSparsity(map, context);
764+
}
765+
return lvlToDim;
766+
}
767+
768+
AffineMap mlir::sparse_tensor::inverseBlockSparsity(AffineMap dimToLvl,
769+
MLIRContext *context) {
770+
SmallVector<AffineExpr> lvlExprs;
771+
auto numLvls = dimToLvl.getNumResults();
772+
lvlExprs.reserve(numLvls);
773+
// lvlExprComponents stores information of the floordiv and mod operations
774+
// applied to the same dimension, so as to build the lvlToDim map.
775+
std::map<unsigned, SmallVector<AffineExpr, 3>> lvlExprComponents;
776+
for (unsigned i = 0, n = numLvls; i < n; i++) {
777+
auto result = dimToLvl.getResult(i);
778+
if (auto binOp = result.dyn_cast<AffineBinaryOpExpr>()) {
779+
if (result.getKind() == AffineExprKind::FloorDiv) {
780+
// Position of the dimension in dimToLvl.
781+
auto pos = binOp.getLHS().dyn_cast<AffineDimExpr>().getPosition();
782+
assert(lvlExprComponents.find(pos) == lvlExprComponents.end() &&
783+
"expected only one floordiv for each dimension");
784+
SmallVector<AffineExpr, 3> components;
785+
// Level variable for floordiv.
786+
components.push_back(getAffineDimExpr(i, context));
787+
// Multiplier.
788+
components.push_back(binOp.getRHS());
789+
// Map key is the position of the dimension.
790+
lvlExprComponents[pos] = components;
791+
} else if (result.getKind() == AffineExprKind::Mod) {
792+
auto pos = binOp.getLHS().dyn_cast<AffineDimExpr>().getPosition();
793+
assert(lvlExprComponents.find(pos) != lvlExprComponents.end() &&
794+
"expected floordiv before mod");
795+
// Add level variable for mod to the same vector
796+
// of the corresponding floordiv.
797+
lvlExprComponents[pos].push_back(getAffineDimExpr(i, context));
798+
} else {
799+
assert(false && "expected floordiv or mod");
800+
}
801+
} else {
802+
lvlExprs.push_back(getAffineDimExpr(i, context));
803+
}
804+
}
805+
// Build lvlExprs from lvlExprComponents.
806+
// For example, for il = i floordiv 2 and ii = i mod 2, the components
807+
// would be [il, 2, ii]. It could be used to build the AffineExpr
808+
// i = il * 2 + ii in lvlToDim.
809+
for (auto &components : lvlExprComponents) {
810+
assert(components.second.size() == 3 &&
811+
"expected 3 components to build lvlExprs");
812+
auto mulOp = getAffineBinaryOpExpr(
813+
AffineExprKind::Mul, components.second[0], components.second[1]);
814+
auto addOp =
815+
getAffineBinaryOpExpr(AffineExprKind::Add, mulOp, components.second[2]);
816+
lvlExprs.push_back(addOp);
817+
}
818+
return dimToLvl.get(dimToLvl.getNumResults(), 0, lvlExprs, context);
819+
}
820+
752821
bool mlir::sparse_tensor::isCOOType(SparseTensorEncodingAttr enc,
753822
Level startLvl, bool isUnique) {
754823
if (!enc ||
@@ -811,7 +880,7 @@ RankedTensorType sparse_tensor::getCOOFromTypeWithOrdering(RankedTensorType rtt,
811880
// default value.
812881
unsigned posWidth = src.getPosWidth();
813882
unsigned crdWidth = src.getCrdWidth();
814-
AffineMap invPerm; // TODO
883+
AffineMap invPerm = src.getLvlToDim();
815884
auto enc = SparseTensorEncodingAttr::get(src.getContext(), lvlTypes, lvlPerm,
816885
invPerm, posWidth, crdWidth);
817886
return RankedTensorType::get(src.getDimShape(), src.getElementType(), enc);

mlir/test/CAPI/sparse_tensor.c

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,8 @@ static int testRoundtripEncoding(MlirContext ctx) {
4040
// CHECK: level_type: 4
4141
// CHECK: level_type: 8
4242
// CHECK: level_type: 8
43+
MlirAffineMap lvlToDim =
44+
mlirSparseTensorEncodingAttrGetLvlToDim(originalAttr);
4345
int lvlRank = mlirSparseTensorEncodingGetLvlRank(originalAttr);
4446
enum MlirSparseTensorDimLevelType *lvlTypes =
4547
malloc(sizeof(enum MlirSparseTensorDimLevelType) * lvlRank);
@@ -53,9 +55,8 @@ static int testRoundtripEncoding(MlirContext ctx) {
5355
// CHECK: crdWidth: 64
5456
int crdWidth = mlirSparseTensorEncodingAttrGetCrdWidth(originalAttr);
5557
fprintf(stderr, "crdWidth: %d\n", crdWidth);
56-
// TODO: lvlToDim
5758
MlirAttribute newAttr = mlirSparseTensorEncodingAttrGet(
58-
ctx, lvlRank, lvlTypes, dimToLvl, posWidth, crdWidth);
59+
ctx, lvlRank, lvlTypes, dimToLvl, lvlToDim, posWidth, crdWidth);
5960
mlirAttributeDump(newAttr); // For debugging filecheck output.
6061
// CHECK: equal: 1
6162
fprintf(stderr, "equal: %d\n", mlirAttributeEqual(originalAttr, newAttr));

mlir/test/Dialect/SparseTensor/roundtrip_encoding.mlir

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -160,6 +160,24 @@ func.func private @BSR(%arg0: tensor<?x?xf64, #BSR>) {
160160

161161
// -----
162162

163+
#BCSR = #sparse_tensor.encoding<{
164+
map = ( i, j, k ) ->
165+
( i floordiv 2 : dense,
166+
j floordiv 3 : dense,
167+
k floordiv 4 : compressed,
168+
i mod 2 : dense,
169+
j mod 3 : dense,
170+
k mod 4 : dense
171+
)
172+
}>
173+
174+
// CHECK-LABEL: func private @BCSR(
175+
// CHECK-SAME: tensor<?x?x?xf64, #sparse_tensor.encoding<{ map = (d0, d1, d2) -> (d0 floordiv 2 : dense, d1 floordiv 3 : dense, d2 floordiv 4 : compressed, d0 mod 2 : dense, d1 mod 3 : dense, d2 mod 4 : dense) }>>
176+
func.func private @BCSR(%arg0: tensor<?x?x?xf64, #BCSR>) {
177+
return
178+
}
179+
// -----
180+
163181
#BSR_explicit = #sparse_tensor.encoding<{
164182
map =
165183
{il, jl, ii, jj}
@@ -194,3 +212,37 @@ func.func private @BSR_explicit(%arg0: tensor<?x?xf64, #BSR_explicit>) {
194212
func.func private @NV_24(%arg0: tensor<?x?xf64, #NV_24>) {
195213
return
196214
}
215+
216+
// -----
217+
218+
#NV_24 = #sparse_tensor.encoding<{
219+
map = ( i, j, k ) ->
220+
( i : dense,
221+
j : dense,
222+
k floordiv 4 : dense,
223+
k mod 4 : block2_4
224+
)
225+
}>
226+
227+
// CHECK-LABEL: func private @NV_24(
228+
// CHECK-SAME: tensor<?x?x?xf64, #sparse_tensor.encoding<{ map = (d0, d1, d2) -> (d0 : dense, d1 : dense, d2 floordiv 4 : dense, d2 mod 4 : block2_4) }>>
229+
func.func private @NV_24(%arg0: tensor<?x?x?xf64, #NV_24>) {
230+
return
231+
}
232+
233+
// -----
234+
235+
#NV_24 = #sparse_tensor.encoding<{
236+
map = ( i, j, k ) ->
237+
( i : dense,
238+
k floordiv 4 : dense,
239+
j : dense,
240+
k mod 4 : block2_4
241+
)
242+
}>
243+
244+
// CHECK-LABEL: func private @NV_24(
245+
// CHECK-SAME: tensor<?x?x?xf64, #sparse_tensor.encoding<{ map = (d0, d1, d2) -> (d0 : dense, d2 floordiv 4 : dense, d1 : dense, d2 mod 4 : block2_4) }>>
246+
func.func private @NV_24(%arg0: tensor<?x?x?xf64, #NV_24>) {
247+
return
248+
}

mlir/test/Integration/Dialect/SparseTensor/python/test_SDDMM.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -155,7 +155,7 @@ def main():
155155
for iwidth in [32]:
156156
for e in [True]:
157157
attr = st.EncodingAttr.get(
158-
level, ordering, pwidth, iwidth
158+
level, ordering, None, pwidth, iwidth
159159
)
160160
opt = f"parallelization-strategy=none"
161161
compiler = sparse_compiler.SparseCompiler(

mlir/test/Integration/Dialect/SparseTensor/python/test_SpMM.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -145,7 +145,7 @@ def main():
145145
for pwidth in bitwidths:
146146
for iwidth in bitwidths:
147147
attr = st.EncodingAttr.get(
148-
level, ordering, pwidth, iwidth
148+
level, ordering, None, pwidth, iwidth
149149
)
150150
build_compile_and_run_SpMM(attr, compiler)
151151
count = count + 1

0 commit comments

Comments
 (0)