Skip to content

Commit c780352

Browse files
authored
[mlir][sparse] implement sparse_tensor.lvl operation. (#69993)
1 parent 260dbb4 commit c780352

15 files changed

+132
-106
lines changed

mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -915,7 +915,7 @@ Level mlir::sparse_tensor::toStoredDim(SparseTensorEncodingAttr enc,
915915
// properly handle non-permutations.
916916
Dimension mlir::sparse_tensor::toOrigDim(RankedTensorType type, Level l) {
917917
const auto enc = getSparseTensorEncoding(type);
918-
assert(l < enc.getLvlRank());
918+
assert(!enc || l < enc.getLvlRank());
919919
return toOrigDim(enc, l);
920920
}
921921

@@ -1208,6 +1208,12 @@ LogicalResult CrdTranslateOp::fold(FoldAdaptor adaptor,
12081208
return success();
12091209
}
12101210

1211+
void LvlOp::build(OpBuilder &builder, OperationState &state, Value source,
1212+
int64_t index) {
1213+
Value val = builder.create<arith::ConstantIndexOp>(state.location, index);
1214+
return build(builder, state, source, val);
1215+
}
1216+
12111217
LogicalResult LvlOp::verify() {
12121218
if (std::optional<uint64_t> lvl = getConstantLvlIndex()) {
12131219
auto stt = getSparseTensorType(getSource());

mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -428,12 +428,13 @@ void LoopEmitter::initializeLoopEmit(
428428
const auto enc = getSparseTensorEncoding(rtp);
429429
const Level cooStart = enc ? getCOOStart(enc) : lvlRank;
430430

431-
SmallVector<Value> dimSz;
432-
for (Dimension d = 0; d < stt.getDimRank(); d++)
433-
dimSz.push_back(linalg::createOrFoldDimOp(builder, loc, tensor, d));
434-
435-
ValueRange lvlSzs =
436-
enc.translateCrds(builder, loc, dimSz, CrdTransDirectionKind::dim2lvl);
431+
SmallVector<Value> lvlSzs;
432+
for (Level l = 0; l < stt.getLvlRank(); l++) {
433+
if (stt.hasEncoding())
434+
lvlSzs.push_back(builder.create<LvlOp>(loc, tensor, l));
435+
else
436+
lvlSzs.push_back(builder.create<tensor::DimOp>(loc, tensor, l));
437+
}
437438

438439
// Scan all levels of current tensor.
439440
for (Level l = 0; l < lvlRank; l++) {
@@ -489,7 +490,8 @@ void LoopEmitter::initializeLoopEmit(
489490
valBuffer[t] = denseVal;
490491
} else {
491492
// Annotated sparse tensors.
492-
// We also need the value buffer for all-dense annotated "sparse" tensors.
493+
// We also need the value buffer for all-dense annotated "sparse"
494+
// tensors.
493495
valBuffer[t] = genToValues(builder, loc, tensor);
494496
}
495497
// NOTE: we can also prepare for 0 lvl here in advance, this will hoist

mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp

Lines changed: 12 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -97,32 +97,6 @@ static scf::ForOp createFor(OpBuilder &builder, Location loc, Value upper,
9797
return forOp;
9898
}
9999

100-
/// Gets the dimension size for the given sparse tensor at the given
101-
/// original dimension 'dim'.
102-
static Value sizeFromTensorAtDim(OpBuilder &builder, Location loc,
103-
SparseTensorDescriptor desc, Dimension dim) {
104-
const SparseTensorType stt(desc.getRankedTensorType());
105-
// Access into static dimension can query original type directly.
106-
// Note that this is typically already done by DimOp's folding.
107-
if (auto sz = stt.getStaticDimSize(dim))
108-
return constantIndex(builder, loc, *sz);
109-
110-
// Any other query can consult the dimSizes array at field DimSizesIdx,
111-
// accounting for the reordering applied to the sparse storage.
112-
// FIXME: `toStoredDim` is deprecated.
113-
const Level lvl = toStoredDim(stt, dim);
114-
return desc.getLvlSize(builder, loc, lvl);
115-
}
116-
117-
// Gets the dimension size at the given stored level 'lvl', either as a
118-
// constant for a static size, or otherwise dynamically through memSizes.
119-
static Value sizeFromTensorAtLvl(OpBuilder &builder, Location loc,
120-
SparseTensorDescriptor desc, Level lvl) {
121-
// FIXME: `toOrigDim` is deprecated.
122-
return sizeFromTensorAtDim(builder, loc, desc,
123-
toOrigDim(desc.getRankedTensorType(), lvl));
124-
}
125-
126100
static void createPushback(OpBuilder &builder, Location loc,
127101
MutSparseTensorDescriptor desc,
128102
SparseTensorFieldKind kind, std::optional<Level> lvl,
@@ -164,7 +138,7 @@ static void allocSchemeForRank(OpBuilder &builder, Location loc,
164138
// at this level. We will eventually reach a compressed level or
165139
// otherwise the values array for the from-here "all-dense" case.
166140
assert(isDenseDLT(dlt));
167-
Value size = sizeFromTensorAtLvl(builder, loc, desc, l);
141+
Value size = desc.getLvlSize(builder, loc, l);
168142
linear = builder.create<arith::MulIOp>(loc, linear, size);
169143
}
170144
// Reached values array so prepare for an insertion.
@@ -448,7 +422,7 @@ class SparseInsertGenerator
448422
// Construct the new position as:
449423
// positions[l] = size * positions[l-1] + coords[l]
450424
// <insert @ positions[l] at next level l + 1>
451-
Value size = sizeFromTensorAtLvl(builder, loc, desc, l);
425+
Value size = desc.getLvlSize(builder, loc, l);
452426
Value mult = builder.create<arith::MulIOp>(loc, size, parentPos);
453427
parentPos = builder.create<arith::AddIOp>(loc, mult, coords[l]);
454428
}
@@ -658,19 +632,19 @@ class SparseCallConverter : public OpConversionPattern<func::CallOp> {
658632
}
659633
};
660634

661-
/// Sparse codegen rule for dimension accesses.
662-
class SparseDimOpConverter : public OpConversionPattern<tensor::DimOp> {
635+
/// Sparse codegen rule for level accesses.
636+
class SparseLvlOpConverter : public OpConversionPattern<LvlOp> {
663637
public:
664638
using OpConversionPattern::OpConversionPattern;
665639
LogicalResult
666-
matchAndRewrite(tensor::DimOp op, OpAdaptor adaptor,
640+
matchAndRewrite(LvlOp op, OpAdaptor adaptor,
667641
ConversionPatternRewriter &rewriter) const override {
668-
std::optional<int64_t> dim = op.getConstantIndex();
669-
if (!dim || !getSparseTensorEncoding(adaptor.getSource().getType()))
642+
std::optional<int64_t> lvl = op.getConstantLvlIndex();
643+
if (!lvl || !getSparseTensorEncoding(adaptor.getSource().getType()))
670644
return failure();
671645

672646
auto desc = getDescriptorFromTensorTuple(adaptor.getSource());
673-
auto sz = sizeFromTensorAtDim(rewriter, op.getLoc(), desc, *dim);
647+
auto sz = desc.getLvlSize(rewriter, op.getLoc(), *lvl);
674648

675649
rewriter.replaceOp(op, sz);
676650
return success();
@@ -922,12 +896,10 @@ class SparseExpandConverter : public OpConversionPattern<ExpandOp> {
922896
Type idxType = rewriter.getIndexType();
923897
// All initialization should be done on entry of the loop nest.
924898
rewriter.setInsertionPointAfter(op.getTensor().getDefiningOp());
899+
925900
// Determine the size for access expansion (always the innermost stored
926-
// level size, translated back to original dimension). Note that we
927-
// recursively rewrite the new DimOp on the **original** tensor.
928-
// FIXME: `toOrigDim` is deprecated.
929-
const Dimension innerDim = toOrigDim(srcType, srcType.getLvlRank() - 1);
930-
const auto sz = sizeFromTensorAtDim(rewriter, loc, desc, innerDim);
901+
// level size).
902+
const auto sz = desc.getLvlSize(rewriter, loc, srcType.getLvlRank() - 1);
931903
// Generate a memref for `sz` elements of type `t`.
932904
const auto genAlloc = [&](Type t) {
933905
const auto memTp = MemRefType::get({ShapedType::kDynamic}, t);
@@ -1588,7 +1560,7 @@ void mlir::populateSparseTensorCodegenPatterns(
15881560
TypeConverter &typeConverter, RewritePatternSet &patterns,
15891561
bool createSparseDeallocs, bool enableBufferInitialization) {
15901562
patterns.add<SparseAssembleOpConverter, SparseDisassembleOpConverter,
1591-
SparseReturnConverter, SparseCallConverter, SparseDimOpConverter,
1563+
SparseReturnConverter, SparseCallConverter, SparseLvlOpConverter,
15921564
SparseCastConverter, SparseExtractSliceConverter,
15931565
SparseTensorLoadConverter, SparseExpandConverter,
15941566
SparseCompressConverter, SparseInsertConverter,

mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -293,26 +293,28 @@ class SparseReturnConverter : public OpConversionPattern<func::ReturnOp> {
293293
}
294294
};
295295

296-
/// Sparse conversion rule for accessing dimension-sizes.
297-
class SparseTensorToDimSizeConverter
298-
: public OpConversionPattern<tensor::DimOp> {
296+
/// Sparse conversion rule for accessing level-sizes.
297+
class SparseTensorLvlOpConverter : public OpConversionPattern<LvlOp> {
299298
public:
300299
using OpConversionPattern::OpConversionPattern;
301300
LogicalResult
302-
matchAndRewrite(tensor::DimOp op, OpAdaptor adaptor,
301+
matchAndRewrite(LvlOp op, OpAdaptor adaptor,
303302
ConversionPatternRewriter &rewriter) const override {
304303
const auto stt = getSparseTensorType(op.getSource());
305304
// Only rewrite sparse DimOp.
306305
if (!stt.hasEncoding())
307306
return failure();
307+
308308
// Only rewrite DimOp with constant index.
309-
std::optional<int64_t> dim = op.getConstantIndex();
310-
if (!dim)
309+
std::optional<int64_t> lvl = op.getConstantLvlIndex();
310+
311+
if (!lvl)
311312
return failure();
312-
// Generate the call.
313+
314+
// By now, if the level size is constant, the operation should have already
315+
// been folded by LvlOp's folder, so we generate the call unconditionally.
313316
Value src = adaptor.getOperands()[0];
314-
rewriter.replaceOp(
315-
op, createOrFoldDimCall(rewriter, op->getLoc(), stt, src, *dim));
317+
rewriter.replaceOp(op, genLvlSizeCall(rewriter, op.getLoc(), src, *lvl));
316318
return success();
317319
}
318320
};
@@ -767,7 +769,7 @@ mlir::SparseTensorTypeToPtrConverter::SparseTensorTypeToPtrConverter() {
767769
void mlir::populateSparseTensorConversionPatterns(TypeConverter &typeConverter,
768770
RewritePatternSet &patterns) {
769771
patterns
770-
.add<SparseReturnConverter, SparseTensorToDimSizeConverter,
772+
.add<SparseReturnConverter, SparseTensorLvlOpConverter,
771773
SparseCastConverter, SparseTensorNewConverter,
772774
SparseTensorAllocConverter, SparseTensorEmptyConverter,
773775
SparseTensorDeallocConverter, SparseTensorReorderCOOConverter,

mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp

Lines changed: 44 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -888,6 +888,48 @@ struct CrdTranslateRewriter : public OpRewritePattern<CrdTranslateOp> {
888888
}
889889
};
890890

891+
struct SparseTensorDimOpRewriter : public OpRewritePattern<tensor::DimOp> {
892+
using OpRewritePattern::OpRewritePattern;
893+
LogicalResult matchAndRewrite(tensor::DimOp op,
894+
PatternRewriter &rewriter) const override {
895+
std::optional<int64_t> dim = op.getConstantIndex();
896+
auto stt = getSparseTensorType(op.getSource());
897+
if (!dim || !stt.hasEncoding())
898+
return failure();
899+
900+
if (stt.isPermutation()) {
901+
rewriter.replaceOpWithNewOp<LvlOp>(op, op.getSource(),
902+
toStoredDim(stt, *dim));
903+
return success();
904+
}
905+
906+
// Non-permutation dim2lvl/lvl2dim maps.
907+
// Compute as follows:
908+
// affine.apply #map (l0 - 1, l1 - 1, ...) + 1
909+
// Note that it is not the most efficient way (but a more general one) for
910+
// the lvl to dim translation, e.g., for BSR, the dimension size for can be
911+
// computed simply by lvl_size * block_size.
912+
Location loc = op.getLoc();
913+
SmallVector<Value> maxLvlCrds;
914+
for (Level l = 0; l < stt.getLvlRank(); l++) {
915+
Value lvlSz = rewriter.create<LvlOp>(loc, op.getSource(), l);
916+
Value maxLvlCrd = rewriter.create<arith::SubIOp>(
917+
loc, lvlSz, constantOne(rewriter, loc, rewriter.getIndexType()));
918+
maxLvlCrds.push_back(maxLvlCrd);
919+
}
920+
921+
AffineExpr lvl2DimExp = stt.getLvlToDim().getResult(*dim);
922+
Value maxDimCrd = rewriter.create<affine::AffineApplyOp>(
923+
op.getLoc(), AffineMap::get(stt.getLvlRank(), 0, lvl2DimExp),
924+
maxLvlCrds);
925+
926+
Value dimSz = rewriter.create<arith::AddIOp>(
927+
loc, maxDimCrd, constantOne(rewriter, loc, rewriter.getIndexType()));
928+
rewriter.replaceOp(op, dimSz);
929+
return success();
930+
}
931+
};
932+
891933
struct ConcatenateRewriter : public OpRewritePattern<ConcatenateOp> {
892934
using OpRewritePattern::OpRewritePattern;
893935
LogicalResult matchAndRewrite(ConcatenateOp op,
@@ -1270,7 +1312,8 @@ void mlir::populatePostSparsificationRewriting(RewritePatternSet &patterns,
12701312
ReshapeRewriter<tensor::CollapseShapeOp>,
12711313
Sparse2SparseReshapeRewriter<tensor::ExpandShapeOp>,
12721314
Sparse2SparseReshapeRewriter<tensor::CollapseShapeOp>,
1273-
TensorReshapeRewriter>(patterns.getContext());
1315+
SparseTensorDimOpRewriter, TensorReshapeRewriter>(
1316+
patterns.getContext());
12741317
if (enableForeach)
12751318
patterns.add<ForeachRewriter>(patterns.getContext());
12761319
if (enableConvert)

mlir/test/Dialect/SparseTensor/codegen.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// RUN: mlir-opt %s --sparse-tensor-codegen --canonicalize -cse | FileCheck %s
1+
// RUN: mlir-opt %s --post-sparsification-rewrite --sparse-tensor-codegen --canonicalize -cse | FileCheck %s
22

33
#SV = #sparse_tensor.encoding<{ map = (d0) -> (d0 : compressed) }>
44

mlir/test/Dialect/SparseTensor/conversion.mlir

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// RUN: mlir-opt %s --sparse-tensor-conversion --canonicalize --cse | FileCheck %s
1+
// RUN: mlir-opt %s --post-sparsification-rewrite --sparse-tensor-conversion --canonicalize --cse | FileCheck %s
22

33
#SparseVector = #sparse_tensor.encoding<{
44
map = (d0) -> (d0 : compressed)
@@ -38,7 +38,7 @@ func.func @sparse_nop(%arg0: tensor<?xf64, #SparseVector>) -> tensor<?xf64, #Spa
3838
// CHECK-LABEL: func @sparse_dim1d(
3939
// CHECK-SAME: %[[A:.*]]: !llvm.ptr<i8>)
4040
// CHECK: %[[C:.*]] = arith.constant 0 : index
41-
// CHECK: %[[D:.*]] = call @sparseDimSize(%[[A]], %[[C]])
41+
// CHECK: %[[D:.*]] = call @sparseLvlSize(%[[A]], %[[C]])
4242
// CHECK: return %[[D]] : index
4343
func.func @sparse_dim1d(%arg0: tensor<?xf64, #SparseVector>) -> index {
4444
%c = arith.constant 0 : index
@@ -51,8 +51,8 @@ func.func @sparse_dim1d(%arg0: tensor<?xf64, #SparseVector>) -> index {
5151
// dimension 1 is stored as level 2).
5252
// CHECK-LABEL: func @sparse_dim3d(
5353
// CHECK-SAME: %[[A:.*]]: !llvm.ptr<i8>)
54-
// CHECK: %[[C:.*]] = arith.constant 1 : index
55-
// CHECK: %[[D:.*]] = call @sparseDimSize(%[[A]], %[[C]])
54+
// CHECK: %[[C:.*]] = arith.constant 2 : index
55+
// CHECK: %[[D:.*]] = call @sparseLvlSize(%[[A]], %[[C]])
5656
// CHECK: return %[[D]] : index
5757
func.func @sparse_dim3d(%arg0: tensor<?x?x?xf64, #SparseTensor>) -> index {
5858
%c = arith.constant 1 : index

mlir/test/Dialect/SparseTensor/sparse_2d.mlir

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1511,7 +1511,7 @@ func.func @sum_reduction(%arga: tensor<10x20xf32, #Tds>, %argx: tensor<f32>) ->
15111511
// CHECK-DAG: %[[VAL_5:.*]] = sparse_tensor.positions %[[VAL_0]] {level = 1 : index} : tensor<?x?xf64, #sparse_tensor.encoding<{{{.*}}}>> to memref<?xindex>
15121512
// CHECK-DAG: %[[VAL_6:.*]] = sparse_tensor.coordinates %[[VAL_0]] {level = 1 : index} : tensor<?x?xf64, #sparse_tensor.encoding<{{{.*}}}>> to memref<?xindex>
15131513
// CHECK-DAG: %[[VAL_7:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<?x?xf64, #sparse_tensor.encoding<{{{.*}}}>> to memref<?xf64>
1514-
// CHECK-DAG: %[[VAL_8:.*]] = tensor.dim %[[VAL_0]], %[[VAL_3]] : tensor<?x?xf64, #sparse_tensor.encoding<{{{.*}}}>>
1514+
// CHECK-DAG: %[[VAL_8:.*]] = sparse_tensor.lvl %[[VAL_0]], %[[VAL_3]] : tensor<?x?xf64, #sparse_tensor.encoding<{{{.*}}}>>
15151515
// CHECK-DAG: %[[VAL_11:.*]] = bufferization.to_memref %[[VAL_1]] : memref<?x?xf64>
15161516
// CHECK: linalg.fill ins(%{{.*}} : f64) outs(%[[VAL_11]] : memref<?x?xf64>)
15171517
// CHECK: scf.for %[[VAL_12:.*]] = %[[VAL_3]] to %[[VAL_8]] step %[[VAL_4]] {
@@ -1641,7 +1641,7 @@ func.func @sampled_dense_dense(%args: tensor<?x?xf32, #Tss>,
16411641
// CHECK-DAG: %[[VAL_19:.*]] = sparse_tensor.values %[[VAL_2]] : tensor<?x?xf32, #sparse_tensor.encoding<{{{.*}}}>> to memref<?xf32>
16421642
// CHECK-DAG: %[[VAL_20:.*]] = bufferization.to_memref %[[VAL_3]] : memref<?xf32>
16431643
// CHECK-DAG: %[[VAL_21:.*]] = bufferization.to_memref %[[VAL_4]] : memref<f32>
1644-
// CHECK-DAG: %[[VAL_22:.*]] = tensor.dim %[[VAL_2]], %[[VAL_6]] : tensor<?x?xf32,
1644+
// CHECK-DAG: %[[VAL_22:.*]] = sparse_tensor.lvl %[[VAL_2]], %[[VAL_6]] : tensor<?x?xf32,
16451645
// CHECK-DAG: %[[VAL_24:.*]] = bufferization.to_memref %[[VAL_5]] : memref<?xf32>
16461646
// CHECK: %[[VAL_25:.*]] = memref.load %[[VAL_21]][] : memref<f32>
16471647
// CHECK: %[[VAL_26:.*]] = memref.load %[[VAL_9]]{{\[}}%[[VAL_6]]] : memref<?xindex>

mlir/test/Dialect/SparseTensor/sparse_3d.mlir

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1126,10 +1126,10 @@ func.func @mul_sss(%arga: tensor<32x16x8xf32, #Tsss>, %argb: tensor<32x16x8xf32>
11261126
// CHECK-DAG: %[[VAL_7:.*]] = sparse_tensor.positions %[[VAL_1]] {level = 2 : index} : tensor<?x?x?xf32, #sparse_tensor.encoding<{{{.*}}}>> to memref<?xindex>
11271127
// CHECK-DAG: %[[VAL_8:.*]] = sparse_tensor.coordinates %[[VAL_1]] {level = 2 : index} : tensor<?x?x?xf32, #sparse_tensor.encoding<{{{.*}}}>> to memref<?xindex>
11281128
// CHECK-DAG: %[[VAL_9:.*]] = sparse_tensor.values %[[VAL_1]] : tensor<?x?x?xf32, #sparse_tensor.encoding<{{{.*}}}>> to memref<?xf32>
1129-
// CHECK-DAG: %[[VAL_10:.*]] = tensor.dim %[[VAL_1]], %[[VAL_6]] : tensor<?x?x?xf32, #sparse_tensor.encoding<{{{.*}}}>>
1129+
// CHECK-DAG: %[[VAL_10:.*]] = sparse_tensor.lvl %[[VAL_1]], %[[VAL_6]] : tensor<?x?x?xf32, #sparse_tensor.encoding<{{{.*}}}>>
11301130
// CHECK-DAG: %[[VAL_11:.*]] = bufferization.to_memref %[[VAL_2]] : memref<?x?xf32>
11311131
// CHECK-DAG: %[[VAL_12:.*]] = bufferization.to_memref %[[VAL_3]] : memref<?x?xf32>
1132-
// CHECK-DAG: %[[VAL_13:.*]] = tensor.dim %[[VAL_1]], %[[VAL_5]] : tensor<?x?x?xf32, #sparse_tensor.encoding<{{{.*}}}>>
1132+
// CHECK-DAG: %[[VAL_13:.*]] = sparse_tensor.lvl %[[VAL_1]], %[[VAL_5]] : tensor<?x?x?xf32, #sparse_tensor.encoding<{{{.*}}}>>
11331133
// CHECK-DAG: %[[VAL_14:.*]] = tensor.dim %[[VAL_2]], %[[VAL_6]] : tensor<?x?xf32>
11341134
// CHECK-DAG: %[[VAL_16:.*]] = bufferization.to_memref %[[VAL_0]] : memref<?x?xf32>
11351135
// CHECK: scf.for %[[VAL_17:.*]] = %[[VAL_5]] to %[[VAL_13]] step %[[VAL_6]] {

mlir/test/Dialect/SparseTensor/sparse_expand.mlir

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,8 @@
44
// RUN: FileCheck %s --check-prefix=CHECK-SPARSE
55
// RUN: mlir-opt %s --linalg-generalize-named-ops \
66
// RUN: --linalg-fuse-elementwise-ops \
7-
// RUN: --sparsification --sparse-tensor-conversion --cse | \
7+
// RUN: --sparsification --post-sparsification-rewrite \
8+
// RUN: --sparse-tensor-conversion --cse | \
89
// RUN: FileCheck %s --check-prefix=CHECK-CONVERT
910

1011
#CSR = #sparse_tensor.encoding<{
@@ -45,8 +46,9 @@
4546
//
4647
// CHECK-CONVERT-LABEL: func @kernel(
4748
// CHECK-CONVERT-SAME: %[[A:.*]]: !llvm.ptr<i8>) -> !llvm.ptr<i8>
48-
// CHECK-CONVERT: %[[C0:.*]] = arith.constant 0 : index
49-
// CHECK-CONVERT: %[[N:.*]] = call @sparseDimSize(%[[A]], %[[C0]])
49+
// CHECK-CONVERT-DAG: %[[C1:.*]] = arith.constant 1 : index
50+
// CHECK-CONVERT-DAG: %[[C0:.*]] = arith.constant 0 : index
51+
// CHECK-CONVERT: %[[N:.*]] = call @sparseLvlSize(%[[A]], %[[C1]])
5052
// CHECK-CONVERT: %[[V:.*]] = call @newSparseTensor
5153
// CHECK-CONVERT: %[[S:.*]] = call @sparseLvlSize(%[[V]], %[[C0]])
5254
// CHECK-CONVERT: %[[A:.*]] = memref.alloc(%[[S]]) : memref<?xf64>

0 commit comments

Comments
 (0)