@@ -646,10 +646,11 @@ class SparseLvlOpConverter : public OpConversionPattern<LvlOp> {
646646 matchAndRewrite (LvlOp op, OpAdaptor adaptor,
647647 ConversionPatternRewriter &rewriter) const override {
648648 std::optional<int64_t > lvl = op.getConstantLvlIndex ();
649- if (!lvl || !getSparseTensorEncoding (adaptor.getSource ().getType ()))
649+ RankedTensorType srcType = op.getSource ().getType ();
650+ if (!lvl || !getSparseTensorEncoding (srcType))
650651 return failure ();
651652
652- auto desc = getDescriptorFromTensorTuple (adaptor.getSource ());
653+ auto desc = getDescriptorFromTensorTuple (adaptor.getSource (), srcType );
653654 auto sz = desc.getLvlSize (rewriter, op.getLoc (), *lvl);
654655
655656 rewriter.replaceOp (op, sz);
@@ -675,8 +676,9 @@ struct SparseReorderCOOConverter : public OpConversionPattern<ReorderCOOOp> {
675676 assert (dstStt.hasSameDimToLvl (srcStt));
676677
677678 // We don't need a mutable descriptor here as we perform sorting in-place.
678- auto nnz = genValMemSize (rewriter, op.getLoc (), adaptor.getInputCoo ());
679- auto desc = getDescriptorFromTensorTuple (adaptor.getInputCoo ());
679+ auto desc = getDescriptorFromTensorTuple (adaptor.getInputCoo (),
680+ op.getInputCoo ().getType ());
681+ auto nnz = desc.getValMemSize (rewriter, op.getLoc ());
680682 auto crd = desc.getAOSMemRef ();
681683 auto val = desc.getValMemRef ();
682684
@@ -704,7 +706,8 @@ class SparseSliceGetterOpConverter : public OpConversionPattern<Op> {
704706 matchAndRewrite (Op op, typename Op::Adaptor adaptor,
705707 ConversionPatternRewriter &rewriter) const override {
706708 // Simply lowers to specifer.get <field> operation.
707- auto desc = getDescriptorFromTensorTuple (adaptor.getSlice ());
709+ auto desc = getDescriptorFromTensorTuple (adaptor.getSlice (),
710+ op.getSlice ().getType ());
708711 auto v = desc.getSpecifierField (rewriter, op.getLoc (), kind,
709712 op.getDim ().getZExtValue ());
710713
@@ -762,7 +765,8 @@ class SparseTensorAllocConverter
762765 Location loc = op.getLoc ();
763766 // Deal with copy.
764767 if (op.getCopy ()) {
765- auto desc = getDescriptorFromTensorTuple (adaptor.getCopy ());
768+ auto desc = getDescriptorFromTensorTuple (
769+ adaptor.getCopy (), cast<RankedTensorType>(op.getCopy ().getType ()));
766770 SmallVector<Value> fields;
767771 fields.reserve (desc.getNumFields ());
768772 // Memcpy on memref fields.
@@ -868,7 +872,9 @@ class SparseTensorDeallocConverter
868872 if (createDeallocs) {
869873 // Replace the sparse tensor deallocation with field deallocations.
870874 Location loc = op.getLoc ();
871- auto desc = getDescriptorFromTensorTuple (adaptor.getTensor ());
875+ auto desc = getDescriptorFromTensorTuple (
876+ adaptor.getTensor (),
877+ cast<RankedTensorType>(op.getTensor ().getType ()));
872878 for (auto input : desc.getMemRefFields ())
873879 // Deallocate every buffer used to store the sparse tensor handler.
874880 rewriter.create <memref::DeallocOp>(loc, input);
@@ -889,7 +895,8 @@ class SparseTensorLoadConverter : public OpConversionPattern<LoadOp> {
889895 matchAndRewrite (LoadOp op, OpAdaptor adaptor,
890896 ConversionPatternRewriter &rewriter) const override {
891897 // Prepare descriptor.
892- auto desc = getDescriptorFromTensorTuple (adaptor.getTensor ());
898+ auto desc = getDescriptorFromTensorTuple (adaptor.getTensor (),
899+ op.getTensor ().getType ());
893900 // Generate optional insertion finalization code.
894901 if (op.getHasInserts ())
895902 genEndInsert (rewriter, op.getLoc (), desc);
@@ -909,7 +916,8 @@ class SparseExpandConverter : public OpConversionPattern<ExpandOp> {
909916 if (!getSparseTensorEncoding (op.getTensor ().getType ()))
910917 return failure ();
911918 Location loc = op->getLoc ();
912- auto desc = getDescriptorFromTensorTuple (adaptor.getTensor ());
919+ auto desc = getDescriptorFromTensorTuple (adaptor.getTensor (),
920+ op.getTensor ().getType ());
913921 const auto srcType = getSparseTensorType (op.getTensor ());
914922 Type eltType = srcType.getElementType ();
915923 Type boolType = rewriter.getIntegerType (1 );
@@ -959,7 +967,8 @@ class SparseCompressConverter : public OpConversionPattern<CompressOp> {
959967 ConversionPatternRewriter &rewriter) const override {
960968 Location loc = op->getLoc ();
961969 SmallVector<Value> fields;
962- auto desc = getMutDescriptorFromTensorTuple (adaptor.getTensor (), fields);
970+ auto desc = getMutDescriptorFromTensorTuple (adaptor.getTensor (), fields,
971+ op.getTensor ().getType ());
963972 Value values = adaptor.getValues ();
964973 Value filled = adaptor.getFilled ();
965974 Value added = adaptor.getAdded ();
@@ -1032,7 +1041,8 @@ class SparseInsertConverter : public OpConversionPattern<tensor::InsertOp> {
10321041 assert (stt.isIdentity () && " Run reinterpret-map before conversion." );
10331042
10341043 Location loc = op.getLoc ();
1035- auto desc = getDescriptorFromTensorTuple (adaptor.getDest ());
1044+ auto desc =
1045+ getDescriptorFromTensorTuple (adaptor.getDest (), op.getDest ().getType ());
10361046 TypeRange flatSpTensorTps = desc.getFields ().getTypes ();
10371047 SmallVector<Value> params = llvm::to_vector (desc.getFields ());
10381048 params.append (adaptor.getIndices ().begin (), adaptor.getIndices ().end ());
@@ -1059,7 +1069,8 @@ class SparseToPositionsConverter : public OpConversionPattern<ToPositionsOp> {
10591069 // of this operation truly observe size, not capacity!
10601070 Location loc = op.getLoc ();
10611071 Level lvl = op.getLevel ();
1062- auto desc = getDescriptorFromTensorTuple (adaptor.getTensor ());
1072+ auto desc = getDescriptorFromTensorTuple (adaptor.getTensor (),
1073+ op.getTensor ().getType ());
10631074 auto mem = desc.getPosMemRef (lvl);
10641075 auto size = desc.getPosMemSize (rewriter, loc, lvl);
10651076 rewriter.replaceOp (op, genSliceToSize (rewriter, loc, mem, size));
@@ -1081,7 +1092,8 @@ class SparseToCoordinatesConverter
10811092 // of this operation truly observe size, not capacity!
10821093 Location loc = op.getLoc ();
10831094 Level lvl = op.getLevel ();
1084- auto desc = getDescriptorFromTensorTuple (adaptor.getTensor ());
1095+ auto desc = getDescriptorFromTensorTuple (adaptor.getTensor (),
1096+ op.getTensor ().getType ());
10851097 auto mem = desc.getCrdMemRefOrView (rewriter, loc, lvl);
10861098 if (lvl < getSparseTensorType (op.getTensor ()).getAoSCOOStart ()) {
10871099 auto size = desc.getCrdMemSize (rewriter, loc, lvl);
@@ -1106,7 +1118,8 @@ class SparseToCoordinatesBufferConverter
11061118 // of this operation truly observe size, not capacity!
11071119 Location loc = op.getLoc ();
11081120 Level lvl = getSparseTensorType (op.getTensor ()).getAoSCOOStart ();
1109- auto desc = getDescriptorFromTensorTuple (adaptor.getTensor ());
1121+ auto desc = getDescriptorFromTensorTuple (adaptor.getTensor (),
1122+ op.getTensor ().getType ());
11101123 auto mem = desc.getAOSMemRef ();
11111124 auto size = desc.getCrdMemSize (rewriter, loc, lvl);
11121125 rewriter.replaceOp (op, genSliceToSize (rewriter, loc, mem, size));
@@ -1126,7 +1139,8 @@ class SparseToValuesConverter : public OpConversionPattern<ToValuesOp> {
11261139 // The view is restricted to the actual size to ensure clients
11271140 // of this operation truly observe size, not capacity!
11281141 Location loc = op.getLoc ();
1129- auto desc = getDescriptorFromTensorTuple (adaptor.getTensor ());
1142+ auto desc = getDescriptorFromTensorTuple (adaptor.getTensor (),
1143+ op.getTensor ().getType ());
11301144 auto mem = desc.getValMemRef ();
11311145 auto size = desc.getValMemSize (rewriter, loc);
11321146 rewriter.replaceOp (op, genSliceToSize (rewriter, loc, mem, size));
@@ -1172,7 +1186,8 @@ class SparseConvertConverter : public OpConversionPattern<ConvertOp> {
11721186 // else:
11731187 // dst = memref.copy(src)
11741188 Location loc = op.getLoc ();
1175- auto srcDesc = getDescriptorFromTensorTuple (adaptor.getSource ());
1189+ auto srcDesc = getDescriptorFromTensorTuple (adaptor.getSource (),
1190+ op.getSource ().getType ());
11761191 SmallVector<Value> fields;
11771192 foreachFieldAndTypeInSparseTensor (
11781193 SparseTensorType (cast<RankedTensorType>(op.getResult ().getType ())),
@@ -1236,7 +1251,8 @@ class SparseExtractSliceConverter
12361251 assert (srcEnc.withoutDimSlices () == dstEnc.withoutDimSlices ());
12371252
12381253 SmallVector<Value> fields;
1239- auto desc = getMutDescriptorFromTensorTuple (adaptor.getSource (), fields);
1254+ auto desc = getMutDescriptorFromTensorTuple (adaptor.getSource (), fields,
1255+ op.getSource ().getType ());
12401256
12411257 auto newSpec = rewriter.create <StorageSpecifierInitOp>(
12421258 loc, StorageSpecifierType::get (ctx, dstEnc), desc.getSpecifier ());
@@ -1285,8 +1301,9 @@ class SparseNumberOfEntriesConverter
12851301 // Query memSizes for the actually stored values.
12861302 // FIXME: the nse value computed in this way might be wrong when there is
12871303 // any "loose_compressed" level.
1288- rewriter.replaceOp (
1289- op, genValMemSize (rewriter, op.getLoc (), adaptor.getTensor ()));
1304+ auto desc = getDescriptorFromTensorTuple (adaptor.getTensor (),
1305+ op.getTensor ().getType ());
1306+ rewriter.replaceOp (op, desc.getValMemSize (rewriter, op.getLoc ()));
12901307 return success ();
12911308 }
12921309};
@@ -1415,7 +1432,8 @@ struct SparseDisassembleOpConverter
14151432 LogicalResult
14161433 matchAndRewrite (DisassembleOp op, OpAdaptor adaptor,
14171434 ConversionPatternRewriter &rewriter) const override {
1418- auto desc = getDescriptorFromTensorTuple (adaptor.getTensor ());
1435+ auto desc = getDescriptorFromTensorTuple (adaptor.getTensor (),
1436+ op.getTensor ().getType ());
14191437 Location loc = op.getLoc ();
14201438 SmallVector<Value> retMem;
14211439 SmallVector<Value> retLen;
0 commit comments