Skip to content

Commit ef22298

Browse files
authored
[mlir][sparse] implements sparse_tensor.reinterpret_map (#70388)
1 parent 56b99f0 commit ef22298

File tree

4 files changed

+51
-4
lines changed

4 files changed

+51
-4
lines changed

mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -453,11 +453,22 @@ SparseTensorEncodingAttr::tranlateShape(ArrayRef<int64_t> srcShape,
453453
// Do constant propagation on the affine map.
454454
AffineExpr evalExp =
455455
simplifyAffineExpr(exp.replaceDims(dimRep), srcShape.size(), 0);
456-
if (auto c = evalExp.dyn_cast<AffineConstantExpr>())
456+
if (auto c = evalExp.dyn_cast<AffineConstantExpr>()) {
457457
ret.push_back(c.getValue() + 1);
458-
else
458+
} else {
459+
if (auto mod = evalExp.dyn_cast<AffineBinaryOpExpr>();
460+
mod && mod.getKind() == AffineExprKind::Mod) {
461+
// We can still infer a static bound for expressions in form
462+
// "d % constant" since d % constant \in [0, constant).
463+
if (auto bound = mod.getRHS().dyn_cast<AffineConstantExpr>()) {
464+
ret.push_back(bound.getValue());
465+
continue;
466+
}
467+
}
459468
ret.push_back(ShapedType::kDynamic);
469+
}
460470
}
471+
assert(ret.size() == rank);
461472
return ret;
462473
}
463474

mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -725,6 +725,18 @@ class SparseCastConverter : public OpConversionPattern<tensor::CastOp> {
725725
}
726726
};
727727

728+
class SparseReMapConverter : public OpConversionPattern<ReinterpretMapOp> {
729+
public:
730+
using OpConversionPattern::OpConversionPattern;
731+
LogicalResult
732+
matchAndRewrite(ReinterpretMapOp op, OpAdaptor adaptor,
733+
ConversionPatternRewriter &rewriter) const override {
734+
// Simply fold the operation.
735+
rewriter.replaceOp(op, adaptor.getSource());
736+
return success();
737+
}
738+
};
739+
728740
/// Sparse codegen rule for the alloc operator.
729741
/// TODO(springerm): remove when bufferization.alloc_tensor is gone
730742
class SparseTensorAllocConverter
@@ -1564,7 +1576,7 @@ void mlir::populateSparseTensorCodegenPatterns(
15641576
SparseCastConverter, SparseExtractSliceConverter,
15651577
SparseTensorLoadConverter, SparseExpandConverter,
15661578
SparseCompressConverter, SparseInsertConverter,
1567-
SparseReorderCOOConverter,
1579+
SparseReorderCOOConverter, SparseReMapConverter,
15681580
SparseSliceGetterOpConverter<ToSliceOffsetOp,
15691581
StorageSpecifierKind::DimOffset>,
15701582
SparseSliceGetterOpConverter<ToSliceStrideOp,

mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -336,6 +336,18 @@ class SparseCastConverter : public OpConversionPattern<tensor::CastOp> {
336336
}
337337
};
338338

339+
class SparseReMapConverter : public OpConversionPattern<ReinterpretMapOp> {
340+
public:
341+
using OpConversionPattern::OpConversionPattern;
342+
LogicalResult
343+
matchAndRewrite(ReinterpretMapOp op, OpAdaptor adaptor,
344+
ConversionPatternRewriter &rewriter) const override {
345+
// Simply fold the operation.
346+
rewriter.replaceOp(op, adaptor.getSource());
347+
return success();
348+
}
349+
};
350+
339351
/// Sparse conversion rule for the new operator.
340352
class SparseTensorNewConverter : public OpConversionPattern<NewOp> {
341353
public:
@@ -770,7 +782,7 @@ void mlir::populateSparseTensorConversionPatterns(TypeConverter &typeConverter,
770782
RewritePatternSet &patterns) {
771783
patterns
772784
.add<SparseReturnConverter, SparseTensorLvlOpConverter,
773-
SparseCastConverter, SparseTensorNewConverter,
785+
SparseCastConverter, SparseReMapConverter, SparseTensorNewConverter,
774786
SparseTensorAllocConverter, SparseTensorEmptyConverter,
775787
SparseTensorDeallocConverter, SparseTensorReorderCOOConverter,
776788
SparseTensorToPositionsConverter, SparseTensorToCoordinatesConverter,

mlir/test/Integration/Dialect/SparseTensor/CPU/block.mlir

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,11 @@
3434
)
3535
}>
3636

37+
#DSDD = #sparse_tensor.encoding<{
38+
map = (i, j, k, l) -> ( i : dense, j : compressed, k : dense, l : dense)
39+
}>
40+
41+
3742
!Filename = !llvm.ptr<i8>
3843

3944
//
@@ -77,6 +82,13 @@ module {
7782
%vecv = vector.transfer_read %val[%c0], %f0 : memref<?xf64>, vector<12xf64>
7883
vector.print %vecv : vector<12xf64>
7984

85+
// CHECK-NEXT: ( 1, 2, 0, 3, 4, 0, 0, 5, 6, 7, 8, 0 )
86+
%t1 = sparse_tensor.reinterpret_map %A : tensor<?x?xf64, #BSR>
87+
to tensor<?x?x2x2xf64, #DSDD>
88+
%vdsdd = sparse_tensor.values %t1 : tensor<?x?x2x2xf64, #DSDD> to memref<?xf64>
89+
%vecdsdd = vector.transfer_read %vdsdd[%c0], %f0 : memref<?xf64>, vector<12xf64>
90+
vector.print %vecdsdd : vector<12xf64>
91+
8092
// Release the resources.
8193
bufferization.dealloc_tensor %A: tensor<?x?xf64, #BSR>
8294

0 commit comments

Comments
 (0)