@@ -679,6 +679,50 @@ class SparseDimOpConverter : public OpConversionPattern<tensor::DimOp> {
679679 }
680680};
681681
682+ // TODO: use a new SortCOO operation here instead of reusing convert op.
683+ struct SparseSortCOOConverter : public OpConversionPattern <ConvertOp> {
684+ using OpConversionPattern::OpConversionPattern;
685+ LogicalResult
686+ matchAndRewrite (ConvertOp op, ConvertOpAdaptor adaptor,
687+ ConversionPatternRewriter &rewriter) const override {
688+ // Direct conversion should have already been lowered.
689+ if (!op.isSortCOOConvert ())
690+ return failure ();
691+
692+ Location loc = op.getLoc ();
693+ MLIRContext *ctx = op.getContext ();
694+
695+ SparseTensorType srcStt = getSparseTensorType (op.getSource ());
696+ SparseTensorType dstStt = getSparseTensorType (op.getDest ());
697+
698+ // TODO: This should be verification rules for sort_coo operation.
699+ assert (dstStt.isAllOrdered () && !srcStt.isAllOrdered () &&
700+ isUniqueCOOType (srcStt.getRankedTensorType ()) &&
701+ isUniqueCOOType (dstStt.getRankedTensorType ()));
702+
703+ assert (dstStt.hasSameDimToLvl (srcStt));
704+
705+ // We don't need a mutable descriptor here as we perform sorting in-place.
706+ auto nnz = genValMemSize (rewriter, op.getLoc (), adaptor.getSource ());
707+ auto desc = getDescriptorFromTensorTuple (adaptor.getSource ());
708+ auto crd = desc.getAOSMemRef ();
709+ auto val = desc.getValMemRef ();
710+
711+ // Otherwise we need another data shuffle and a non-identity map.
712+ assert (dstStt.hasSameDimToLvl (srcStt));
713+ auto id = AffineMap::getMultiDimIdentityMap (srcStt.getLvlRank (), ctx);
714+
715+ rewriter.create <SortOp>(loc, nnz, crd, ValueRange{val}, id,
716+ rewriter.getIndexAttr (0 ),
717+ SparseTensorSortKind::HybridQuickSort);
718+
719+ // Since we do in-place sorting, the destinate tensor will have the same set
720+ // of memrefs as the source tensor.
721+ rewriter.replaceOp (op, adaptor.getSource ());
722+ return success ();
723+ }
724+ };
725+
682726template <typename Op, StorageSpecifierKind kind>
683727class SparseSliceGetterOpConverter : public OpConversionPattern <Op> {
684728public:
@@ -1101,6 +1145,9 @@ class SparseConvertConverter : public OpConversionPattern<ConvertOp> {
11011145 LogicalResult
11021146 matchAndRewrite (ConvertOp op, OpAdaptor adaptor,
11031147 ConversionPatternRewriter &rewriter) const override {
1148+ if (op.isSortCOOConvert ())
1149+ return failure ();
1150+
11041151 SparseTensorEncodingAttr encDst = getSparseTensorEncoding (op.getType ());
11051152 SparseTensorEncodingAttr encSrc =
11061153 getSparseTensorEncoding (op.getSource ().getType ());
@@ -1554,6 +1601,7 @@ void mlir::populateSparseTensorCodegenPatterns(
15541601 SparseCastConverter, SparseExtractSliceConverter,
15551602 SparseTensorLoadConverter, SparseExpandConverter,
15561603 SparseCompressConverter, SparseInsertConverter,
1604+ SparseSortCOOConverter,
15571605 SparseSliceGetterOpConverter<ToSliceOffsetOp,
15581606 StorageSpecifierKind::DimOffset>,
15591607 SparseSliceGetterOpConverter<ToSliceStrideOp,
0 commit comments