From 0d667f53447230f054f738b47d9e42f31ec49467 Mon Sep 17 00:00:00 2001 From: Ryutaro Okada <1015ryu88@gmail.com> Date: Wed, 12 Nov 2025 20:23:25 +0900 Subject: [PATCH] [MLIR] Extend linalg.pack and linalg.unpack to accept memref Extend linalg.pack and linalg.unpack to accept memref operands in addition to tensors. As part of this change, we now disable all transformations when these ops have memref semantics. Closes https://github.com/llvm/llvm-project/issues/129004 Co-authored-by: Hyunsung Lee Signed-off-by: Ryutaro Okada <1015ryu88@gmail.com> --- .../Dialect/Linalg/IR/LinalgRelayoutOps.td | 99 ++-- mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp | 483 ++++++++++++++++-- .../Linalg/Transforms/BlockPackMatmul.cpp | 4 + .../Transforms/DataLayoutPropagation.cpp | 69 ++- .../Transforms/PackAndUnpackPatterns.cpp | 50 +- .../Linalg/Transforms/TilingInterfaceImpl.cpp | 16 + .../Dialect/Linalg/Transforms/Transforms.cpp | 38 +- .../Linalg/Transforms/Vectorization.cpp | 15 +- mlir/test/Dialect/Linalg/canonicalize.mlir | 59 +++ .../Dialect/Linalg/memref-pack-unpack.mlir | 47 ++ mlir/test/Dialect/Linalg/roundtrip.mlir | 26 + 11 files changed, 771 insertions(+), 135 deletions(-) create mode 100644 mlir/test/Dialect/Linalg/memref-pack-unpack.mlir diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgRelayoutOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgRelayoutOps.td index 784bdd8e22f1f..ecf8b63ec3c71 100644 --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgRelayoutOps.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgRelayoutOps.td @@ -7,11 +7,7 @@ //===----------------------------------------------------------------------===// // // This file defines Pack + Unpack Ops that have been moved from the Tensor -// dialect. As such, these are defined as memory-effect-free and only accept -// "tensors" as inputs. -// -// TODO: Once a good motivating example is identified, relax these -// restrictions. +// dialect. // //===----------------------------------------------------------------------===// @@ -30,24 +26,27 @@ include "mlir/IR/OpAsmInterface.td" // RelayoutOp //===----------------------------------------------------------------------===// -class Linalg_RelayoutOp traits = []> : - Op, - DestinationStyleOpInterface, LinalgRelayoutOpInterface, - ConditionallySpeculatable, NoMemoryEffect, - DeclareOpInterfaceMethods traits = []> + : Op, + DestinationStyleOpInterface, LinalgRelayoutOpInterface, + ConditionallySpeculatable, + DeclareOpInterfaceMethods, + DeclareOpInterfaceMethods< + ReifyRankedShapedTypeOpInterface, [ "reifyResultShapes"]>, - TypesMatchWith<"result type matches type of dest", - "dest", "result", - "$_self">])> { + OptionalTypesMatchWith<"result type matches type of dest", + "dest", "result", "$_self">])> { code commonExtraClassDeclaration = [{ size_t getSourceRank() { return getSourceType().getRank(); }; size_t getDestRank() { return getDestType().getRank(); }; - RankedTensorType getSourceType() { - return ::llvm::cast(getSource().getType()); }; - RankedTensorType getDestType() { - return ::llvm::cast(getDest().getType()); }; + ShapedType getSourceType() { + return ::llvm::cast(getSource().getType()); }; + ShapedType getDestType() { + return ::llvm::cast(getDest().getType()); }; MutableOperandRange getDpsInitsMutable() { return getDestMutable(); } @@ -192,23 +191,12 @@ def Linalg_PackOp : Linalg_RelayoutOp<"pack", [ // expect tensor<2x8xf32> because CeilDiv(9, 8) = 2 ``` }]; - let arguments = (ins AnyRankedTensor:$source, - AnyRankedTensor:$dest, - Optional:$padding_value, - DefaultValuedOptionalAttr:$outer_dims_perm, - DenseI64ArrayAttr:$inner_dims_pos, - Variadic:$inner_tiles, - DenseI64ArrayAttr:$static_inner_tiles); - let results = (outs AnyRankedTensor:$result); - let assemblyFormat = [{ - $source - (`padding_value` `(` $padding_value^ `:` type($padding_value) `)`)? - (`outer_dims_perm` `=` $outer_dims_perm^)? - `inner_dims_pos` `=` $inner_dims_pos - `inner_tiles` `=` - custom($inner_tiles, $static_inner_tiles) - `into` $dest attr-dict `:` type($source) `->` type($dest) - }]; + let arguments = (ins TensorOrMemRef<[AnyType]>:$source, + TensorOrMemRef<[AnyType]>:$dest, Optional:$padding_value, + DefaultValuedOptionalAttr:$outer_dims_perm, + DenseI64ArrayAttr:$inner_dims_pos, Variadic:$inner_tiles, + DenseI64ArrayAttr:$static_inner_tiles); + let results = (outs Optional:$result); let builders = [ OpBuilder<(ins "Value":$source, "Value":$dest, @@ -218,7 +206,7 @@ def Linalg_PackOp : Linalg_RelayoutOp<"pack", [ CArg<"ArrayRef", "{}">:$outerDimsPerm)> ]; - let extraClassDeclaration = commonExtraClassDeclaration # [{ + let extraClassDeclaration = commonExtraClassDeclaration#[{ // Method to get the shape of the result as `SmallVector`. // This is a static method to allow getting the shape of the destination // expected while creating a `pack` op. @@ -230,7 +218,19 @@ def Linalg_PackOp : Linalg_RelayoutOp<"pack", [ // Method to get the `RankedTensorType` of the result based on the inner // tiles, position of the inner tiles (innerDimsPos) and interchange vector // of outer loops (outerDimsPerm). - static RankedTensorType inferPackedType(RankedTensorType sourceType, + static RankedTensorType inferPackedTensorType(RankedTensorType sourceType, + ArrayRef innerTileSizes, ArrayRef innerDimsPos, + ArrayRef outerDimsPerm = {}); + + // Method to get the `MemRefType` of the result based on the inner + // tiles, position of the inner tiles (innerDimsPos) and interchange vector + // of outer loops (outerDimsPerm). + static MemRefType inferPackedMemRefType(MemRefType sourceType, + ArrayRef innerTileSizes, ArrayRef innerDimsPos, + ArrayRef outerDimsPerm = {}); + + // Returns the shape of the packed type. It is a shared helper helps type inference methods in a way that ensures that they agree on which dimensions are dynamic. + static SmallVector inferPackedShape(ArrayRef inputShape, ArrayRef innerTileSizes, ArrayRef innerDimsPos, ArrayRef outerDimsPerm = {}); @@ -282,6 +282,8 @@ def Linalg_PackOp : Linalg_RelayoutOp<"pack", [ let hasCanonicalizeMethod = 1; let hasFolder = 1; + + let hasCustomAssemblyFormat = 1; } //===----------------------------------------------------------------------===// @@ -349,21 +351,12 @@ def Linalg_UnPackOp : Linalg_RelayoutOp<"unpack"> { // Outer Dims: 9x3x8 Inner Dims: 4x2 ``` }]; - let arguments = (ins AnyRankedTensor:$source, - AnyRankedTensor:$dest, - DefaultValuedOptionalAttr:$outer_dims_perm, - DenseI64ArrayAttr:$inner_dims_pos, - Variadic:$inner_tiles, - DenseI64ArrayAttr:$static_inner_tiles); - let results = (outs AnyRankedTensor:$result); - let assemblyFormat = [{ - $source - (`outer_dims_perm` `=` $outer_dims_perm^)? - `inner_dims_pos` `=` $inner_dims_pos - `inner_tiles` `=` - custom($inner_tiles, $static_inner_tiles) - `into` $dest attr-dict `:` type($source) `->` type($dest) - }]; + let arguments = (ins TensorOrMemRef<[AnyType]>:$source, + TensorOrMemRef<[AnyType]>:$dest, + DefaultValuedOptionalAttr:$outer_dims_perm, + DenseI64ArrayAttr:$inner_dims_pos, Variadic:$inner_tiles, + DenseI64ArrayAttr:$static_inner_tiles); + let results = (outs Optional:$result); let builders = [ OpBuilder<(ins "Value":$source, "Value":$dest, @@ -406,6 +399,8 @@ def Linalg_UnPackOp : Linalg_RelayoutOp<"unpack"> { let hasCanonicalizeMethod = 1; let hasFolder = 1; + + let hasCustomAssemblyFormat = 1; } #endif // LINALG_RELEAYOUT_OPS diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp index 3dc45edf4a23f..6bcf9281f4e52 100644 --- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -4968,12 +4968,12 @@ Speculation::Speculatability ElementwiseOp::getSpeculatability() { template SmallVector getPackedOuterShapeWithoutTransposition(OpTy packOrUnPack) { - RankedTensorType packedType = (std::is_same::value) - ? packOrUnPack.getDestType() - : packOrUnPack.getSourceType(); - RankedTensorType unpackedType = (std::is_same::value) - ? packOrUnPack.getSourceType() - : packOrUnPack.getDestType(); + ShapedType packedType = (std::is_same::value) + ? packOrUnPack.getDestType() + : packOrUnPack.getSourceType(); + ShapedType unpackedType = (std::is_same::value) + ? packOrUnPack.getSourceType() + : packOrUnPack.getDestType(); SmallVector result( packedType.getShape().take_front(unpackedType.getRank())); if (!packOrUnPack.getOuterDimsPerm().empty()) { @@ -5107,15 +5107,34 @@ static LogicalResult commonVerifierPackAndUnPackOp(OpTy packOrUnPack) { return llvm::any_of(tiles, isZeroInteger); }; + // Verify that the source and destination are ranked types. + if (!packOrUnPack.getSourceType().hasRank() || + !packOrUnPack.getDestType().hasRank()) { + return op->emitError("expected both source and destination to have rank"); + } + + // Verify that the Operation does not have mixed tensor/buffer semantics. + if (!packOrUnPack.hasPureBufferSemantics() && + !packOrUnPack.hasPureTensorSemantics()) { + return op->emitError("mixing tensor and buffer semantics is not allowed"); + } + const unsigned numResults = packOrUnPack.getNumResults(); + if (packOrUnPack.hasPureTensorSemantics() && numResults != 1) { + return op->emitError("expected 1 result, got ") << numResults; + } + if (packOrUnPack.hasPureBufferSemantics() && numResults != 0) { + return op->emitError("expected 0 results, got ") << numResults; + } + // Verify tiles. Do not allow zero tiles. SmallVector mixedTiles = packOrUnPack.getMixedTiles(); if (hasZeros(mixedTiles)) return op->emitError("invalid zero tile factor"); // Verify inner_dims_pos and outer_dims_perm. - RankedTensorType unpackedType = (std::is_same::value) - ? packOrUnPack.getSourceType() - : packOrUnPack.getDestType(); + ShapedType unpackedType = (std::is_same::value) + ? packOrUnPack.getSourceType() + : packOrUnPack.getDestType(); size_t unpackedRank = unpackedType.getRank(); ArrayRef innerDimsPos = packOrUnPack.getInnerDimsPos(); ArrayRef outerDimPerm = packOrUnPack.getOuterDimsPerm(); @@ -5152,8 +5171,9 @@ static LogicalResult commonVerifierPackAndUnPackOp(OpTy packOrUnPack) { // Verify result shape is greater than the minimum expected // by the pack operation, and that the output shape // represents full tiles. - RankedTensorType expectedPackedType = PackOp::inferPackedType( - unpackedType, packOrUnPack.getStaticTiles(), innerDimsPos, outerDimPerm); + SmallVector expectedPackedShape = PackOp::inferPackedShape( + unpackedType.getShape(), packOrUnPack.getStaticTiles(), + packOrUnPack.getInnerDimsPos(), packOrUnPack.getOuterDimsPerm()); if (!llvm::all_of( llvm::zip(packedType.getShape().take_back(mixedTiles.size()), mixedTiles), @@ -5170,11 +5190,20 @@ static LogicalResult commonVerifierPackAndUnPackOp(OpTy packOrUnPack) { return op->emitError("mismatch in inner tile sizes specified and shaped of " "tiled dimension in the packed type"); } - if (failed(verifyCompatibleShape(expectedPackedType.getShape(), - packedType.getShape()))) { + if (failed( + verifyCompatibleShape(expectedPackedShape, packedType.getShape()))) { + auto elementType = unpackedType.getElementType(); + Type expectedType, actualType; + if (packOrUnPack.hasPureTensorSemantics()) { + expectedType = RankedTensorType::get(expectedPackedShape, elementType); + actualType = RankedTensorType::get(packedType.getShape(), elementType); + } else { + expectedType = MemRefType::get(expectedPackedShape, elementType); + actualType = MemRefType::get(packedType.getShape(), elementType); + } return op->emitError("expected ") - << expectedPackedType << " for the packed domain value, got " - << packedType; + << expectedType << " for the packed domain value, got " + << actualType; } return success(); } @@ -5235,9 +5264,158 @@ commonPermutationOfPackAndUnPackOp(OpTy packOrUnPackOp, //===----------------------------------------------------------------------===// void PackOp::getAsmResultNames(function_ref setNameFn) { + if (getNumResults() == 0) + return; setNameFn(getResult(), "pack"); } +ParseResult PackOp::parse(OpAsmParser &parser, OperationState &result) { + OpAsmParser::UnresolvedOperand source, dest; + SmallVector dynamicTiles; + SmallVector paddingValue; + SmallVector paddingValueType; + SmallVector staticTiles; + DenseI64ArrayAttr innerDimsPos, outerDimsPerm; + Type sourceType, destType, resultType; + + if (parser.parseOperand(source)) + return failure(); + + if (succeeded(parser.parseOptionalKeyword("padding_value"))) { + if (parser.parseLParen() || + parser.parseOperandList(paddingValue, /*requiredOperandCount=*/1) || + parser.parseColon() || parser.parseTypeList(paddingValueType) || + parser.parseRParen()) + return failure(); + } + + if (succeeded(parser.parseOptionalKeyword("outer_dims_perm"))) { + if (parser.parseEqual()) + return failure(); + + SmallVector outerDimsPermVec; + if (parser.parseCommaSeparatedList(AsmParser::Delimiter::Square, [&]() { + int64_t value; + if (parser.parseInteger(value)) + return failure(); + outerDimsPermVec.push_back(value); + return success(); + })) + return failure(); + outerDimsPerm = parser.getBuilder().getDenseI64ArrayAttr(outerDimsPermVec); + } + + if (parser.parseKeyword("inner_dims_pos") || parser.parseEqual()) + return failure(); + + SmallVector innerDimsPosVec; + if (parser.parseCommaSeparatedList(AsmParser::Delimiter::Square, [&]() { + int64_t value; + if (parser.parseInteger(value)) + return failure(); + innerDimsPosVec.push_back(value); + return success(); + })) + return failure(); + innerDimsPos = parser.getBuilder().getDenseI64ArrayAttr(innerDimsPosVec); + + if (parser.parseKeyword("inner_tiles") || parser.parseEqual()) + return failure(); + + DenseI64ArrayAttr staticTilesAttr; + if (parseDynamicIndexList(parser, dynamicTiles, staticTilesAttr)) + return failure(); + for (auto val : staticTilesAttr.asArrayRef()) + staticTiles.push_back(val); + + if (parser.parseKeyword("into") || parser.parseOperand(dest)) + return failure(); + + if (parser.parseOptionalAttrDict(result.attributes)) + return failure(); + + if (parser.parseColon() || parser.parseType(sourceType)) + return failure(); + + bool hasArrow = succeeded(parser.parseOptionalArrow()); + if (hasArrow) { + if (parser.parseType(destType)) + return failure(); + } + + bool isMemRef = llvm::isa(sourceType); + if (!hasArrow) { + return parser.emitError(parser.getCurrentLocation(), + "pack/unpack requires '->' and destination type"); + } + + if (!isMemRef) { + resultType = destType; + } + + if (parser.resolveOperand(source, sourceType, result.operands) || + parser.resolveOperand(dest, destType, result.operands)) + return failure(); + + if (!paddingValue.empty() && + parser.resolveOperands(paddingValue, paddingValueType[0], + result.operands)) + return failure(); + + if (!dynamicTiles.empty() && + parser.resolveOperands(dynamicTiles, parser.getBuilder().getIndexType(), + result.operands)) + return failure(); + + result.addAttribute("static_inner_tiles", + parser.getBuilder().getDenseI64ArrayAttr(staticTiles)); + result.addAttribute("inner_dims_pos", innerDimsPos); + if (outerDimsPerm) + result.addAttribute("outer_dims_perm", outerDimsPerm); + + SmallVector segmentSizes = { + 1, 1, static_cast(paddingValue.size()), + static_cast(dynamicTiles.size())}; + result.addAttribute("operandSegmentSizes", + parser.getBuilder().getDenseI32ArrayAttr(segmentSizes)); + + if (!isMemRef) + result.addTypes(resultType); + + return success(); +} + +void PackOp::print(OpAsmPrinter &p) { + p << " " << getSource(); + + if (getPaddingValue()) { + p << " padding_value(" << getPaddingValue() << " : " + << getPaddingValue().getType() << ")"; + } + + if (!getOuterDimsPerm().empty()) { + p << " outer_dims_perm = ["; + llvm::interleaveComma(getOuterDimsPerm(), p); + p << "]"; + } + + p << " inner_dims_pos = ["; + llvm::interleaveComma(getInnerDimsPos(), p); + p << "]"; + + p << " inner_tiles = "; + printDynamicIndexList(p, *this, getInnerTiles(), getStaticInnerTilesAttr()); + + p << " into " << getDest(); + + p.printOptionalAttrDict((*this)->getAttrs(), + {"static_inner_tiles", "inner_dims_pos", + "outer_dims_perm", "operandSegmentSizes"}); + + p << " : " << getSource().getType(); + p << " -> " << getDest().getType(); +} + void PackOp::build(OpBuilder &builder, OperationState &state, Value source, Value dest, ArrayRef innerDimsPos, ArrayRef innerTiles, @@ -5260,6 +5438,8 @@ void PackOp::build(OpBuilder &builder, OperationState &state, Value source, LogicalResult PackOp::reifyResultShapes(OpBuilder &builder, ReifiedRankedShapedTypeDims &reifiedReturnShapes) { + if (!hasPureTensorSemantics()) + return failure(); return reifyResultShapesImpl(*this, builder, reifiedReturnShapes); } @@ -5395,13 +5575,11 @@ asShapeWithAnyValueAsDynamic(ArrayRef ofrs) { return result; } -/// Helper for PackOp::{getResultShape,inferPackedType}. Returns the shape of -/// the packed type. Having a shared helper helps implement these two methods in -/// a way that ensures that they agree on which dimensions are dynamic. -static SmallVector getPackOpResultTypeShape( - ArrayRef sourceShape, ArrayRef innerTileSizes, - ArrayRef innerDimsPos, ArrayRef outerDimsPerm) { - SmallVector resultShape = llvm::to_vector(sourceShape); +SmallVector PackOp::inferPackedShape(ArrayRef inputShape, + ArrayRef innerTileSizes, + ArrayRef innerDimsPos, + ArrayRef outerDimsPerm) { + SmallVector resultShape = llvm::to_vector(inputShape); for (auto tiledDim : llvm::enumerate(llvm::to_vector(innerDimsPos))) { if (ShapedType::isDynamic(resultShape[tiledDim.value()])) continue; @@ -5441,9 +5619,9 @@ SmallVector PackOp::getResultShape( resultDims.append(innerTileSizes.begin(), innerTileSizes.end()); SmallVector resultTypeShape = - getPackOpResultTypeShape(asShapeWithAnyValueAsDynamic(sourceDims), - asShapeWithAnyValueAsDynamic(innerTileSizes), - innerDimsPos, outerDimsPerm); + inferPackedShape(asShapeWithAnyValueAsDynamic(sourceDims), + asShapeWithAnyValueAsDynamic(innerTileSizes), + innerDimsPos, outerDimsPerm); // Fix-up `resultDims` to ensure that they are Value's if and only if the // result type shape says it's a dynamic dim. This is needed as callers may @@ -5459,15 +5637,21 @@ SmallVector PackOp::getResultShape( return resultDims; } -/// Get the expected packed type based on source type, tile factors, position of -/// the inner tiles and permutation of the outer tiled loop. -RankedTensorType PackOp::inferPackedType(RankedTensorType sourceType, +RankedTensorType PackOp::inferPackedTensorType( + RankedTensorType sourceType, ArrayRef innerTileSizes, + ArrayRef innerDimsPos, ArrayRef outerDimsPerm) { + SmallVector resultShape = inferPackedShape( + sourceType.getShape(), innerTileSizes, innerDimsPos, outerDimsPerm); + return RankedTensorType::get(resultShape, sourceType.getElementType()); +} + +MemRefType PackOp::inferPackedMemRefType(MemRefType sourceType, ArrayRef innerTileSizes, ArrayRef innerDimsPos, ArrayRef outerDimsPerm) { - SmallVector resultShape = getPackOpResultTypeShape( + SmallVector resultShape = inferPackedShape( sourceType.getShape(), innerTileSizes, innerDimsPos, outerDimsPerm); - return RankedTensorType::get(resultShape, sourceType.getElementType()); + return MemRefType::get(resultShape, sourceType.getElementType()); } Value PackOp::createDestinationTensor(OpBuilder &b, Location loc, Value source, @@ -5516,6 +5700,45 @@ PackOp PackOp::createTransposedClone(OpBuilder &b, Location loc, getPaddingValue(), metadata.outerDimsPerm); } +template +static void getPackUnPackEffectsImpl( + OpTy op, SmallVectorImpl> + &effects) { + // No memory effects for pure tensor semantics + if (op.hasPureTensorSemantics()) + return; + + for (OpOperand &opOperand : op.getOperation()->getOpOperands()) { + if (!llvm::isa(opOperand.get().getType())) + continue; + + if (&opOperand == &op.getSourceMutable()) { + effects.emplace_back(MemoryEffects::Read::get(), &opOperand, /*stage=*/0, + /*effectOnFullRegion=*/true, + SideEffects::DefaultResource::get()); + } else if (&opOperand == &op.getDestMutable()) { + effects.emplace_back(MemoryEffects::Read::get(), &opOperand, /*stage=*/0, + /*effectOnFullRegion=*/true, + SideEffects::DefaultResource::get()); + effects.emplace_back(MemoryEffects::Write::get(), &opOperand, /*stage=*/0, + /*effectOnFullRegion=*/true, + SideEffects::DefaultResource::get()); + } + } +} + +void PackOp::getEffects( + SmallVectorImpl> + &effects) { + getPackUnPackEffectsImpl(*this, effects); +} + +void UnPackOp::getEffects( + SmallVectorImpl> + &effects) { + getPackUnPackEffectsImpl(*this, effects); +} + /// Returns true if the tiles and the tiled dims are constant. template static bool areTilesAndTiledDimsAllConstant(OpTy op) { @@ -5535,6 +5758,8 @@ static bool areTilesAndTiledDimsAllConstant(OpTy op) { } Speculation::Speculatability PackOp::getSpeculatability() { + if (!hasPureTensorSemantics()) + return Speculation::NotSpeculatable; if (getPaddingValue()) return Speculation::Speculatable; @@ -5625,6 +5850,10 @@ static bool inferStaticShape(PackOp packOp, SmallVectorImpl &srcShape, } LogicalResult PackOp::canonicalize(PackOp packOp, PatternRewriter &rewriter) { + // TODO: Support Memref PackOp. Temporarily return failure. + if (!packOp.hasPureTensorSemantics()) + return failure(); + // Fold an pack(unpack(x)) to x. if (auto unPackOp = packOp.getSource().getDefiningOp()) { if (unPackOp.getSourceType() == packOp.getDestType() && @@ -5655,7 +5884,7 @@ LogicalResult PackOp::canonicalize(PackOp packOp, PatternRewriter &rewriter) { tensor::CastOp::create(rewriter, loc, newSrcType, packOp.getSource()); } Value dest = packOp.getDest(); - RankedTensorType originalResultType = packOp.getDestType(); + ShapedType originalResultType = packOp.getDestType(); bool needUpdateDestType = (destShape != originalResultType.getShape()); if (needUpdateDestType) { auto newDestType = packOp.getDestType().clone(destShape); @@ -5670,9 +5899,9 @@ LogicalResult PackOp::canonicalize(PackOp packOp, PatternRewriter &rewriter) { // Insert a cast if needed if (needUpdateDestType) { rewriter.setInsertionPointAfter(packOp); - auto castOp = - tensor::CastOp::create(rewriter, loc, originalResultType, packOp); - rewriter.replaceAllUsesExcept(packOp, castOp, castOp); + auto castOp = tensor::CastOp::create(rewriter, loc, originalResultType, + packOp.getResult()); + rewriter.replaceAllUsesExcept(packOp.getResult(), castOp, castOp); } return success(); } @@ -5681,8 +5910,7 @@ LogicalResult PackOp::canonicalize(PackOp packOp, PatternRewriter &rewriter) { } template -static bool isLikePadUnPad(PackOrUnpackOp packOp, - RankedTensorType packedTensorType) { +static bool isLikePadUnPad(PackOrUnpackOp packOp, ShapedType packedTensorType) { static_assert(std::is_same::value || std::is_same::value, "Function meant for pack/unpack"); @@ -5715,19 +5943,25 @@ static bool isLikePadUnPad(PackOrUnpackOp packOp, bool PackOp::isLikePad() { auto packedTensorType = - llvm::cast((*this)->getResultTypes().front()); + llvm::cast((*this)->getResultTypes().front()); return isLikePadUnPad(*this, packedTensorType); } -OpFoldResult PackOp::fold(FoldAdaptor adaptor) { +::mlir::LogicalResult +PackOp::fold(FoldAdaptor adaptor, + ::llvm::SmallVectorImpl &results) { + if (!hasPureTensorSemantics()) + return failure(); std::optional paddingValue; if (auto pad = adaptor.getPaddingValue()) paddingValue = pad; if (OpFoldResult reshapedSource = reshapeConstantSource( llvm::dyn_cast_if_present(adaptor.getSource()), - getDestType(), paddingValue)) - return reshapedSource; - return {}; + cast(getDestType()), paddingValue)) { + results.push_back(reshapedSource); + return success(); + } + return failure(); } /// Folds a tensor.cast op into a consuming PackOp op if the @@ -5749,6 +5983,10 @@ struct FoldTensorCastPackOp : public OpRewritePattern { LogicalResult matchAndRewrite(PackOp op, PatternRewriter &rewriter) const override { + // TODO: Support Memref PackOp. Temporarily return failure. + if (!op.hasPureTensorSemantics()) + return failure(); + if (!tensor::hasFoldableTensorCastOperand(op)) return failure(); @@ -5791,12 +6029,143 @@ struct FoldTensorCastPackOp : public OpRewritePattern { void UnPackOp::getAsmResultNames( function_ref setNameFn) { + if (getNumResults() == 0) + return; setNameFn(getResult(), "unpack"); } +// Custom parser for UnPackOp that handles the memref/tensor case distinction +ParseResult UnPackOp::parse(OpAsmParser &parser, OperationState &result) { + OpAsmParser::UnresolvedOperand source, dest; + SmallVector dynamicTiles; + SmallVector staticTiles; + DenseI64ArrayAttr innerDimsPos, outerDimsPerm; + Type sourceType, destType, resultType; + + if (parser.parseOperand(source)) + return failure(); + + if (succeeded(parser.parseOptionalKeyword("outer_dims_perm"))) { + if (parser.parseEqual()) + return failure(); + + SmallVector outerDimsPermVec; + if (parser.parseCommaSeparatedList(AsmParser::Delimiter::Square, [&]() { + int64_t value; + if (parser.parseInteger(value)) + return failure(); + outerDimsPermVec.push_back(value); + return success(); + })) + return failure(); + outerDimsPerm = parser.getBuilder().getDenseI64ArrayAttr(outerDimsPermVec); + } + + if (parser.parseKeyword("inner_dims_pos") || parser.parseEqual()) + return failure(); + + SmallVector innerDimsPosVec; + if (parser.parseCommaSeparatedList(AsmParser::Delimiter::Square, [&]() { + int64_t value; + if (parser.parseInteger(value)) + return failure(); + innerDimsPosVec.push_back(value); + return success(); + })) + return failure(); + innerDimsPos = parser.getBuilder().getDenseI64ArrayAttr(innerDimsPosVec); + + if (parser.parseKeyword("inner_tiles") || parser.parseEqual()) + return failure(); + + DenseI64ArrayAttr staticTilesAttr; + if (parseDynamicIndexList(parser, dynamicTiles, staticTilesAttr)) + return failure(); + for (auto val : staticTilesAttr.asArrayRef()) + staticTiles.push_back(val); + + if (parser.parseKeyword("into") || parser.parseOperand(dest)) + return failure(); + + if (parser.parseOptionalAttrDict(result.attributes)) + return failure(); + + if (parser.parseColon() || parser.parseType(sourceType)) + return failure(); + + bool hasArrow = succeeded(parser.parseOptionalArrow()); + if (hasArrow) { + if (parser.parseType(destType)) + return failure(); + } + + bool isMemRef = llvm::isa(sourceType); + if (!hasArrow) { + return parser.emitError(parser.getCurrentLocation(), + "pack/unpack requires '->' and destination type"); + } + + if (!isMemRef) { + resultType = destType; + } + + if (parser.resolveOperand(source, sourceType, result.operands) || + parser.resolveOperand(dest, destType, result.operands)) + return failure(); + + if (!dynamicTiles.empty() && + parser.resolveOperands(dynamicTiles, parser.getBuilder().getIndexType(), + result.operands)) + return failure(); + + result.addAttribute("static_inner_tiles", + parser.getBuilder().getDenseI64ArrayAttr(staticTiles)); + result.addAttribute("inner_dims_pos", innerDimsPos); + if (outerDimsPerm) + result.addAttribute("outer_dims_perm", outerDimsPerm); + + SmallVector segmentSizes = { + 1, 1, 0, static_cast(dynamicTiles.size())}; + result.addAttribute("operandSegmentSizes", + parser.getBuilder().getDenseI32ArrayAttr(segmentSizes)); + + if (!isMemRef) + result.addTypes(resultType); + + return success(); +} + +void UnPackOp::print(OpAsmPrinter &p) { + p << " " << getSource(); + + if (!getOuterDimsPerm().empty()) { + p << " outer_dims_perm = ["; + llvm::interleaveComma(getOuterDimsPerm(), p); + p << "]"; + } + + p << " inner_dims_pos = ["; + llvm::interleaveComma(getInnerDimsPos(), p); + p << "]"; + + p << " inner_tiles = "; + printDynamicIndexList(p, *this, getInnerTiles(), getStaticInnerTilesAttr()); + + p << " into " << getDest(); + + p.printOptionalAttrDict((*this)->getAttrs(), + {"static_inner_tiles", "inner_dims_pos", + "outer_dims_perm", "operandSegmentSizes"}); + + p << " : " << getSource().getType(); + p << " -> " << getDest().getType(); +} + LogicalResult UnPackOp::reifyResultShapes(OpBuilder &builder, ReifiedRankedShapedTypeDims &reifiedReturnShapes) { + if (!hasPureTensorSemantics()) + return failure(); return reifyResultShapesImpl(*this, builder, reifiedReturnShapes); } @@ -5841,6 +6210,8 @@ LogicalResult UnPackOp::verify() { } Speculation::Speculatability UnPackOp::getSpeculatability() { + if (!hasPureTensorSemantics()) + return Speculation::NotSpeculatable; // See PackOp::getSpeculatability. if (!areTilesAndTiledDimsAllConstant(*this)) return Speculation::NotSpeculatable; @@ -5947,6 +6318,10 @@ static bool inferStaticShape(UnPackOp op, SmallVectorImpl &srcShape, LogicalResult UnPackOp::canonicalize(UnPackOp unPackOp, PatternRewriter &rewriter) { + // TODO: Support Memref UnPackOp. Temporarily return failure. + if (!unPackOp.hasPureTensorSemantics()) + return failure(); + /// unpack(pack(x)) -> x if (PackOp packOp = unPackOp.getSource().getDefiningOp()) { if (packOp.getSourceType() != unPackOp.getDestType()) @@ -6003,11 +6378,11 @@ LogicalResult UnPackOp::canonicalize(UnPackOp unPackOp, dest = tensor::CastOp::create(rewriter, loc, newDestType, unPackOp.getDest()); } - Value newOp = UnPackOp::create( + UnPackOp newOp = UnPackOp::create( rewriter, loc, source, dest, unPackOp.getInnerDimsPos(), unPackOp.getMixedTiles(), unPackOp.getOuterDimsPerm()); rewriter.replaceOpWithNewOp( - unPackOp, unPackOp.getResult().getType(), newOp); + unPackOp, unPackOp.getResult().getType(), newOp.getResult()); return success(); } @@ -6041,16 +6416,24 @@ bool UnPackOp::canFoldSliceOp(tensor::ExtractSliceOp sliceOp) { } bool UnPackOp::isLikeUnPad() { - RankedTensorType packedTensorType = getSourceType(); + ShapedType packedTensorType = getSourceType(); return isLikePadUnPad(*this, packedTensorType); } -OpFoldResult UnPackOp::fold(FoldAdaptor adaptor) { +::mlir::LogicalResult +UnPackOp::fold(FoldAdaptor adaptor, + ::llvm::SmallVectorImpl &results) { + // TODO: Support Memref UnPackOp. Temporarily return failure. + if (!hasPureTensorSemantics()) + return failure(); + if (OpFoldResult reshapedSource = reshapeConstantSource( llvm::dyn_cast_if_present(adaptor.getSource()), - getResult().getType())) - return reshapedSource; - return {}; + cast(getResult().getType()))) { + results.push_back(reshapedSource); + return success(); + } + return failure(); } /// Folds a tensor.cast op into a consuming UnPackOp op if the @@ -6072,6 +6455,10 @@ struct FoldTensorCastUnPackOp : public OpRewritePattern { LogicalResult matchAndRewrite(UnPackOp op, PatternRewriter &rewriter) const override { + // TODO: Support Memref UnPackOp. Temporarily return failure. + if (!op.hasPureTensorSemantics()) + return failure(); + if (!tensor::hasFoldableTensorCastOperand(op)) return failure(); diff --git a/mlir/lib/Dialect/Linalg/Transforms/BlockPackMatmul.cpp b/mlir/lib/Dialect/Linalg/Transforms/BlockPackMatmul.cpp index 6912da3ffbc83..6ea1eb50b13ce 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/BlockPackMatmul.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/BlockPackMatmul.cpp @@ -90,6 +90,10 @@ transposePackedMatmul(RewriterBase &rewriter, linalg::LinalgOp linalgOp, linalg::PackOp packOp, AffineMap operandMap, ArrayRef blocksStartDimPos, bool transposeOuterBlocks, bool transposeInnerBlocks) { + // TODO: Support Memref PackOp. Temporarily return failure. + if (!packOp.hasPureTensorSemantics()) + return failure(); + assert(operandMap.getNumDims() >= 4 && "expected at least 4D prepacked matmul"); assert(blocksStartDimPos.size() >= 2 && diff --git a/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp b/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp index 3bb5f8af821c0..69d216fab7da8 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp @@ -282,8 +282,8 @@ static bool getPackedOperandDetails( }); bool requirePadding = linalg::PackOp::requirePaddingValueStrict( inputType.getShape(), innerDimsPos, - linalg::PackOp::inferPackedType(inputType, maybeIntInnerTileSizes, - innerDimsPos, outerDimsPerm) + linalg::PackOp::inferPackedTensorType(inputType, maybeIntInnerTileSizes, + innerDimsPos, outerDimsPerm) .getShape(), outerDimsPerm, innerTileSizes); currOperandDetails.innerDimsPos = innerDimsPos; @@ -341,10 +341,11 @@ static std::tuple getOrCreatePackedViewOfOperand( b, loc, opOperand->get(), innerTileSizes, innerDimsPos, outerDimsPerm); auto poison = ub::PoisonOp::create( b, loc, getElementTypeOrSelf(opOperand->get().getType())); - Value packedOperand = + PackOp packedOperand = linalg::PackOp::create(b, loc, opOperand->get(), empty, innerDimsPos, innerTileSizes, poison, outerDimsPerm); - return std::make_tuple(packedOperand, currOperandDetails.indexingMap); + return std::make_tuple(packedOperand.getResult(), + currOperandDetails.indexingMap); } /// This function is a helper subroutine to pack a genericOp and return it. It @@ -571,6 +572,10 @@ struct BubbleUpPackOpThroughGenericOpPattern LogicalResult matchAndRewrite(linalg::PackOp packOp, PatternRewriter &rewriter) const override { + // TODO: Support Memref PackOp. Temporarily return failure. + if (!packOp.hasPureTensorSemantics()) + return failure(); + auto genericOp = bubbleUpPackOpThroughGenericOp(rewriter, packOp, controlFn, poisonPaddingOk); if (failed(genericOp)) @@ -594,6 +599,10 @@ class BubbleUpPackThroughPadOp final : public OpRewritePattern { LogicalResult matchAndRewrite(linalg::PackOp packOp, PatternRewriter &rewriter) const override { + // TODO: Support Memref PackOp. Temporarily return failure. + if (!packOp.hasPureTensorSemantics()) + return failure(); + auto padOp = packOp.getSource().getDefiningOp(); if (!padOp) return failure(); @@ -653,19 +662,19 @@ class BubbleUpPackThroughPadOp final : public OpRewritePattern { lowPad.append(pointLoopsSize, rewriter.getIndexAttr(0)); highPad.append(pointLoopsSize, rewriter.getIndexAttr(0)); - auto newPadOp = - tensor::PadOp::create(rewriter, loc, /*result=*/Type(), sourcePack, - lowPad, highPad, paddingVal, padOp.getNofold()); + auto newPadOp = tensor::PadOp::create( + rewriter, loc, /*result=*/Type(), sourcePack.getResult(), lowPad, + highPad, paddingVal, padOp.getNofold()); // If the pad has more than one user, create an unpack on the new pad to // replace the other uses. if (!padOp->hasOneUse()) { auto unpackEmpty = linalg::UnPackOp::createDestinationTensor( rewriter, loc, newPadOp, mixedTiles, innerDimsPos, outerDimsPerm); - Value unpackedPad = + UnPackOp unpackedPad = linalg::UnPackOp::create(rewriter, loc, newPadOp, unpackEmpty, innerDimsPos, mixedTiles, outerDimsPerm); - rewriter.replaceAllUsesExcept(padOp, unpackedPad, sourcePack); + rewriter.replaceAllUsesExcept(padOp, unpackedPad.getResult(), sourcePack); } // Replace the pack with the new pad. @@ -763,6 +772,10 @@ static LogicalResult bubbleUpPackOpThroughCollapseShape(tensor::CollapseShapeOp collapseOp, linalg::PackOp packOp, PatternRewriter &rewriter) { + // TODO: Support Memref PackOp. Temporarily return failure. + if (!packOp.hasPureTensorSemantics()) + return failure(); + SmallVector innerTileSizes = packOp.getStaticTiles(); ArrayRef innerDimsPos = packOp.getInnerDimsPos(); ArrayRef outerDimsPerm = packOp.getOuterDimsPerm(); @@ -812,8 +825,8 @@ bubbleUpPackOpThroughCollapseShape(tensor::CollapseShapeOp collapseOp, } auto newCollapseOp = tensor::CollapseShapeOp::create( - rewriter, collapseOp.getLoc(), packOp.getType(), newPackOp, - newReassocIndices); + rewriter, collapseOp.getLoc(), packOp.getResult().getType(), + newPackOp.getResult(), newReassocIndices); rewriter.replaceOp(packOp, newCollapseOp); return success(); @@ -868,6 +881,10 @@ static LogicalResult bubbleUpPackOpThroughExpandShape(tensor::ExpandShapeOp expandOp, linalg::PackOp packOp, PatternRewriter &rewriter) { + // TODO: Support Memref PackOp. Temporarily return failure. + if (!packOp.hasPureTensorSemantics()) + return failure(); + // Outer dimensions permutation is not supported currently. // TODO: Handle outer_dims_perm variants. ArrayRef outerDimsPerm = packOp.getOuterDimsPerm(); @@ -918,7 +935,7 @@ bubbleUpPackOpThroughExpandShape(tensor::ExpandShapeOp expandOp, // If reassociation is not possible, then reordering cannot happen. // This can be caused by pack padding affecting previously expanded // dimensions or packing extending dimensions. - RankedTensorType newPackType = linalg::PackOp::inferPackedType( + RankedTensorType newPackType = linalg::PackOp::inferPackedTensorType( expandOp.getSrcType(), packOp.getStaticInnerTiles(), projectedInnerDimsPos, /*outerDimsPerm=*/SmallVector{}); auto reassocExpand = @@ -930,14 +947,14 @@ bubbleUpPackOpThroughExpandShape(tensor::ExpandShapeOp expandOp, Value destTensor = linalg::PackOp::createDestinationTensor( rewriter, packOp.getLoc(), expandOp.getSrc(), packOp.getMixedTiles(), projectedInnerDimsPos, /*outerDimsPerm=*/SmallVector{}); - Value packedVal = linalg::PackOp::create( + PackOp packedVal = linalg::PackOp::create( rewriter, packOp.getLoc(), expandOp.getSrc(), destTensor, projectedInnerDimsPos, packOp.getMixedTiles(), packOp.getPaddingValue(), /*outerDimsPerm=*/SmallVector{}); - Value newExpandOp = tensor::ExpandShapeOp::create(rewriter, packOp.getLoc(), - packOp.getDestType(), - packedVal, *reassocExpand); + Value newExpandOp = tensor::ExpandShapeOp::create( + rewriter, packOp.getLoc(), packOp.getDestType(), packedVal.getResult(), + *reassocExpand); rewriter.replaceOp(packOp, newExpandOp); return success(); @@ -951,6 +968,10 @@ class BubbleUpPackOpThroughReshapeOp final LogicalResult matchAndRewrite(linalg::PackOp packOp, PatternRewriter &rewriter) const override { + // TODO: Support Memref PackOp. Temporarily return failure. + if (!packOp.hasPureTensorSemantics()) + return failure(); + Operation *srcOp = packOp.getSource().getDefiningOp(); // Currently only support when the pack op is the only user. if (!srcOp || !(srcOp->getNumResults() == 1) || @@ -1001,6 +1022,10 @@ class BubbleUpPackOpThroughReshapeOp final static LogicalResult pushDownUnPackOpThroughExpandShape( linalg::UnPackOp unPackOp, tensor::ExpandShapeOp expandOp, PatternRewriter &rewriter, ControlPropagationFn controlFn) { + // TODO: Support Memref UnpackOp. Temporarily return failure. + if (!unPackOp.hasPureTensorSemantics()) + return failure(); + // User controlled propagation function. if (!controlFn(&expandOp.getSrcMutable())) return failure(); @@ -1048,7 +1073,7 @@ static LogicalResult pushDownUnPackOpThroughExpandShape( nextPos += 1; } - RankedTensorType newExpandType = linalg::PackOp::inferPackedType( + RankedTensorType newExpandType = linalg::PackOp::inferPackedTensorType( expandTy, innerTileSizes, projectedInnerDimsPos, newOuterDimsPerm); auto newExpandOp = tensor::ExpandShapeOp::create(rewriter, expandOp.getLoc(), newExpandType, @@ -1075,6 +1100,10 @@ class PushDownUnPackOpThroughReshapeOp final LogicalResult matchAndRewrite(linalg::UnPackOp unPackOp, PatternRewriter &rewriter) const override { + // TODO: Support Memref UnPackOp. Temporarily return failure. + if (!unPackOp.hasPureTensorSemantics()) + return failure(); + Value result = unPackOp.getResult(); // Currently only support unpack op with the single user. if (!result.hasOneUse()) { @@ -1274,6 +1303,10 @@ struct PushDownUnPackThroughPadOp : public OpRewritePattern { if (!unpackOp) return failure(); + // TODO: Support Memref UnPackOp. Temporarily return failure. + if (!unpackOp.hasPureTensorSemantics()) + return failure(); + if (!controlFn(&padOp.getSourceMutable())) return failure(); @@ -1313,7 +1346,7 @@ struct PushDownUnPackThroughPadOp : public OpRewritePattern { tensor::EmptyOp::create(rewriter, loc, padOp.getResultType().getShape(), padOp.getResultType().getElementType()); - Value replacement = linalg::UnPackOp::create( + UnPackOp replacement = linalg::UnPackOp::create( rewriter, loc, newPadOp.getResult(), outputUnPack, innerDimsPos, unpackOp.getMixedTiles(), outerDimsPerm); rewriter.replaceOp(padOp, replacement); diff --git a/mlir/lib/Dialect/Linalg/Transforms/PackAndUnpackPatterns.cpp b/mlir/lib/Dialect/Linalg/Transforms/PackAndUnpackPatterns.cpp index 1d4c11e418006..993eae62535c3 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/PackAndUnpackPatterns.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/PackAndUnpackPatterns.cpp @@ -11,6 +11,7 @@ #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Utils/IndexingUtils.h" #include "mlir/Dialect/Utils/StaticValueUtils.h" +#include "mlir/IR/BuiltinTypeInterfaces.h" #include "mlir/IR/PatternMatch.h" namespace mlir { @@ -110,8 +111,11 @@ struct SimplifyPackToExpandShape : public OpRewritePattern { PatternRewriter &rewriter) const override { if (packOp.getPaddingValue()) return rewriter.notifyMatchFailure(packOp, "expects no padding value"); + // TODO: Support Memref PackOp. Temporarily return failure. + if (!packOp.hasPureTensorSemantics()) + return failure(); - RankedTensorType sourceType = packOp.getSourceType(); + ShapedType sourceType = packOp.getSourceType(); if (failed(isPackOnInnerMostDim(rewriter, packOp)) && failed(isPackOn1D(rewriter, packOp, sourceType.getShape(), packOp.getStaticTiles())) && @@ -119,7 +123,7 @@ struct SimplifyPackToExpandShape : public OpRewritePattern { return failure(); } - RankedTensorType destType = packOp.getDestType(); + ShapedType destType = packOp.getDestType(); auto reassociation = getReassociationIndicesForReshape(sourceType, destType); if (!reassociation) @@ -157,8 +161,8 @@ struct SimplifyUnPackToCollapseShape : public OpRewritePattern { "expects outer_dims_perm is empty or an identity permutation"); } - RankedTensorType sourceType = unpackOp.getSourceType(); - RankedTensorType destType = unpackOp.getDestType(); + ShapedType sourceType = unpackOp.getSourceType(); + ShapedType destType = unpackOp.getDestType(); if (!sourceType.hasStaticShape() || !destType.hasStaticShape()) return rewriter.notifyMatchFailure(unpackOp, "expects static shapes"); @@ -173,7 +177,11 @@ struct SimplifyUnPackToCollapseShape : public OpRewritePattern { LogicalResult matchAndRewrite(UnPackOp unpackOp, PatternRewriter &rewriter) const override { - RankedTensorType destType = unpackOp.getDestType(); + // TODO: Support Memref UnPackOp. Temporarily return failure. + if (!unpackOp.hasPureTensorSemantics()) + return failure(); + + ShapedType destType = unpackOp.getDestType(); if (failed(isUnpackOnInnerMostDim(rewriter, unpackOp)) && failed(isPackOn1D(rewriter, unpackOp, destType.getShape(), unpackOp.getStaticTiles())) && @@ -181,7 +189,7 @@ struct SimplifyUnPackToCollapseShape : public OpRewritePattern { return failure(); } - RankedTensorType sourceType = unpackOp.getSourceType(); + ShapedType sourceType = unpackOp.getSourceType(); auto reassociation = getReassociationIndicesForReshape(sourceType, destType); if (!reassociation) @@ -225,7 +233,7 @@ struct FoldPadWithPackOp : public OpRewritePattern { // sizes - that is because it would be impossible to compute the padding // size and hence to establish whether "artificial" padding would be // created. - RankedTensorType unpackedType = packOp.getSourceType(); + ShapedType unpackedType = packOp.getSourceType(); SmallVector outerShapeWithoutTranspose = getPackedOuterShapeWithoutTransposition(packOp); for (auto [pos, tileSize, high] : @@ -274,6 +282,10 @@ struct FoldUnpackWithExtractSliceOp if (!unpackOp) return failure(); + // TODO: Support Memref UnPackOp. Temporarily return failure. + if (!unpackOp.hasPureTensorSemantics()) + return failure(); + // User controlled folding function. if (controlFn && !controlFn(&sliceOp.getSourceMutable())) return failure(); @@ -336,6 +348,10 @@ struct FoldProducerPackWithConsumerLinalgTransposeOp if (!packOp) return failure(); + // TODO: Support Memref PackOp. Temporarily return failure. + if (!packOp.hasPureTensorSemantics()) + return failure(); + // User controlled folding function. if (controlFn && !controlFn(&linalgOp->getOpOperand(0))) return failure(); @@ -395,6 +411,10 @@ struct FoldConsumerPackWithProducerLinalgTransposeOp LogicalResult matchAndRewrite(PackOp packOp, PatternRewriter &rewriter) const override { + // TODO: Support Memref PackOp. Temporarily return failure. + if (!packOp.hasPureTensorSemantics()) + return failure(); + auto linalgOp = packOp.getSource().getDefiningOp(); if (!linalgOp) return failure(); @@ -456,6 +476,10 @@ struct FoldProducerUnPackWithConsumerLinalgTransposeOp if (!unPackOp) return failure(); + // TODO: Support Memref UnPackOp. Temporarily return failure. + if (!unPackOp.hasPureTensorSemantics()) + return failure(); + // User controlled folding function. if (controlFn && !controlFn(&linalgOp->getOpOperand(0))) return failure(); @@ -504,6 +528,10 @@ struct FoldConsumerUnPackWithProducerLinalgTransposeOp LogicalResult matchAndRewrite(UnPackOp unPackOp, PatternRewriter &rewriter) const override { + // TODO: Support Memref UnPackOp. Temporarily return failure. + if (!unPackOp.hasPureTensorSemantics()) + return failure(); + auto linalgOp = unPackOp.getSource().getDefiningOp(); if (!linalgOp) return failure(); @@ -568,6 +596,10 @@ struct FoldEmptyTensorWithPackOp : public OpRewritePattern { LogicalResult matchAndRewrite(PackOp packOp, PatternRewriter &rewriter) const override { + // TODO: Support Memref PackOp. Temporarily return failure. + if (!packOp.hasPureTensorSemantics()) + return failure(); + // Check for tensor.empty source. auto emptyOp = packOp.getSource().getDefiningOp(); if (!emptyOp) @@ -592,6 +624,10 @@ struct FoldEmptyTensorWithUnPackOp : public OpRewritePattern { LogicalResult matchAndRewrite(UnPackOp unPackOp, PatternRewriter &rewriter) const override { + // TODO: Support Memref UnPackOp. Temporarily return failure. + if (!unPackOp.hasPureTensorSemantics()) + return failure(); + // Check for tensor.empty source. auto emptyOp = unPackOp.getSource().getDefiningOp(); if (!emptyOp) diff --git a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp index 8a0440bcc6fb9..10e9fbd22a41d 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp @@ -740,6 +740,10 @@ struct PackOpTiling ArrayRef offsets, ArrayRef sizes) const { auto packOp = cast(op); + // TODO: Support Memref PackOp. Temporarily return failure. + if (!packOp.hasPureTensorSemantics()) + return failure(); + Location loc = packOp.getLoc(); // The tiling is applied on interchanged dimensions. We have to undo the @@ -984,6 +988,10 @@ struct PackOpTiling ArrayRef sizes(allSizes[0]); auto packOp = cast(op); + // TODO: Support Memref UnPackOp. Temporarily return failure. + if (!packOp.hasPureTensorSemantics()) + return failure(); + Location loc = packOp.getLoc(); int64_t inputRank = packOp.getSourceRank(); @@ -1163,6 +1171,10 @@ struct UnPackOpTiling ArrayRef offsets, ArrayRef sizes) const { auto unpackOp = cast(op); + // TODO: Support Memref UnPackOp. Temporarily return failure. + if (!unpackOp.hasPureTensorSemantics()) + return failure(); + int64_t srcRank = unpackOp.getSourceRank(); int64_t destRank = unpackOp.getDestRank(); int64_t numInnerTiles = srcRank - destRank; @@ -1333,6 +1345,10 @@ struct UnPackOpTiling return failure(); } auto unPackOp = cast(op); + // TODO: Support Memref UnPackOp. Temporarily return failure. + if (!unPackOp.hasPureTensorSemantics()) + return failure(); + ArrayRef offsets(allOffsets[0]); ArrayRef sizes(allSizes[0]); diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp index 027268cc20ddd..821b9d229ac42 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp @@ -26,6 +26,8 @@ #include "mlir/Dialect/Utils/StructuredOpsUtils.h" #include "mlir/Dialect/Vector/IR/VectorOps.h" #include "mlir/IR/AffineExpr.h" +#include "mlir/IR/BuiltinTypeInterfaces.h" +#include "mlir/IR/BuiltinTypes.h" #include "mlir/Support/LLVM.h" #include "llvm/ADT/TypeSwitch.h" #include "llvm/Support/Debug.h" @@ -217,6 +219,10 @@ struct PackedOperandsDimList { FailureOr linalg::lowerPack(RewriterBase &rewriter, linalg::PackOp packOp, bool lowerPadLikeWithInsertSlice) { + // TODO: Support Memref PackOp. Temporarily return failure. + if (!packOp.hasPureTensorSemantics()) + return failure(); + // 1. Filter out NYI cases. auto packedTensorType = cast(packOp->getResultTypes().front()); @@ -344,11 +350,15 @@ FailureOr linalg::lowerPack(RewriterBase &rewriter, FailureOr linalg::lowerUnPack(RewriterBase &rewriter, linalg::UnPackOp unPackOp, bool lowerUnpadLikeWithExtractSlice) { + // TODO: Support Memref UnPackOp. Temporarily return failure. + if (!unPackOp.hasPureTensorSemantics()) + return failure(); + Location loc = unPackOp->getLoc(); OpBuilder::InsertionGuard g(rewriter); rewriter.setInsertionPoint(unPackOp); - RankedTensorType packedTensorType = unPackOp.getSourceType(); + auto packedTensorType = cast(unPackOp.getSourceType()); int64_t packedRank = packedTensorType.getRank(); OpFoldResult zero = rewriter.getIndexAttr(0), one = rewriter.getIndexAttr(1); @@ -548,7 +558,7 @@ FailureOr linalg::pack(RewriterBase &rewriter, packOps.push_back(linalg::PackOp::create( rewriter, loc, operand, dest, innerPos, innerPackSizes, zero)); } - inputsAndInits.push_back(packOps.back()); + inputsAndInits.push_back(packOps.back().getResult()); } } @@ -575,7 +585,7 @@ FailureOr linalg::pack(RewriterBase &rewriter, unPackOps.push_back(linalg::UnPackOp::create( rewriter, packedLinalgOp->getLoc(), result, maybePackedInit.getSource(), maybePackedInit.getInnerDimsPos(), maybePackedInit.getMixedTiles())); - results.push_back(unPackOps.back()); + results.push_back(unPackOps.back().getResult()); } // Step 5. Replace `linalgOp`. @@ -665,7 +675,7 @@ linalg::packTranspose(RewriterBase &rewriter, linalg::PackOp packOp, linalg::PackOp transposedPackOp = packOp.createTransposedClone(rewriter, loc, innerPerm, outerPerm); - if (!packOp.getResult().hasOneUse()) + if (packOp.getNumResults() == 0 || !packOp.getResult().hasOneUse()) return rewriter.notifyMatchFailure(linalgOp, "expect single pack use"); OpOperand &packUse = *packOp->getUses().begin(); @@ -726,7 +736,10 @@ linalg::packTranspose(RewriterBase &rewriter, linalg::PackOp packOp, } // Step 4. Finally, replace packOp now that we don't need it anymore. - rewriter.replaceOp(packOp, transposedPackOp->getResults()); + if (packOp.getNumResults() != 0) + rewriter.replaceOp(packOp, transposedPackOp->getResults()); + else + rewriter.eraseOp(packOp); return PackTransposeResult{transposedPackOp, transposedLinalgOp, transposedUnPackOp}; @@ -1018,6 +1031,10 @@ LogicalResult ExtractSliceOfPadTensorSwapPattern::matchAndRewrite( static Value getPackOpSourceOrPaddedSource(OpBuilder &builder, linalg::PackOp packOp) { Value input = packOp.getSource(); + // TODO: Support Memref PackOp. Temporarily return just Op Source. + if (!packOp.hasPureTensorSemantics()) + return input; + if (!packOp.getPaddingValue()) { return input; } @@ -1134,6 +1151,10 @@ getPackUnpackRankReducedPerm(ArrayRef shape, LogicalResult DecomposeOuterUnitDimsPackOpPattern::matchAndRewrite( linalg::PackOp packOp, PatternRewriter &rewriter) const { + // TODO: Support Memref PackOp. Temporarily return failure. + if (!packOp.hasPureTensorSemantics()) + return failure(); + if (llvm::any_of(packOp.getTiledOuterDims(), [](int64_t dim) { return dim != 1; })) { return rewriter.notifyMatchFailure( @@ -1155,8 +1176,8 @@ LogicalResult DecomposeOuterUnitDimsPackOpPattern::matchAndRewrite( // Check whether this dim has been permuted. Permuting unit dims is fine // as that's effectively a no-op. - if (dim < prev && (packOp.getType().getShape()[prev] != 1 || - packOp.getType().getShape()[dim] != 1)) + if (dim < prev && (packOp.getResult().getType().getShape()[prev] != 1 || + packOp.getResult().getType().getShape()[dim] != 1)) return false; prev = dim; @@ -1279,6 +1300,9 @@ LogicalResult DecomposeOuterUnitDimsPackOpPattern::matchAndRewrite( LogicalResult DecomposeOuterUnitDimsUnPackOpPattern::matchAndRewrite( linalg::UnPackOp unpackOp, PatternRewriter &rewriter) const { + if (!unpackOp.hasPureTensorSemantics()) + return failure(); + int64_t srcRank = unpackOp.getSourceRank(); int64_t destRank = unpackOp.getDestRank(); ArrayRef srcShape = unpackOp.getSourceType().getShape(); diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp index dcf84c46949f3..44c799d97e8aa 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp @@ -1877,8 +1877,9 @@ vectorizeAsTensorPackOp(RewriterBase &rewriter, linalg::PackOp packOp, SmallVector preTransposeWriteVecSizses(writeVectorSizes); auto destInvPermutation = getPackInverseDestPerm(packOp, packMetadata); applyPermutationToVector(preTransposeWriteVecSizses, destInvPermutation); - auto preTransposeWriteVecType = VectorType::get( - preTransposeWriteVecSizses, packOp.getType().getElementType()); + auto preTransposeWriteVecType = + VectorType::get(preTransposeWriteVecSizses, + packOp.getResult().getType().getElementType()); // Compute vector type for the _read_ opeartion. This is simply // pre-transpose-write-vector-type with the dimensions collapsed @@ -1954,7 +1955,7 @@ vectorizeAsTensorUnpackOp(RewriterBase &rewriter, linalg::UnPackOp unpackOp, OpBuilder::InsertionGuard g(rewriter); rewriter.setInsertionPoint(unpackOp); - RankedTensorType unpackTensorType = unpackOp.getSourceType(); + ShapedType unpackTensorType = unpackOp.getSourceType(); ArrayRef sourceShape = unpackTensorType.getShape(); bool useInBoundsInsteadOfMasking = false; @@ -2117,6 +2118,10 @@ vectorizeDynamicLinalgOpPrecondition(linalg::LinalgOp op, static LogicalResult vectorizeUnPackOpPrecondition(linalg::UnPackOp unpackOp, ArrayRef inputVectorSizes) { + // TODO: Support Memref UnPackOp. Temporarily return failure. + if (!unpackOp.hasPureTensorSemantics()) + return failure(); + // If there are no input vector sizes and all shapes are static, there is // nothing left to check. if (inputVectorSizes.empty() && unpackOp.getDestType().hasStaticShape() && @@ -2454,6 +2459,10 @@ static LogicalResult vectorizeLinalgOpPrecondition( static LogicalResult vectorizePackOpPrecondition(linalg::PackOp packOp, ArrayRef inputVectorSizes) { + // TODO: Support Memref PackOp. Temporarily return failure. + if (!packOp.hasPureTensorSemantics()) + return failure(); + auto padValue = packOp.getPaddingValue(); Attribute cstAttr; // TODO: Relax this condiiton diff --git a/mlir/test/Dialect/Linalg/canonicalize.mlir b/mlir/test/Dialect/Linalg/canonicalize.mlir index f4020ede4854e..1f5995a680f71 100644 --- a/mlir/test/Dialect/Linalg/canonicalize.mlir +++ b/mlir/test/Dialect/Linalg/canonicalize.mlir @@ -2058,3 +2058,62 @@ func.func @no_fold_extract_slice_into_unpack_non_zero_offset( // CHECK-SAME: into %[[DEST]] // CHECK: %[[SLICE:.+]] = tensor.extract_slice %[[UNPACK]] // CHECK: return %[[SLICE]] + +// ----- + +// CHECK-LABEL: func.func @fold_cast_unpack_dynamic_tile_size( +// CHECK-SAME: %[[SRC:.*]]: tensor<1x1x8x1xi32>, +// CHECK-SAME: %[[DEST:.*]]: tensor<7x?xi32>) -> tensor<7x?xi32> { +// CHECK: %[[RES:.*]] = linalg.unpack %[[SRC]] inner_dims_pos = [0, 1] inner_tiles = [8, 1] into %[[DEST]] {test_attr} : tensor<1x1x8x1xi32> -> tensor<7x?xi32> +// CHECK: return %[[RES]] : tensor<7x?xi32> +func.func @fold_cast_unpack_dynamic_tile_size( + %src: tensor<1x1x8x1xi32>, + %res: tensor<7x?xi32>) -> tensor<7x?xi32> { + + %cast = tensor.cast %src : tensor<1x1x8x1xi32> to tensor<1x1x?x1xi32> + %c8 = arith.constant 8 : index + %unpack = linalg.unpack %cast + inner_dims_pos = [0, 1] + inner_tiles = [%c8, 1] + into %res {test_attr} : tensor<1x1x?x1xi32> -> tensor<7x?xi32> + return %unpack : tensor<7x?xi32> +} + +// ----- + +// CHECK-LABEL: func.func @fold_pack_unpack_tensor +// CHECK-SAME: (%[[ARG0:.*]]: tensor<2x3xf32>) -> tensor<2x3xf32> +// CHECK: return %[[ARG0]] : tensor<2x3xf32> +func.func @fold_pack_unpack_tensor(%x: tensor<2x3xf32>) -> tensor<2x3xf32> { + %unpacked = linalg.unpack %x outer_dims_perm = [] inner_dims_pos = [] inner_tiles = [] + into %x : tensor<2x3xf32> -> tensor<2x3xf32> + %packed = linalg.pack %unpacked outer_dims_perm = [] inner_dims_pos = [] inner_tiles = [] + into %x : tensor<2x3xf32> -> tensor<2x3xf32> + return %packed : tensor<2x3xf32> +} + +// ----- + +// Test that pack/unpack canonicalization is disabled for memref versions +// CHECK-LABEL: func.func @pack_unpack_memref_no_canonicalization +// CHECK: linalg.pack +// CHECK: linalg.unpack +// CHECK: return +func.func @pack_unpack_memref_no_canonicalization(%source: memref<128x256xf32>, %packed: memref<16x8x8x32xf32>, %dest: memref<128x256xf32>) { + linalg.pack %source inner_dims_pos = [0, 1] inner_tiles = [8, 32] into %packed : memref<128x256xf32> -> memref<16x8x8x32xf32> + linalg.unpack %packed inner_dims_pos = [0, 1] inner_tiles = [8, 32] into %dest : memref<16x8x8x32xf32> -> memref<128x256xf32> + return +} + +// ----- + +// Test that unpack/pack canonicalization is disabled for memref versions +// CHECK-LABEL: func.func @unpack_pack_memref_no_canonicalization +// CHECK: linalg.unpack +// CHECK: linalg.pack +// CHECK: return +func.func @unpack_pack_memref_no_canonicalization(%packed: memref<16x8x8x32xf32>, %unpacked: memref<128x256xf32>, %dest: memref<16x8x8x32xf32>) { + linalg.unpack %packed inner_dims_pos = [0, 1] inner_tiles = [8, 32] into %unpacked : memref<16x8x8x32xf32> -> memref<128x256xf32> + linalg.pack %unpacked inner_dims_pos = [0, 1] inner_tiles = [8, 32] into %dest : memref<128x256xf32> -> memref<16x8x8x32xf32> + return +} \ No newline at end of file diff --git a/mlir/test/Dialect/Linalg/memref-pack-unpack.mlir b/mlir/test/Dialect/Linalg/memref-pack-unpack.mlir new file mode 100644 index 0000000000000..2701a1e3512a2 --- /dev/null +++ b/mlir/test/Dialect/Linalg/memref-pack-unpack.mlir @@ -0,0 +1,47 @@ +// RUN: mlir-opt %s -split-input-file | FileCheck %s + +// CHECK-LABEL: func @test_pack_memref +func.func @test_pack_memref(%arg0: memref<128x256xf32>, %arg1: memref<16x8x8x32xf32>) { + // CHECK-NOT: %{{.*}} = linalg.pack + // CHECK: linalg.pack %{{.*}} inner_dims_pos = [0, 1] inner_tiles = [8, 32] into %{{.*}} : memref<128x256xf32> -> memref<16x8x8x32xf32> + linalg.pack %arg0 inner_dims_pos = [0, 1] inner_tiles = [8, 32] into %arg1 : memref<128x256xf32> -> memref<16x8x8x32xf32> + return +} + +// ----- + +// CHECK-LABEL: func @test_unpack_memref +func.func @test_unpack_memref(%arg0: memref<16x8x8x32xf32>, %arg1: memref<128x256xf32>) { + // CHECK: linalg.unpack %{{.*}} inner_dims_pos = [0, 1] inner_tiles = [8, 32] into %{{.*}} : memref<16x8x8x32xf32> + linalg.unpack %arg0 inner_dims_pos = [0, 1] inner_tiles = [8, 32] into %arg1 : memref<16x8x8x32xf32> -> memref<128x256xf32> + return +} + +// ----- + +// CHECK-LABEL: func @test_pack_memref_with_padding +func.func @test_pack_memref_with_padding(%arg0: memref<127x255xf32>, %arg1: memref<16x8x8x32xf32>, %pad: f32) { + // CHECK: linalg.pack %{{.*}} padding_value(%{{.*}} : f32) inner_dims_pos = [0, 1] inner_tiles = [8, 32] into %{{.*}} : memref<127x255xf32> + linalg.pack %arg0 padding_value(%pad : f32) inner_dims_pos = [0, 1] inner_tiles = [8, 32] into %arg1 : memref<127x255xf32> -> memref<16x8x8x32xf32> + return +} + +// ----- + +// CHECK-LABEL: func @test_pack_tensor +func.func @test_pack_tensor(%arg0: tensor<128x256xf32>, %arg1: tensor<16x8x8x32xf32>) -> tensor<16x8x8x32xf32> { + // CHECK: %[[RESULT:.*]] = linalg.pack %{{.*}} inner_dims_pos = [0, 1] inner_tiles = [8, 32] into %{{.*}} : tensor<128x256xf32> -> tensor<16x8x8x32xf32> + %0 = linalg.pack %arg0 inner_dims_pos = [0, 1] inner_tiles = [8, 32] into %arg1 : tensor<128x256xf32> -> tensor<16x8x8x32xf32> + // CHECK: return %[[RESULT]] : tensor<16x8x8x32xf32> + return %0 : tensor<16x8x8x32xf32> +} + +// ----- + +// CHECK-LABEL: func @test_unpack_tensor +func.func @test_unpack_tensor(%arg0: tensor<16x8x8x32xf32>, %arg1: tensor<128x256xf32>) -> tensor<128x256xf32> { + // CHECK: %[[RESULT:.*]] = linalg.unpack %{{.*}} inner_dims_pos = [0, 1] inner_tiles = [8, 32] into %{{.*}} : tensor<16x8x8x32xf32> -> tensor<128x256xf32> + %0 = linalg.unpack %arg0 inner_dims_pos = [0, 1] inner_tiles = [8, 32] into %arg1 : tensor<16x8x8x32xf32> -> tensor<128x256xf32> + // CHECK: return %[[RESULT]] : tensor<128x256xf32> + return %0 : tensor<128x256xf32> +} \ No newline at end of file diff --git a/mlir/test/Dialect/Linalg/roundtrip.mlir b/mlir/test/Dialect/Linalg/roundtrip.mlir index 74928920c695a..71d0ffc417179 100644 --- a/mlir/test/Dialect/Linalg/roundtrip.mlir +++ b/mlir/test/Dialect/Linalg/roundtrip.mlir @@ -755,3 +755,29 @@ func.func @conv2d_channel_first_q_promote(%img: tensor<100x3x224x224xi8>, %filt: // CHECK-LABEL: func @conv2d_channel_first_q_promote( // CHECK: %[[arg0:[a-zA-z0-9]*]]: tensor<100x3x224x224xi8>, %[[arg1:[a-zA-z0-9]*]]: tensor<64x3x5x5xi8>, %[[arg2:[a-zA-z0-9]*]]: i8, %[[arg3:[a-zA-z0-9]*]]: i8) // CHECK: linalg.conv_2d_nchw_fchw_q {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%[[arg0]], %[[arg1]], %[[arg2]], %[[arg3]] : tensor<100x3x224x224xi8>, tensor<64x3x5x5xi8>, i8, i8) outs(%{{.*}} : tensor<100x64x220x220xi32>) -> tensor<100x64x220x220xi32> + +// ----- + +func.func @pack_memref(%source: memref<128x256xf32>, %dest: memref<8x16x8x32xf32>) { + linalg.pack %source outer_dims_perm = [1, 0] inner_dims_pos = [0, 1] inner_tiles = [8, 32] + into %dest : memref<128x256xf32> -> memref<8x16x8x32xf32> + return +} + +// CHECK-label: func @pack_memref( +// CHECK: %[[source:[a-zA-z0-9]*]]: memref<128x256xf32>, %[[dest:[a-zA-z0-9]*]]: memref<8x16x8x32xf32>) { +// CHECK: linalg.pack %[[source]] outer_dims_perm = [1, 0] inner_dims_pos = [0, 1] inner_tiles = [8, 32] into %[[dest]] : memref<128x256xf32> -> memref<8x16x8x32xf32> +// CHECK: return +// CHECK: } +// ----- + +func.func @unpack_memref(%source: memref<16x8x8x32xf32>, %dest: memref<128x256xf32>) { + linalg.unpack %source inner_dims_pos = [0, 1] inner_tiles = [8, 32] + into %dest : memref<16x8x8x32xf32> -> memref<128x256xf32> + return +} + +// CHECK-label: func @unpack_memref( +// CHECK: %[[source:[a-zA-z0-9]*]]: memref<16x8x8x32xf32>, %[[dest:[a-zA-z0-9]*]]: memref<128x256xf32>) { +// CHECK: linalg.unpack %[[source]] inner_dims_pos = [0, 1] inner_tiles = [8, 32] into %[[dest]] : memref<16x8x8x32xf32> -> memref<128x256xf32> +// CHECK: return \ No newline at end of file