@@ -97,32 +97,6 @@ static scf::ForOp createFor(OpBuilder &builder, Location loc, Value upper,
9797 return forOp;
9898}
9999
100- // / Gets the dimension size for the given sparse tensor at the given
101- // / original dimension 'dim'.
102- static Value sizeFromTensorAtDim (OpBuilder &builder, Location loc,
103- SparseTensorDescriptor desc, Dimension dim) {
104- const SparseTensorType stt (desc.getRankedTensorType ());
105- // Access into static dimension can query original type directly.
106- // Note that this is typically already done by DimOp's folding.
107- if (auto sz = stt.getStaticDimSize (dim))
108- return constantIndex (builder, loc, *sz);
109-
110- // Any other query can consult the dimSizes array at field DimSizesIdx,
111- // accounting for the reordering applied to the sparse storage.
112- // FIXME: `toStoredDim` is deprecated.
113- const Level lvl = toStoredDim (stt, dim);
114- return desc.getLvlSize (builder, loc, lvl);
115- }
116-
117- // Gets the dimension size at the given stored level 'lvl', either as a
118- // constant for a static size, or otherwise dynamically through memSizes.
119- static Value sizeFromTensorAtLvl (OpBuilder &builder, Location loc,
120- SparseTensorDescriptor desc, Level lvl) {
121- // FIXME: `toOrigDim` is deprecated.
122- return sizeFromTensorAtDim (builder, loc, desc,
123- toOrigDim (desc.getRankedTensorType (), lvl));
124- }
125-
126100static void createPushback (OpBuilder &builder, Location loc,
127101 MutSparseTensorDescriptor desc,
128102 SparseTensorFieldKind kind, std::optional<Level> lvl,
@@ -164,7 +138,7 @@ static void allocSchemeForRank(OpBuilder &builder, Location loc,
164138 // at this level. We will eventually reach a compressed level or
165139 // otherwise the values array for the from-here "all-dense" case.
166140 assert (isDenseDLT (dlt));
167- Value size = sizeFromTensorAtLvl (builder, loc, desc , l);
141+ Value size = desc. getLvlSize (builder, loc, l);
168142 linear = builder.create <arith::MulIOp>(loc, linear, size);
169143 }
170144 // Reached values array so prepare for an insertion.
@@ -448,7 +422,7 @@ class SparseInsertGenerator
448422 // Construct the new position as:
449423 // positions[l] = size * positions[l-1] + coords[l]
450424 // <insert @ positions[l] at next level l + 1>
451- Value size = sizeFromTensorAtLvl (builder, loc, desc , l);
425+ Value size = desc. getLvlSize (builder, loc, l);
452426 Value mult = builder.create <arith::MulIOp>(loc, size, parentPos);
453427 parentPos = builder.create <arith::AddIOp>(loc, mult, coords[l]);
454428 }
@@ -658,19 +632,19 @@ class SparseCallConverter : public OpConversionPattern<func::CallOp> {
658632 }
659633};
660634
661- // / Sparse codegen rule for dimension accesses.
662- class SparseDimOpConverter : public OpConversionPattern <tensor::DimOp > {
635+ // / Sparse codegen rule for level accesses.
636+ class SparseLvlOpConverter : public OpConversionPattern <LvlOp > {
663637public:
664638 using OpConversionPattern::OpConversionPattern;
665639 LogicalResult
666- matchAndRewrite (tensor::DimOp op, OpAdaptor adaptor,
640+ matchAndRewrite (LvlOp op, OpAdaptor adaptor,
667641 ConversionPatternRewriter &rewriter) const override {
668- std::optional<int64_t > dim = op.getConstantIndex ();
669- if (!dim || !getSparseTensorEncoding (adaptor.getSource ().getType ()))
642+ std::optional<int64_t > lvl = op.getConstantLvlIndex ();
643+ if (!lvl || !getSparseTensorEncoding (adaptor.getSource ().getType ()))
670644 return failure ();
671645
672646 auto desc = getDescriptorFromTensorTuple (adaptor.getSource ());
673- auto sz = sizeFromTensorAtDim (rewriter, op.getLoc (), desc, *dim );
647+ auto sz = desc. getLvlSize (rewriter, op.getLoc (), *lvl );
674648
675649 rewriter.replaceOp (op, sz);
676650 return success ();
@@ -922,12 +896,10 @@ class SparseExpandConverter : public OpConversionPattern<ExpandOp> {
922896 Type idxType = rewriter.getIndexType ();
923897 // All initialization should be done on entry of the loop nest.
924898 rewriter.setInsertionPointAfter (op.getTensor ().getDefiningOp ());
899+
925900 // Determine the size for access expansion (always the innermost stored
926- // level size, translated back to original dimension). Note that we
927- // recursively rewrite the new DimOp on the **original** tensor.
928- // FIXME: `toOrigDim` is deprecated.
929- const Dimension innerDim = toOrigDim (srcType, srcType.getLvlRank () - 1 );
930- const auto sz = sizeFromTensorAtDim (rewriter, loc, desc, innerDim);
901+ // level size).
902+ const auto sz = desc.getLvlSize (rewriter, loc, srcType.getLvlRank () - 1 );
931903 // Generate a memref for `sz` elements of type `t`.
932904 const auto genAlloc = [&](Type t) {
933905 const auto memTp = MemRefType::get ({ShapedType::kDynamic }, t);
@@ -1588,7 +1560,7 @@ void mlir::populateSparseTensorCodegenPatterns(
15881560 TypeConverter &typeConverter, RewritePatternSet &patterns,
15891561 bool createSparseDeallocs, bool enableBufferInitialization) {
15901562 patterns.add <SparseAssembleOpConverter, SparseDisassembleOpConverter,
1591- SparseReturnConverter, SparseCallConverter, SparseDimOpConverter ,
1563+ SparseReturnConverter, SparseCallConverter, SparseLvlOpConverter ,
15921564 SparseCastConverter, SparseExtractSliceConverter,
15931565 SparseTensorLoadConverter, SparseExpandConverter,
15941566 SparseCompressConverter, SparseInsertConverter,
0 commit comments