From 66fb6bb37e32eac9838e6913bc5f8f9ffd089f8c Mon Sep 17 00:00:00 2001 From: Matthias Springer Date: Sat, 16 Nov 2024 05:11:48 +0100 Subject: [PATCH] [mlir][SparseTensor][NFC] Pass tensor type to descriptor helper --- .../Transforms/SparseTensorCodegen.cpp | 58 ++++++++++++------- .../Transforms/Utils/CodegenUtils.cpp | 5 -- .../Transforms/Utils/CodegenUtils.h | 3 - .../Transforms/Utils/SparseTensorDescriptor.h | 12 ++-- 4 files changed, 44 insertions(+), 34 deletions(-) diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp index bf7b3f9bec558..25fca49cb0154 100644 --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp @@ -646,10 +646,11 @@ class SparseLvlOpConverter : public OpConversionPattern { matchAndRewrite(LvlOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { std::optional lvl = op.getConstantLvlIndex(); - if (!lvl || !getSparseTensorEncoding(adaptor.getSource().getType())) + RankedTensorType srcType = op.getSource().getType(); + if (!lvl || !getSparseTensorEncoding(srcType)) return failure(); - auto desc = getDescriptorFromTensorTuple(adaptor.getSource()); + auto desc = getDescriptorFromTensorTuple(adaptor.getSource(), srcType); auto sz = desc.getLvlSize(rewriter, op.getLoc(), *lvl); rewriter.replaceOp(op, sz); @@ -675,8 +676,9 @@ struct SparseReorderCOOConverter : public OpConversionPattern { assert(dstStt.hasSameDimToLvl(srcStt)); // We don't need a mutable descriptor here as we perform sorting in-place. - auto nnz = genValMemSize(rewriter, op.getLoc(), adaptor.getInputCoo()); - auto desc = getDescriptorFromTensorTuple(adaptor.getInputCoo()); + auto desc = getDescriptorFromTensorTuple(adaptor.getInputCoo(), + op.getInputCoo().getType()); + auto nnz = desc.getValMemSize(rewriter, op.getLoc()); auto crd = desc.getAOSMemRef(); auto val = desc.getValMemRef(); @@ -704,7 +706,8 @@ class SparseSliceGetterOpConverter : public OpConversionPattern { matchAndRewrite(Op op, typename Op::Adaptor adaptor, ConversionPatternRewriter &rewriter) const override { // Simply lowers to specifer.get operation. - auto desc = getDescriptorFromTensorTuple(adaptor.getSlice()); + auto desc = getDescriptorFromTensorTuple(adaptor.getSlice(), + op.getSlice().getType()); auto v = desc.getSpecifierField(rewriter, op.getLoc(), kind, op.getDim().getZExtValue()); @@ -762,7 +765,8 @@ class SparseTensorAllocConverter Location loc = op.getLoc(); // Deal with copy. if (op.getCopy()) { - auto desc = getDescriptorFromTensorTuple(adaptor.getCopy()); + auto desc = getDescriptorFromTensorTuple( + adaptor.getCopy(), cast(op.getCopy().getType())); SmallVector fields; fields.reserve(desc.getNumFields()); // Memcpy on memref fields. @@ -868,7 +872,9 @@ class SparseTensorDeallocConverter if (createDeallocs) { // Replace the sparse tensor deallocation with field deallocations. Location loc = op.getLoc(); - auto desc = getDescriptorFromTensorTuple(adaptor.getTensor()); + auto desc = getDescriptorFromTensorTuple( + adaptor.getTensor(), + cast(op.getTensor().getType())); for (auto input : desc.getMemRefFields()) // Deallocate every buffer used to store the sparse tensor handler. rewriter.create(loc, input); @@ -889,7 +895,8 @@ class SparseTensorLoadConverter : public OpConversionPattern { matchAndRewrite(LoadOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { // Prepare descriptor. - auto desc = getDescriptorFromTensorTuple(adaptor.getTensor()); + auto desc = getDescriptorFromTensorTuple(adaptor.getTensor(), + op.getTensor().getType()); // Generate optional insertion finalization code. if (op.getHasInserts()) genEndInsert(rewriter, op.getLoc(), desc); @@ -909,7 +916,8 @@ class SparseExpandConverter : public OpConversionPattern { if (!getSparseTensorEncoding(op.getTensor().getType())) return failure(); Location loc = op->getLoc(); - auto desc = getDescriptorFromTensorTuple(adaptor.getTensor()); + auto desc = getDescriptorFromTensorTuple(adaptor.getTensor(), + op.getTensor().getType()); const auto srcType = getSparseTensorType(op.getTensor()); Type eltType = srcType.getElementType(); Type boolType = rewriter.getIntegerType(1); @@ -959,7 +967,8 @@ class SparseCompressConverter : public OpConversionPattern { ConversionPatternRewriter &rewriter) const override { Location loc = op->getLoc(); SmallVector fields; - auto desc = getMutDescriptorFromTensorTuple(adaptor.getTensor(), fields); + auto desc = getMutDescriptorFromTensorTuple(adaptor.getTensor(), fields, + op.getTensor().getType()); Value values = adaptor.getValues(); Value filled = adaptor.getFilled(); Value added = adaptor.getAdded(); @@ -1032,7 +1041,8 @@ class SparseInsertConverter : public OpConversionPattern { assert(stt.isIdentity() && "Run reinterpret-map before conversion."); Location loc = op.getLoc(); - auto desc = getDescriptorFromTensorTuple(adaptor.getDest()); + auto desc = + getDescriptorFromTensorTuple(adaptor.getDest(), op.getDest().getType()); TypeRange flatSpTensorTps = desc.getFields().getTypes(); SmallVector params = llvm::to_vector(desc.getFields()); params.append(adaptor.getIndices().begin(), adaptor.getIndices().end()); @@ -1059,7 +1069,8 @@ class SparseToPositionsConverter : public OpConversionPattern { // of this operation truly observe size, not capacity! Location loc = op.getLoc(); Level lvl = op.getLevel(); - auto desc = getDescriptorFromTensorTuple(adaptor.getTensor()); + auto desc = getDescriptorFromTensorTuple(adaptor.getTensor(), + op.getTensor().getType()); auto mem = desc.getPosMemRef(lvl); auto size = desc.getPosMemSize(rewriter, loc, lvl); rewriter.replaceOp(op, genSliceToSize(rewriter, loc, mem, size)); @@ -1081,7 +1092,8 @@ class SparseToCoordinatesConverter // of this operation truly observe size, not capacity! Location loc = op.getLoc(); Level lvl = op.getLevel(); - auto desc = getDescriptorFromTensorTuple(adaptor.getTensor()); + auto desc = getDescriptorFromTensorTuple(adaptor.getTensor(), + op.getTensor().getType()); auto mem = desc.getCrdMemRefOrView(rewriter, loc, lvl); if (lvl < getSparseTensorType(op.getTensor()).getAoSCOOStart()) { auto size = desc.getCrdMemSize(rewriter, loc, lvl); @@ -1106,7 +1118,8 @@ class SparseToCoordinatesBufferConverter // of this operation truly observe size, not capacity! Location loc = op.getLoc(); Level lvl = getSparseTensorType(op.getTensor()).getAoSCOOStart(); - auto desc = getDescriptorFromTensorTuple(adaptor.getTensor()); + auto desc = getDescriptorFromTensorTuple(adaptor.getTensor(), + op.getTensor().getType()); auto mem = desc.getAOSMemRef(); auto size = desc.getCrdMemSize(rewriter, loc, lvl); rewriter.replaceOp(op, genSliceToSize(rewriter, loc, mem, size)); @@ -1126,7 +1139,8 @@ class SparseToValuesConverter : public OpConversionPattern { // The view is restricted to the actual size to ensure clients // of this operation truly observe size, not capacity! Location loc = op.getLoc(); - auto desc = getDescriptorFromTensorTuple(adaptor.getTensor()); + auto desc = getDescriptorFromTensorTuple(adaptor.getTensor(), + op.getTensor().getType()); auto mem = desc.getValMemRef(); auto size = desc.getValMemSize(rewriter, loc); rewriter.replaceOp(op, genSliceToSize(rewriter, loc, mem, size)); @@ -1172,7 +1186,8 @@ class SparseConvertConverter : public OpConversionPattern { // else: // dst = memref.copy(src) Location loc = op.getLoc(); - auto srcDesc = getDescriptorFromTensorTuple(adaptor.getSource()); + auto srcDesc = getDescriptorFromTensorTuple(adaptor.getSource(), + op.getSource().getType()); SmallVector fields; foreachFieldAndTypeInSparseTensor( SparseTensorType(cast(op.getResult().getType())), @@ -1236,7 +1251,8 @@ class SparseExtractSliceConverter assert(srcEnc.withoutDimSlices() == dstEnc.withoutDimSlices()); SmallVector fields; - auto desc = getMutDescriptorFromTensorTuple(adaptor.getSource(), fields); + auto desc = getMutDescriptorFromTensorTuple(adaptor.getSource(), fields, + op.getSource().getType()); auto newSpec = rewriter.create( loc, StorageSpecifierType::get(ctx, dstEnc), desc.getSpecifier()); @@ -1285,8 +1301,9 @@ class SparseNumberOfEntriesConverter // Query memSizes for the actually stored values. // FIXME: the nse value computed in this way might be wrong when there is // any "loose_compressed" level. - rewriter.replaceOp( - op, genValMemSize(rewriter, op.getLoc(), adaptor.getTensor())); + auto desc = getDescriptorFromTensorTuple(adaptor.getTensor(), + op.getTensor().getType()); + rewriter.replaceOp(op, desc.getValMemSize(rewriter, op.getLoc())); return success(); } }; @@ -1415,7 +1432,8 @@ struct SparseDisassembleOpConverter LogicalResult matchAndRewrite(DisassembleOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - auto desc = getDescriptorFromTensorTuple(adaptor.getTensor()); + auto desc = getDescriptorFromTensorTuple(adaptor.getTensor(), + op.getTensor().getType()); Location loc = op.getLoc(); SmallVector retMem; SmallVector retLen; diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/CodegenUtils.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/CodegenUtils.cpp index de553a5f9bf08..f92382472b478 100644 --- a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/CodegenUtils.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/CodegenUtils.cpp @@ -554,11 +554,6 @@ sparse_tensor::genToMemref(OpBuilder &builder, Location loc, Value tensor) { .getResult(); } -Value sparse_tensor::genValMemSize(OpBuilder &builder, Location loc, - Value tensor) { - return getDescriptorFromTensorTuple(tensor).getValMemSize(builder, loc); -} - Value sparse_tensor::createOrFoldSliceOffsetOp(OpBuilder &builder, Location loc, Value tensor, Dimension dim) { auto enc = getSparseTensorEncoding(tensor.getType()); diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/CodegenUtils.h b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/CodegenUtils.h index d0ef8a6860bb2..dc017e6baa6dc 100644 --- a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/CodegenUtils.h +++ b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/CodegenUtils.h @@ -270,9 +270,6 @@ void storeAll(OpBuilder &builder, Location loc, Value mem, ValueRange vs, TypedValue genToMemref(OpBuilder &builder, Location loc, Value tensor); -/// Generates code to retrieve the values size for the sparse tensor. -Value genValMemSize(OpBuilder &builder, Location loc, Value tensor); - /// Generates code to retrieve the slice offset for the sparse tensor slice, /// return a constant if the offset is statically known. Value createOrFoldSliceOffsetOp(OpBuilder &builder, Location loc, Value tensor, diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorDescriptor.h b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorDescriptor.h index c2f631605bf4b..89858546e37e1 100644 --- a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorDescriptor.h +++ b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorDescriptor.h @@ -245,18 +245,18 @@ inline Value genTuple(OpBuilder &builder, Location loc, return genTuple(builder, loc, desc.getRankedTensorType(), desc.getFields()); } -inline SparseTensorDescriptor getDescriptorFromTensorTuple(Value tensor) { +inline SparseTensorDescriptor +getDescriptorFromTensorTuple(Value tensor, RankedTensorType type) { auto tuple = getTuple(tensor); - SparseTensorType stt(cast(tuple.getResultTypes()[0])); - return SparseTensorDescriptor(stt, tuple.getInputs()); + return SparseTensorDescriptor(SparseTensorType(type), tuple.getInputs()); } inline MutSparseTensorDescriptor -getMutDescriptorFromTensorTuple(Value tensor, SmallVectorImpl &fields) { +getMutDescriptorFromTensorTuple(Value tensor, SmallVectorImpl &fields, + RankedTensorType type) { auto tuple = getTuple(tensor); fields.assign(tuple.getInputs().begin(), tuple.getInputs().end()); - SparseTensorType stt(cast(tuple.getResultTypes()[0])); - return MutSparseTensorDescriptor(stt, fields); + return MutSparseTensorDescriptor(SparseTensorType(type), fields); } } // namespace sparse_tensor