Skip to content

Commit b165650

Browse files
[mlir][sparse] Return actual identity map instead of null map (#70365)
Changes: 1. For both dimToLvl and lvlToDim, always returns the actual map instead of AffineMap() for identity map. 2. Updated custom builder for encoding to have default values. 3. Non-inferable lvlToDim will still return AffineMap() during inference, so it will be caught by verifier.
1 parent f74f213 commit b165650

File tree

4 files changed

+14
-10
lines changed

4 files changed

+14
-10
lines changed

mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -303,10 +303,13 @@ def SparseTensorEncodingAttr : SparseTensor_Attr<"SparseTensorEncoding",
303303

304304
let builders = [
305305
AttrBuilder<(ins "ArrayRef<::mlir::sparse_tensor::DimLevelType>":$lvlTypes,
306-
"AffineMap":$dimToLvl,
307-
"AffineMap":$lvlToDim,
308-
"unsigned":$posWidth,
309-
"unsigned":$crdWidth), [{
306+
CArg<"AffineMap", "{}">:$dimToLvl,
307+
CArg<"AffineMap", "{}">:$lvlToDim,
308+
CArg<"unsigned", "0">:$posWidth,
309+
CArg<"unsigned", "0">:$crdWidth), [{
310+
if (!dimToLvl) {
311+
dimToLvl = ::mlir::AffineMap::getMultiDimIdentityMap(lvlTypes.size(), $_ctxt);
312+
}
310313
if (!lvlToDim) {
311314
lvlToDim = ::mlir::sparse_tensor::inferLvlToDim(dimToLvl, $_ctxt);
312315
}

mlir/lib/Dialect/SparseTensor/IR/Detail/DimLvlMap.cpp

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -313,8 +313,7 @@ AffineMap DimLvlMap::getDimToLvlMap(MLIRContext *context) const {
313313
lvlAffines.reserve(getLvlRank());
314314
for (const auto &lvlSpec : lvlSpecs)
315315
lvlAffines.push_back(lvlSpec.getExpr().getAffineExpr());
316-
auto map = AffineMap::get(getDimRank(), getSymRank(), lvlAffines, context);
317-
if (map.isIdentity()) return AffineMap();
316+
auto map = AffineMap::get(getDimRank(), getSymRank(), lvlAffines, context);
318317
return map;
319318
}
320319

@@ -328,7 +327,9 @@ AffineMap DimLvlMap::getLvlToDimMap(MLIRContext *context) const {
328327
}
329328
}
330329
auto map = AffineMap::get(getLvlRank(), getSymRank(), dimAffines, context);
331-
if (dimAffines.empty() || map.isIdentity())
330+
// If no lvlToDim map was passed in, returns a null AffineMap and infers it
331+
// in SparseTensorEncodingAttr::parse.
332+
if (dimAffines.empty())
332333
return AffineMap();
333334
return map;
334335
}

mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -291,7 +291,7 @@ SparseTensorEncodingAttr
291291
SparseTensorEncodingAttr::withDimToLvl(AffineMap dimToLvl) const {
292292
assert(getImpl() && "Uninitialized SparseTensorEncodingAttr");
293293
return SparseTensorEncodingAttr::get(getContext(), getLvlTypes(), dimToLvl,
294-
getLvlToDim(), getPosWidth(),
294+
AffineMap(), getPosWidth(),
295295
getCrdWidth());
296296
}
297297

mlir/test/python/dialects/sparse_tensor/dialect.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,9 +30,9 @@ def testEncodingAttr1D():
3030

3131
# CHECK: lvl_types: [<DimLevelType.compressed: 8>]
3232
print(f"lvl_types: {casted.lvl_types}")
33-
# CHECK: dim_to_lvl: None
33+
# CHECK: dim_to_lvl: (d0) -> (d0)
3434
print(f"dim_to_lvl: {casted.dim_to_lvl}")
35-
# CHECK: lvl_to_dim: None
35+
# CHECK: lvl_to_dim: (d0) -> (d0)
3636
print(f"lvl_to_dim: {casted.lvl_to_dim}")
3737
# CHECK: pos_width: 16
3838
print(f"pos_width: {casted.pos_width}")

0 commit comments

Comments
 (0)