@@ -600,8 +600,8 @@ class SparseCallConverter : public OpConversionPattern<func::CallOp> {
600600 flattenOperands (adaptor.getOperands (), flattened);
601601 auto newCall = rewriter.create <func::CallOp>(loc, op.getCallee (),
602602 finalRetTy, flattened);
603- // (2) Create cast operation for sparse tensor returns.
604- SmallVector<Value> castedRet ;
603+ // (2) Gather sparse tensor returns.
604+ SmallVector<SmallVector< Value>> packedResultVals ;
605605 // Tracks the offset of current return value (of the original call)
606606 // relative to the new call (after sparse tensor flattening);
607607 unsigned retOffset = 0 ;
@@ -618,21 +618,22 @@ class SparseCallConverter : public OpConversionPattern<func::CallOp> {
618618 assert (!sparseFlat.empty ());
619619 if (sparseFlat.size () > 1 ) {
620620 auto flatSize = sparseFlat.size ();
621- ValueRange fields (iterator_range<ResultRange::iterator>(
622- newCall.result_begin () + retOffset,
623- newCall.result_begin () + retOffset + flatSize));
624- castedRet.push_back (genTuple (rewriter, loc, retType, fields));
621+ packedResultVals.emplace_back ();
622+ llvm::append_range (packedResultVals.back (),
623+ newCall.getResults ().slice (retOffset, flatSize));
625624 retOffset += flatSize;
626625 } else {
627626 // If this is an 1:1 conversion, no need for casting.
628- castedRet.push_back (newCall.getResult (retOffset));
627+ packedResultVals.emplace_back ();
628+ packedResultVals.back ().push_back (newCall.getResult (retOffset));
629629 retOffset++;
630630 }
631631 sparseFlat.clear ();
632632 }
633633
634- assert (castedRet.size () == op.getNumResults ());
635- rewriter.replaceOp (op, castedRet);
634+ assert (packedResultVals.size () == op.getNumResults ());
635+ rewriter.replaceOpWithMultiple (
636+ op, llvm::to_vector_of<ValueRange>(packedResultVals));
636637 return success ();
637638 }
638639};
@@ -776,7 +777,7 @@ class SparseTensorAllocConverter
776777 // Reuses specifier.
777778 fields.push_back (desc.getSpecifier ());
778779 assert (fields.size () == desc.getNumFields ());
779- rewriter.replaceOp (op, genTuple (rewriter, loc, resType, fields) );
780+ rewriter.replaceOpWithMultiple (op, { fields} );
780781 return success ();
781782 }
782783
@@ -796,7 +797,7 @@ class SparseTensorAllocConverter
796797 sizeHint, lvlSizesValues, fields);
797798
798799 // Replace operation with resulting memrefs.
799- rewriter.replaceOp (op, genTuple (rewriter, loc, resType, fields) );
800+ rewriter.replaceOpWithMultiple (op, { fields} );
800801 return success ();
801802 }
802803
@@ -837,7 +838,7 @@ class SparseTensorEmptyConverter : public OpConversionPattern<tensor::EmptyOp> {
837838 sizeHint, lvlSizesValues, fields);
838839
839840 // Replace operation with resulting memrefs.
840- rewriter.replaceOp (op, genTuple (rewriter, loc, resType, fields) );
841+ rewriter.replaceOpWithMultiple (op, { fields} );
841842 return success ();
842843 }
843844
@@ -893,7 +894,7 @@ class SparseTensorLoadConverter : public OpConversionPattern<LoadOp> {
893894 if (op.getHasInserts ())
894895 genEndInsert (rewriter, op.getLoc (), desc);
895896 // Replace operation with resulting memrefs.
896- rewriter.replaceOp (op, genTuple (rewriter, op. getLoc (), desc) );
897+ rewriter.replaceOpWithMultiple (op, {desc. getFields ()} );
897898 return success ();
898899 }
899900};
@@ -1006,15 +1007,14 @@ class SparseCompressConverter : public OpConversionPattern<CompressOp> {
10061007 rewriter.create <scf::YieldOp>(loc, insertRet);
10071008
10081009 rewriter.setInsertionPointAfter (loop);
1009- Value result = genTuple (rewriter, loc, dstType, loop->getResults ());
10101010 // Deallocate the buffers on exit of the full loop nest.
10111011 Operation *parent = getTop (op);
10121012 rewriter.setInsertionPointAfter (parent);
10131013 rewriter.create <memref::DeallocOp>(loc, values);
10141014 rewriter.create <memref::DeallocOp>(loc, filled);
10151015 rewriter.create <memref::DeallocOp>(loc, added);
10161016 // Replace operation with resulting memrefs.
1017- rewriter.replaceOp (op, result );
1017+ rewriter.replaceOpWithMultiple (op, {loop-> getResults ()} );
10181018 return success ();
10191019 }
10201020};
@@ -1041,8 +1041,7 @@ class SparseInsertConverter : public OpConversionPattern<tensor::InsertOp> {
10411041 params, /* genCall=*/ true );
10421042 SmallVector<Value> ret = insertGen.genCallOrInline (rewriter, loc);
10431043 // Replace operation with resulting memrefs.
1044- rewriter.replaceOp (op,
1045- genTuple (rewriter, loc, op.getDest ().getType (), ret));
1044+ rewriter.replaceOpWithMultiple (op, {ret});
10461045 return success ();
10471046 }
10481047};
@@ -1215,8 +1214,7 @@ class SparseConvertConverter : public OpConversionPattern<ConvertOp> {
12151214 return true ;
12161215 });
12171216
1218- rewriter.replaceOp (
1219- op, genTuple (rewriter, loc, op.getResult ().getType (), fields));
1217+ rewriter.replaceOpWithMultiple (op, {fields});
12201218 return success ();
12211219 }
12221220};
@@ -1271,8 +1269,7 @@ class SparseExtractSliceConverter
12711269 // NOTE: we can not generate tuples directly from descriptor here, as the
12721270 // descriptor is holding the original type, yet we want the slice type
12731271 // here (they shared every memref but with an updated specifier).
1274- rewriter.replaceOp (op, genTuple (rewriter, loc, op.getResult ().getType (),
1275- desc.getFields ()));
1272+ rewriter.replaceOpWithMultiple (op, {desc.getFields ()});
12761273 return success ();
12771274 }
12781275};
@@ -1403,7 +1400,7 @@ struct SparseAssembleOpConverter : public OpConversionPattern<AssembleOp> {
14031400 }
14041401 desc.setValMemSize (rewriter, loc, memSize);
14051402
1406- rewriter.replaceOp (op, genTuple (rewriter, loc, desc) );
1403+ rewriter.replaceOpWithMultiple (op, { desc. getFields ()} );
14071404 return success ();
14081405 }
14091406};
@@ -1577,7 +1574,7 @@ struct SparseNewConverter : public OpConversionPattern<NewOp> {
15771574 EmitCInterface::Off);
15781575
15791576 // Replace operation with resulting memrefs.
1580- rewriter.replaceOp (op, genTuple (rewriter, loc, dstTp, fields) );
1577+ rewriter.replaceOpWithMultiple (op, { fields} );
15811578 return success ();
15821579 }
15831580};
0 commit comments