@@ -881,25 +881,27 @@ struct TensorReshapeRewriter : public OpRewritePattern<tensor::ReshapeOp> {
881881 PatternRewriter &rewriter) const override {
882882 Location loc = op.getLoc ();
883883 Value srcTensor = op.getSource ();
884- const auto srcTp = getSparseTensorType (srcTensor);
885- const auto dstTp = getSparseTensorType (op.getResult ());
884+ const auto srcTp = tryGetSparseTensorType (srcTensor);
885+ const auto dstTp = tryGetSparseTensorType (op.getResult ());
886+ if (!srcTp || !dstTp)
887+ return failure ();
886888
887- if (!srcTp. hasEncoding () || !dstTp. hasEncoding () ||
888- !dstTp. hasStaticDimShape ())
889+ if (!srcTp-> hasEncoding () || !dstTp-> hasEncoding () ||
890+ !dstTp-> hasStaticDimShape ())
889891 return failure ();
890892
891893 SmallVector<Value> srcSizes;
892- sizesForTensor (rewriter, srcSizes, loc, srcTp, srcTensor);
894+ sizesForTensor (rewriter, srcSizes, loc, * srcTp, srcTensor);
893895 SmallVector<Value> dstSizes;
894- for (Dimension d : dstTp. getDimShape ())
896+ for (Dimension d : dstTp-> getDimShape ())
895897 dstSizes.push_back (constantIndex (rewriter, loc, d));
896898
897899 Value nnz = rewriter.create <NumberOfEntriesOp>(loc, srcTensor);
898900 // Only need an unordered COO buffer if input and output are not sorted
899901 // in the same way.
900902 Type bufferTp = getBufferType (
901- dstTp. withoutDimToLvl (),
902- !srcTp. isAllOrdered () || !srcTp. isIdentity () || !dstTp. isIdentity ());
903+ dstTp-> withoutDimToLvl (),
904+ !srcTp-> isAllOrdered () || !srcTp-> isIdentity () || !dstTp-> isIdentity ());
903905 SmallVector<Value> dynSizes;
904906 Value buffer = rewriter
905907 .create <AllocTensorOp>(loc, bufferTp, dynSizes, Value (),
@@ -917,12 +919,12 @@ struct TensorReshapeRewriter : public OpRewritePattern<tensor::ReshapeOp> {
917919 // followed by an optional
918920 // %t = sparse_tensor.cast %tmp
919921 // depending on whether the input/output are sorted in the same way.
920- const auto encSrc = srcTp. getEncoding ();
922+ const auto encSrc = srcTp-> getEncoding ();
921923 ForeachOp foreachOp = rewriter.create <ForeachOp>(
922924 loc, srcTensor, buffer,
923925 [&](OpBuilder &builder, Location loc, ValueRange srcLcvs, Value v,
924926 ValueRange reduc) {
925- const Dimension srcRank = srcTp. getDimRank ();
927+ const Dimension srcRank = srcTp-> getDimRank ();
926928 SmallVector<Value> srcDcvs;
927929 srcDcvs.reserve (srcRank);
928930 for (Dimension d = 0 ; d < srcRank; d++) {
@@ -945,7 +947,7 @@ struct TensorReshapeRewriter : public OpRewritePattern<tensor::ReshapeOp> {
945947 collapsedSizes, collapsedDcvs);
946948
947949 ReassociationIndices expandIdx;
948- for (Dimension i = 0 ; i < dstTp. getDimRank (); i++)
950+ for (Dimension i = 0 ; i < dstTp-> getDimRank (); i++)
949951 expandIdx.push_back (i);
950952 SmallVector<ReassociationIndices, 1 > expandReass = {expandIdx};
951953 SmallVector<Value> dstDcvs;
@@ -958,8 +960,8 @@ struct TensorReshapeRewriter : public OpRewritePattern<tensor::ReshapeOp> {
958960 });
959961
960962 Value t = rewriter.create <LoadOp>(loc, foreachOp.getResult (0 ), true );
961- if (bufferTp != dstTp) {
962- auto dstRTT = dstTp. getRankedTensorType ();
963+ if (bufferTp != * dstTp) {
964+ auto dstRTT = dstTp-> getRankedTensorType ();
963965 Value converted = rewriter.create <ConvertOp>(loc, dstRTT, t).getResult ();
964966 rewriter.create <DeallocTensorOp>(loc, t);
965967 t = converted;
@@ -1139,13 +1141,13 @@ struct SparseTensorDimOpRewriter : public OpRewritePattern<tensor::DimOp> {
11391141 LogicalResult matchAndRewrite (tensor::DimOp op,
11401142 PatternRewriter &rewriter) const override {
11411143 std::optional<int64_t > dim = op.getConstantIndex ();
1142- auto stt = getSparseTensorType (op.getSource ());
1143- if (!dim || !stt. hasEncoding ())
1144+ auto stt = tryGetSparseTensorType (op.getSource ());
1145+ if (!dim || !stt || !stt-> hasEncoding ())
11441146 return failure ();
11451147
1146- if (stt. isPermutation ()) {
1148+ if (stt-> isPermutation ()) {
11471149 rewriter.replaceOpWithNewOp <LvlOp>(op, op.getSource (),
1148- toLvl (stt. getEncoding (), *dim));
1150+ toLvl (stt-> getEncoding (), *dim));
11491151 return success ();
11501152 }
11511153
@@ -1157,16 +1159,16 @@ struct SparseTensorDimOpRewriter : public OpRewritePattern<tensor::DimOp> {
11571159 // computed simply by lvl_size * block_size.
11581160 Location loc = op.getLoc ();
11591161 SmallVector<Value> maxLvlCrds;
1160- for (Level l = 0 ; l < stt. getLvlRank (); l++) {
1162+ for (Level l = 0 ; l < stt-> getLvlRank (); l++) {
11611163 Value lvlSz = rewriter.create <LvlOp>(loc, op.getSource (), l);
11621164 Value maxLvlCrd = rewriter.create <arith::SubIOp>(
11631165 loc, lvlSz, constantOne (rewriter, loc, rewriter.getIndexType ()));
11641166 maxLvlCrds.push_back (maxLvlCrd);
11651167 }
11661168
1167- AffineExpr lvl2DimExp = stt. getLvlToDim ().getResult (*dim);
1169+ AffineExpr lvl2DimExp = stt-> getLvlToDim ().getResult (*dim);
11681170 Value maxDimCrd = rewriter.create <affine::AffineApplyOp>(
1169- op.getLoc (), AffineMap::get (stt. getLvlRank (), 0 , lvl2DimExp),
1171+ op.getLoc (), AffineMap::get (stt-> getLvlRank (), 0 , lvl2DimExp),
11701172 maxLvlCrds);
11711173
11721174 Value dimSz = rewriter.create <arith::AddIOp>(
0 commit comments