diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgRelayoutOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgRelayoutOps.td index 784bdd8e22f1f..ecf8b63ec3c71 100644 --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgRelayoutOps.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgRelayoutOps.td @@ -7,11 +7,7 @@ //===----------------------------------------------------------------------===// // // This file defines Pack + Unpack Ops that have been moved from the Tensor -// dialect. As such, these are defined as memory-effect-free and only accept -// "tensors" as inputs. -// -// TODO: Once a good motivating example is identified, relax these -// restrictions. +// dialect. // //===----------------------------------------------------------------------===// @@ -30,24 +26,27 @@ include "mlir/IR/OpAsmInterface.td" // RelayoutOp //===----------------------------------------------------------------------===// -class Linalg_RelayoutOp traits = []> : - Op, - DestinationStyleOpInterface, LinalgRelayoutOpInterface, - ConditionallySpeculatable, NoMemoryEffect, - DeclareOpInterfaceMethods traits = []> + : Op, + DestinationStyleOpInterface, LinalgRelayoutOpInterface, + ConditionallySpeculatable, + DeclareOpInterfaceMethods, + DeclareOpInterfaceMethods< + ReifyRankedShapedTypeOpInterface, [ "reifyResultShapes"]>, - TypesMatchWith<"result type matches type of dest", - "dest", "result", - "$_self">])> { + OptionalTypesMatchWith<"result type matches type of dest", + "dest", "result", "$_self">])> { code commonExtraClassDeclaration = [{ size_t getSourceRank() { return getSourceType().getRank(); }; size_t getDestRank() { return getDestType().getRank(); }; - RankedTensorType getSourceType() { - return ::llvm::cast(getSource().getType()); }; - RankedTensorType getDestType() { - return ::llvm::cast(getDest().getType()); }; + ShapedType getSourceType() { + return ::llvm::cast(getSource().getType()); }; + ShapedType getDestType() { + return ::llvm::cast(getDest().getType()); }; MutableOperandRange getDpsInitsMutable() { return getDestMutable(); } @@ -192,23 +191,12 @@ def Linalg_PackOp : Linalg_RelayoutOp<"pack", [ // expect tensor<2x8xf32> because CeilDiv(9, 8) = 2 ``` }]; - let arguments = (ins AnyRankedTensor:$source, - AnyRankedTensor:$dest, - Optional:$padding_value, - DefaultValuedOptionalAttr:$outer_dims_perm, - DenseI64ArrayAttr:$inner_dims_pos, - Variadic:$inner_tiles, - DenseI64ArrayAttr:$static_inner_tiles); - let results = (outs AnyRankedTensor:$result); - let assemblyFormat = [{ - $source - (`padding_value` `(` $padding_value^ `:` type($padding_value) `)`)? - (`outer_dims_perm` `=` $outer_dims_perm^)? - `inner_dims_pos` `=` $inner_dims_pos - `inner_tiles` `=` - custom($inner_tiles, $static_inner_tiles) - `into` $dest attr-dict `:` type($source) `->` type($dest) - }]; + let arguments = (ins TensorOrMemRef<[AnyType]>:$source, + TensorOrMemRef<[AnyType]>:$dest, Optional:$padding_value, + DefaultValuedOptionalAttr:$outer_dims_perm, + DenseI64ArrayAttr:$inner_dims_pos, Variadic:$inner_tiles, + DenseI64ArrayAttr:$static_inner_tiles); + let results = (outs Optional:$result); let builders = [ OpBuilder<(ins "Value":$source, "Value":$dest, @@ -218,7 +206,7 @@ def Linalg_PackOp : Linalg_RelayoutOp<"pack", [ CArg<"ArrayRef", "{}">:$outerDimsPerm)> ]; - let extraClassDeclaration = commonExtraClassDeclaration # [{ + let extraClassDeclaration = commonExtraClassDeclaration#[{ // Method to get the shape of the result as `SmallVector`. // This is a static method to allow getting the shape of the destination // expected while creating a `pack` op. @@ -230,7 +218,19 @@ def Linalg_PackOp : Linalg_RelayoutOp<"pack", [ // Method to get the `RankedTensorType` of the result based on the inner // tiles, position of the inner tiles (innerDimsPos) and interchange vector // of outer loops (outerDimsPerm). - static RankedTensorType inferPackedType(RankedTensorType sourceType, + static RankedTensorType inferPackedTensorType(RankedTensorType sourceType, + ArrayRef innerTileSizes, ArrayRef innerDimsPos, + ArrayRef outerDimsPerm = {}); + + // Method to get the `MemRefType` of the result based on the inner + // tiles, position of the inner tiles (innerDimsPos) and interchange vector + // of outer loops (outerDimsPerm). + static MemRefType inferPackedMemRefType(MemRefType sourceType, + ArrayRef innerTileSizes, ArrayRef innerDimsPos, + ArrayRef outerDimsPerm = {}); + + // Returns the shape of the packed type. It is a shared helper helps type inference methods in a way that ensures that they agree on which dimensions are dynamic. + static SmallVector inferPackedShape(ArrayRef inputShape, ArrayRef innerTileSizes, ArrayRef innerDimsPos, ArrayRef outerDimsPerm = {}); @@ -282,6 +282,8 @@ def Linalg_PackOp : Linalg_RelayoutOp<"pack", [ let hasCanonicalizeMethod = 1; let hasFolder = 1; + + let hasCustomAssemblyFormat = 1; } //===----------------------------------------------------------------------===// @@ -349,21 +351,12 @@ def Linalg_UnPackOp : Linalg_RelayoutOp<"unpack"> { // Outer Dims: 9x3x8 Inner Dims: 4x2 ``` }]; - let arguments = (ins AnyRankedTensor:$source, - AnyRankedTensor:$dest, - DefaultValuedOptionalAttr:$outer_dims_perm, - DenseI64ArrayAttr:$inner_dims_pos, - Variadic:$inner_tiles, - DenseI64ArrayAttr:$static_inner_tiles); - let results = (outs AnyRankedTensor:$result); - let assemblyFormat = [{ - $source - (`outer_dims_perm` `=` $outer_dims_perm^)? - `inner_dims_pos` `=` $inner_dims_pos - `inner_tiles` `=` - custom($inner_tiles, $static_inner_tiles) - `into` $dest attr-dict `:` type($source) `->` type($dest) - }]; + let arguments = (ins TensorOrMemRef<[AnyType]>:$source, + TensorOrMemRef<[AnyType]>:$dest, + DefaultValuedOptionalAttr:$outer_dims_perm, + DenseI64ArrayAttr:$inner_dims_pos, Variadic:$inner_tiles, + DenseI64ArrayAttr:$static_inner_tiles); + let results = (outs Optional:$result); let builders = [ OpBuilder<(ins "Value":$source, "Value":$dest, @@ -406,6 +399,8 @@ def Linalg_UnPackOp : Linalg_RelayoutOp<"unpack"> { let hasCanonicalizeMethod = 1; let hasFolder = 1; + + let hasCustomAssemblyFormat = 1; } #endif // LINALG_RELEAYOUT_OPS diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp index 3dc45edf4a23f..6bcf9281f4e52 100644 --- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -4968,12 +4968,12 @@ Speculation::Speculatability ElementwiseOp::getSpeculatability() { template SmallVector getPackedOuterShapeWithoutTransposition(OpTy packOrUnPack) { - RankedTensorType packedType = (std::is_same::value) - ? packOrUnPack.getDestType() - : packOrUnPack.getSourceType(); - RankedTensorType unpackedType = (std::is_same::value) - ? packOrUnPack.getSourceType() - : packOrUnPack.getDestType(); + ShapedType packedType = (std::is_same::value) + ? packOrUnPack.getDestType() + : packOrUnPack.getSourceType(); + ShapedType unpackedType = (std::is_same::value) + ? packOrUnPack.getSourceType() + : packOrUnPack.getDestType(); SmallVector result( packedType.getShape().take_front(unpackedType.getRank())); if (!packOrUnPack.getOuterDimsPerm().empty()) { @@ -5107,15 +5107,34 @@ static LogicalResult commonVerifierPackAndUnPackOp(OpTy packOrUnPack) { return llvm::any_of(tiles, isZeroInteger); }; + // Verify that the source and destination are ranked types. + if (!packOrUnPack.getSourceType().hasRank() || + !packOrUnPack.getDestType().hasRank()) { + return op->emitError("expected both source and destination to have rank"); + } + + // Verify that the Operation does not have mixed tensor/buffer semantics. + if (!packOrUnPack.hasPureBufferSemantics() && + !packOrUnPack.hasPureTensorSemantics()) { + return op->emitError("mixing tensor and buffer semantics is not allowed"); + } + const unsigned numResults = packOrUnPack.getNumResults(); + if (packOrUnPack.hasPureTensorSemantics() && numResults != 1) { + return op->emitError("expected 1 result, got ") << numResults; + } + if (packOrUnPack.hasPureBufferSemantics() && numResults != 0) { + return op->emitError("expected 0 results, got ") << numResults; + } + // Verify tiles. Do not allow zero tiles. SmallVector mixedTiles = packOrUnPack.getMixedTiles(); if (hasZeros(mixedTiles)) return op->emitError("invalid zero tile factor"); // Verify inner_dims_pos and outer_dims_perm. - RankedTensorType unpackedType = (std::is_same::value) - ? packOrUnPack.getSourceType() - : packOrUnPack.getDestType(); + ShapedType unpackedType = (std::is_same::value) + ? packOrUnPack.getSourceType() + : packOrUnPack.getDestType(); size_t unpackedRank = unpackedType.getRank(); ArrayRef innerDimsPos = packOrUnPack.getInnerDimsPos(); ArrayRef outerDimPerm = packOrUnPack.getOuterDimsPerm(); @@ -5152,8 +5171,9 @@ static LogicalResult commonVerifierPackAndUnPackOp(OpTy packOrUnPack) { // Verify result shape is greater than the minimum expected // by the pack operation, and that the output shape // represents full tiles. - RankedTensorType expectedPackedType = PackOp::inferPackedType( - unpackedType, packOrUnPack.getStaticTiles(), innerDimsPos, outerDimPerm); + SmallVector expectedPackedShape = PackOp::inferPackedShape( + unpackedType.getShape(), packOrUnPack.getStaticTiles(), + packOrUnPack.getInnerDimsPos(), packOrUnPack.getOuterDimsPerm()); if (!llvm::all_of( llvm::zip(packedType.getShape().take_back(mixedTiles.size()), mixedTiles), @@ -5170,11 +5190,20 @@ static LogicalResult commonVerifierPackAndUnPackOp(OpTy packOrUnPack) { return op->emitError("mismatch in inner tile sizes specified and shaped of " "tiled dimension in the packed type"); } - if (failed(verifyCompatibleShape(expectedPackedType.getShape(), - packedType.getShape()))) { + if (failed( + verifyCompatibleShape(expectedPackedShape, packedType.getShape()))) { + auto elementType = unpackedType.getElementType(); + Type expectedType, actualType; + if (packOrUnPack.hasPureTensorSemantics()) { + expectedType = RankedTensorType::get(expectedPackedShape, elementType); + actualType = RankedTensorType::get(packedType.getShape(), elementType); + } else { + expectedType = MemRefType::get(expectedPackedShape, elementType); + actualType = MemRefType::get(packedType.getShape(), elementType); + } return op->emitError("expected ") - << expectedPackedType << " for the packed domain value, got " - << packedType; + << expectedType << " for the packed domain value, got " + << actualType; } return success(); } @@ -5235,9 +5264,158 @@ commonPermutationOfPackAndUnPackOp(OpTy packOrUnPackOp, //===----------------------------------------------------------------------===// void PackOp::getAsmResultNames(function_ref setNameFn) { + if (getNumResults() == 0) + return; setNameFn(getResult(), "pack"); } +ParseResult PackOp::parse(OpAsmParser &parser, OperationState &result) { + OpAsmParser::UnresolvedOperand source, dest; + SmallVector dynamicTiles; + SmallVector paddingValue; + SmallVector paddingValueType; + SmallVector staticTiles; + DenseI64ArrayAttr innerDimsPos, outerDimsPerm; + Type sourceType, destType, resultType; + + if (parser.parseOperand(source)) + return failure(); + + if (succeeded(parser.parseOptionalKeyword("padding_value"))) { + if (parser.parseLParen() || + parser.parseOperandList(paddingValue, /*requiredOperandCount=*/1) || + parser.parseColon() || parser.parseTypeList(paddingValueType) || + parser.parseRParen()) + return failure(); + } + + if (succeeded(parser.parseOptionalKeyword("outer_dims_perm"))) { + if (parser.parseEqual()) + return failure(); + + SmallVector outerDimsPermVec; + if (parser.parseCommaSeparatedList(AsmParser::Delimiter::Square, [&]() { + int64_t value; + if (parser.parseInteger(value)) + return failure(); + outerDimsPermVec.push_back(value); + return success(); + })) + return failure(); + outerDimsPerm = parser.getBuilder().getDenseI64ArrayAttr(outerDimsPermVec); + } + + if (parser.parseKeyword("inner_dims_pos") || parser.parseEqual()) + return failure(); + + SmallVector innerDimsPosVec; + if (parser.parseCommaSeparatedList(AsmParser::Delimiter::Square, [&]() { + int64_t value; + if (parser.parseInteger(value)) + return failure(); + innerDimsPosVec.push_back(value); + return success(); + })) + return failure(); + innerDimsPos = parser.getBuilder().getDenseI64ArrayAttr(innerDimsPosVec); + + if (parser.parseKeyword("inner_tiles") || parser.parseEqual()) + return failure(); + + DenseI64ArrayAttr staticTilesAttr; + if (parseDynamicIndexList(parser, dynamicTiles, staticTilesAttr)) + return failure(); + for (auto val : staticTilesAttr.asArrayRef()) + staticTiles.push_back(val); + + if (parser.parseKeyword("into") || parser.parseOperand(dest)) + return failure(); + + if (parser.parseOptionalAttrDict(result.attributes)) + return failure(); + + if (parser.parseColon() || parser.parseType(sourceType)) + return failure(); + + bool hasArrow = succeeded(parser.parseOptionalArrow()); + if (hasArrow) { + if (parser.parseType(destType)) + return failure(); + } + + bool isMemRef = llvm::isa(sourceType); + if (!hasArrow) { + return parser.emitError(parser.getCurrentLocation(), + "pack/unpack requires '->' and destination type"); + } + + if (!isMemRef) { + resultType = destType; + } + + if (parser.resolveOperand(source, sourceType, result.operands) || + parser.resolveOperand(dest, destType, result.operands)) + return failure(); + + if (!paddingValue.empty() && + parser.resolveOperands(paddingValue, paddingValueType[0], + result.operands)) + return failure(); + + if (!dynamicTiles.empty() && + parser.resolveOperands(dynamicTiles, parser.getBuilder().getIndexType(), + result.operands)) + return failure(); + + result.addAttribute("static_inner_tiles", + parser.getBuilder().getDenseI64ArrayAttr(staticTiles)); + result.addAttribute("inner_dims_pos", innerDimsPos); + if (outerDimsPerm) + result.addAttribute("outer_dims_perm", outerDimsPerm); + + SmallVector segmentSizes = { + 1, 1, static_cast(paddingValue.size()), + static_cast(dynamicTiles.size())}; + result.addAttribute("operandSegmentSizes", + parser.getBuilder().getDenseI32ArrayAttr(segmentSizes)); + + if (!isMemRef) + result.addTypes(resultType); + + return success(); +} + +void PackOp::print(OpAsmPrinter &p) { + p << " " << getSource(); + + if (getPaddingValue()) { + p << " padding_value(" << getPaddingValue() << " : " + << getPaddingValue().getType() << ")"; + } + + if (!getOuterDimsPerm().empty()) { + p << " outer_dims_perm = ["; + llvm::interleaveComma(getOuterDimsPerm(), p); + p << "]"; + } + + p << " inner_dims_pos = ["; + llvm::interleaveComma(getInnerDimsPos(), p); + p << "]"; + + p << " inner_tiles = "; + printDynamicIndexList(p, *this, getInnerTiles(), getStaticInnerTilesAttr()); + + p << " into " << getDest(); + + p.printOptionalAttrDict((*this)->getAttrs(), + {"static_inner_tiles", "inner_dims_pos", + "outer_dims_perm", "operandSegmentSizes"}); + + p << " : " << getSource().getType(); + p << " -> " << getDest().getType(); +} + void PackOp::build(OpBuilder &builder, OperationState &state, Value source, Value dest, ArrayRef innerDimsPos, ArrayRef innerTiles, @@ -5260,6 +5438,8 @@ void PackOp::build(OpBuilder &builder, OperationState &state, Value source, LogicalResult PackOp::reifyResultShapes(OpBuilder &builder, ReifiedRankedShapedTypeDims &reifiedReturnShapes) { + if (!hasPureTensorSemantics()) + return failure(); return reifyResultShapesImpl(*this, builder, reifiedReturnShapes); } @@ -5395,13 +5575,11 @@ asShapeWithAnyValueAsDynamic(ArrayRef ofrs) { return result; } -/// Helper for PackOp::{getResultShape,inferPackedType}. Returns the shape of -/// the packed type. Having a shared helper helps implement these two methods in -/// a way that ensures that they agree on which dimensions are dynamic. -static SmallVector getPackOpResultTypeShape( - ArrayRef sourceShape, ArrayRef innerTileSizes, - ArrayRef innerDimsPos, ArrayRef outerDimsPerm) { - SmallVector resultShape = llvm::to_vector(sourceShape); +SmallVector PackOp::inferPackedShape(ArrayRef inputShape, + ArrayRef innerTileSizes, + ArrayRef innerDimsPos, + ArrayRef outerDimsPerm) { + SmallVector resultShape = llvm::to_vector(inputShape); for (auto tiledDim : llvm::enumerate(llvm::to_vector(innerDimsPos))) { if (ShapedType::isDynamic(resultShape[tiledDim.value()])) continue; @@ -5441,9 +5619,9 @@ SmallVector PackOp::getResultShape( resultDims.append(innerTileSizes.begin(), innerTileSizes.end()); SmallVector resultTypeShape = - getPackOpResultTypeShape(asShapeWithAnyValueAsDynamic(sourceDims), - asShapeWithAnyValueAsDynamic(innerTileSizes), - innerDimsPos, outerDimsPerm); + inferPackedShape(asShapeWithAnyValueAsDynamic(sourceDims), + asShapeWithAnyValueAsDynamic(innerTileSizes), + innerDimsPos, outerDimsPerm); // Fix-up `resultDims` to ensure that they are Value's if and only if the // result type shape says it's a dynamic dim. This is needed as callers may @@ -5459,15 +5637,21 @@ SmallVector PackOp::getResultShape( return resultDims; } -/// Get the expected packed type based on source type, tile factors, position of -/// the inner tiles and permutation of the outer tiled loop. -RankedTensorType PackOp::inferPackedType(RankedTensorType sourceType, +RankedTensorType PackOp::inferPackedTensorType( + RankedTensorType sourceType, ArrayRef innerTileSizes, + ArrayRef innerDimsPos, ArrayRef outerDimsPerm) { + SmallVector resultShape = inferPackedShape( + sourceType.getShape(), innerTileSizes, innerDimsPos, outerDimsPerm); + return RankedTensorType::get(resultShape, sourceType.getElementType()); +} + +MemRefType PackOp::inferPackedMemRefType(MemRefType sourceType, ArrayRef innerTileSizes, ArrayRef innerDimsPos, ArrayRef outerDimsPerm) { - SmallVector resultShape = getPackOpResultTypeShape( + SmallVector resultShape = inferPackedShape( sourceType.getShape(), innerTileSizes, innerDimsPos, outerDimsPerm); - return RankedTensorType::get(resultShape, sourceType.getElementType()); + return MemRefType::get(resultShape, sourceType.getElementType()); } Value PackOp::createDestinationTensor(OpBuilder &b, Location loc, Value source, @@ -5516,6 +5700,45 @@ PackOp PackOp::createTransposedClone(OpBuilder &b, Location loc, getPaddingValue(), metadata.outerDimsPerm); } +template +static void getPackUnPackEffectsImpl( + OpTy op, SmallVectorImpl> + &effects) { + // No memory effects for pure tensor semantics + if (op.hasPureTensorSemantics()) + return; + + for (OpOperand &opOperand : op.getOperation()->getOpOperands()) { + if (!llvm::isa(opOperand.get().getType())) + continue; + + if (&opOperand == &op.getSourceMutable()) { + effects.emplace_back(MemoryEffects::Read::get(), &opOperand, /*stage=*/0, + /*effectOnFullRegion=*/true, + SideEffects::DefaultResource::get()); + } else if (&opOperand == &op.getDestMutable()) { + effects.emplace_back(MemoryEffects::Read::get(), &opOperand, /*stage=*/0, + /*effectOnFullRegion=*/true, + SideEffects::DefaultResource::get()); + effects.emplace_back(MemoryEffects::Write::get(), &opOperand, /*stage=*/0, + /*effectOnFullRegion=*/true, + SideEffects::DefaultResource::get()); + } + } +} + +void PackOp::getEffects( + SmallVectorImpl> + &effects) { + getPackUnPackEffectsImpl(*this, effects); +} + +void UnPackOp::getEffects( + SmallVectorImpl> + &effects) { + getPackUnPackEffectsImpl(*this, effects); +} + /// Returns true if the tiles and the tiled dims are constant. template static bool areTilesAndTiledDimsAllConstant(OpTy op) { @@ -5535,6 +5758,8 @@ static bool areTilesAndTiledDimsAllConstant(OpTy op) { } Speculation::Speculatability PackOp::getSpeculatability() { + if (!hasPureTensorSemantics()) + return Speculation::NotSpeculatable; if (getPaddingValue()) return Speculation::Speculatable; @@ -5625,6 +5850,10 @@ static bool inferStaticShape(PackOp packOp, SmallVectorImpl &srcShape, } LogicalResult PackOp::canonicalize(PackOp packOp, PatternRewriter &rewriter) { + // TODO: Support Memref PackOp. Temporarily return failure. + if (!packOp.hasPureTensorSemantics()) + return failure(); + // Fold an pack(unpack(x)) to x. if (auto unPackOp = packOp.getSource().getDefiningOp()) { if (unPackOp.getSourceType() == packOp.getDestType() && @@ -5655,7 +5884,7 @@ LogicalResult PackOp::canonicalize(PackOp packOp, PatternRewriter &rewriter) { tensor::CastOp::create(rewriter, loc, newSrcType, packOp.getSource()); } Value dest = packOp.getDest(); - RankedTensorType originalResultType = packOp.getDestType(); + ShapedType originalResultType = packOp.getDestType(); bool needUpdateDestType = (destShape != originalResultType.getShape()); if (needUpdateDestType) { auto newDestType = packOp.getDestType().clone(destShape); @@ -5670,9 +5899,9 @@ LogicalResult PackOp::canonicalize(PackOp packOp, PatternRewriter &rewriter) { // Insert a cast if needed if (needUpdateDestType) { rewriter.setInsertionPointAfter(packOp); - auto castOp = - tensor::CastOp::create(rewriter, loc, originalResultType, packOp); - rewriter.replaceAllUsesExcept(packOp, castOp, castOp); + auto castOp = tensor::CastOp::create(rewriter, loc, originalResultType, + packOp.getResult()); + rewriter.replaceAllUsesExcept(packOp.getResult(), castOp, castOp); } return success(); } @@ -5681,8 +5910,7 @@ LogicalResult PackOp::canonicalize(PackOp packOp, PatternRewriter &rewriter) { } template -static bool isLikePadUnPad(PackOrUnpackOp packOp, - RankedTensorType packedTensorType) { +static bool isLikePadUnPad(PackOrUnpackOp packOp, ShapedType packedTensorType) { static_assert(std::is_same::value || std::is_same::value, "Function meant for pack/unpack"); @@ -5715,19 +5943,25 @@ static bool isLikePadUnPad(PackOrUnpackOp packOp, bool PackOp::isLikePad() { auto packedTensorType = - llvm::cast((*this)->getResultTypes().front()); + llvm::cast((*this)->getResultTypes().front()); return isLikePadUnPad(*this, packedTensorType); } -OpFoldResult PackOp::fold(FoldAdaptor adaptor) { +::mlir::LogicalResult +PackOp::fold(FoldAdaptor adaptor, + ::llvm::SmallVectorImpl &results) { + if (!hasPureTensorSemantics()) + return failure(); std::optional paddingValue; if (auto pad = adaptor.getPaddingValue()) paddingValue = pad; if (OpFoldResult reshapedSource = reshapeConstantSource( llvm::dyn_cast_if_present(adaptor.getSource()), - getDestType(), paddingValue)) - return reshapedSource; - return {}; + cast(getDestType()), paddingValue)) { + results.push_back(reshapedSource); + return success(); + } + return failure(); } /// Folds a tensor.cast op into a consuming PackOp op if the @@ -5749,6 +5983,10 @@ struct FoldTensorCastPackOp : public OpRewritePattern { LogicalResult matchAndRewrite(PackOp op, PatternRewriter &rewriter) const override { + // TODO: Support Memref PackOp. Temporarily return failure. + if (!op.hasPureTensorSemantics()) + return failure(); + if (!tensor::hasFoldableTensorCastOperand(op)) return failure(); @@ -5791,12 +6029,143 @@ struct FoldTensorCastPackOp : public OpRewritePattern { void UnPackOp::getAsmResultNames( function_ref setNameFn) { + if (getNumResults() == 0) + return; setNameFn(getResult(), "unpack"); } +// Custom parser for UnPackOp that handles the memref/tensor case distinction +ParseResult UnPackOp::parse(OpAsmParser &parser, OperationState &result) { + OpAsmParser::UnresolvedOperand source, dest; + SmallVector dynamicTiles; + SmallVector staticTiles; + DenseI64ArrayAttr innerDimsPos, outerDimsPerm; + Type sourceType, destType, resultType; + + if (parser.parseOperand(source)) + return failure(); + + if (succeeded(parser.parseOptionalKeyword("outer_dims_perm"))) { + if (parser.parseEqual()) + return failure(); + + SmallVector outerDimsPermVec; + if (parser.parseCommaSeparatedList(AsmParser::Delimiter::Square, [&]() { + int64_t value; + if (parser.parseInteger(value)) + return failure(); + outerDimsPermVec.push_back(value); + return success(); + })) + return failure(); + outerDimsPerm = parser.getBuilder().getDenseI64ArrayAttr(outerDimsPermVec); + } + + if (parser.parseKeyword("inner_dims_pos") || parser.parseEqual()) + return failure(); + + SmallVector innerDimsPosVec; + if (parser.parseCommaSeparatedList(AsmParser::Delimiter::Square, [&]() { + int64_t value; + if (parser.parseInteger(value)) + return failure(); + innerDimsPosVec.push_back(value); + return success(); + })) + return failure(); + innerDimsPos = parser.getBuilder().getDenseI64ArrayAttr(innerDimsPosVec); + + if (parser.parseKeyword("inner_tiles") || parser.parseEqual()) + return failure(); + + DenseI64ArrayAttr staticTilesAttr; + if (parseDynamicIndexList(parser, dynamicTiles, staticTilesAttr)) + return failure(); + for (auto val : staticTilesAttr.asArrayRef()) + staticTiles.push_back(val); + + if (parser.parseKeyword("into") || parser.parseOperand(dest)) + return failure(); + + if (parser.parseOptionalAttrDict(result.attributes)) + return failure(); + + if (parser.parseColon() || parser.parseType(sourceType)) + return failure(); + + bool hasArrow = succeeded(parser.parseOptionalArrow()); + if (hasArrow) { + if (parser.parseType(destType)) + return failure(); + } + + bool isMemRef = llvm::isa(sourceType); + if (!hasArrow) { + return parser.emitError(parser.getCurrentLocation(), + "pack/unpack requires '->' and destination type"); + } + + if (!isMemRef) { + resultType = destType; + } + + if (parser.resolveOperand(source, sourceType, result.operands) || + parser.resolveOperand(dest, destType, result.operands)) + return failure(); + + if (!dynamicTiles.empty() && + parser.resolveOperands(dynamicTiles, parser.getBuilder().getIndexType(), + result.operands)) + return failure(); + + result.addAttribute("static_inner_tiles", + parser.getBuilder().getDenseI64ArrayAttr(staticTiles)); + result.addAttribute("inner_dims_pos", innerDimsPos); + if (outerDimsPerm) + result.addAttribute("outer_dims_perm", outerDimsPerm); + + SmallVector segmentSizes = { + 1, 1, 0, static_cast(dynamicTiles.size())}; + result.addAttribute("operandSegmentSizes", + parser.getBuilder().getDenseI32ArrayAttr(segmentSizes)); + + if (!isMemRef) + result.addTypes(resultType); + + return success(); +} + +void UnPackOp::print(OpAsmPrinter &p) { + p << " " << getSource(); + + if (!getOuterDimsPerm().empty()) { + p << " outer_dims_perm = ["; + llvm::interleaveComma(getOuterDimsPerm(), p); + p << "]"; + } + + p << " inner_dims_pos = ["; + llvm::interleaveComma(getInnerDimsPos(), p); + p << "]"; + + p << " inner_tiles = "; + printDynamicIndexList(p, *this, getInnerTiles(), getStaticInnerTilesAttr()); + + p << " into " << getDest(); + + p.printOptionalAttrDict((*this)->getAttrs(), + {"static_inner_tiles", "inner_dims_pos", + "outer_dims_perm", "operandSegmentSizes"}); + + p << " : " << getSource().getType(); + p << " -> " << getDest().getType(); +} + LogicalResult UnPackOp::reifyResultShapes(OpBuilder &builder, ReifiedRankedShapedTypeDims &reifiedReturnShapes) { + if (!hasPureTensorSemantics()) + return failure(); return reifyResultShapesImpl(*this, builder, reifiedReturnShapes); } @@ -5841,6 +6210,8 @@ LogicalResult UnPackOp::verify() { } Speculation::Speculatability UnPackOp::getSpeculatability() { + if (!hasPureTensorSemantics()) + return Speculation::NotSpeculatable; // See PackOp::getSpeculatability. if (!areTilesAndTiledDimsAllConstant(*this)) return Speculation::NotSpeculatable; @@ -5947,6 +6318,10 @@ static bool inferStaticShape(UnPackOp op, SmallVectorImpl &srcShape, LogicalResult UnPackOp::canonicalize(UnPackOp unPackOp, PatternRewriter &rewriter) { + // TODO: Support Memref UnPackOp. Temporarily return failure. + if (!unPackOp.hasPureTensorSemantics()) + return failure(); + /// unpack(pack(x)) -> x if (PackOp packOp = unPackOp.getSource().getDefiningOp()) { if (packOp.getSourceType() != unPackOp.getDestType()) @@ -6003,11 +6378,11 @@ LogicalResult UnPackOp::canonicalize(UnPackOp unPackOp, dest = tensor::CastOp::create(rewriter, loc, newDestType, unPackOp.getDest()); } - Value newOp = UnPackOp::create( + UnPackOp newOp = UnPackOp::create( rewriter, loc, source, dest, unPackOp.getInnerDimsPos(), unPackOp.getMixedTiles(), unPackOp.getOuterDimsPerm()); rewriter.replaceOpWithNewOp( - unPackOp, unPackOp.getResult().getType(), newOp); + unPackOp, unPackOp.getResult().getType(), newOp.getResult()); return success(); } @@ -6041,16 +6416,24 @@ bool UnPackOp::canFoldSliceOp(tensor::ExtractSliceOp sliceOp) { } bool UnPackOp::isLikeUnPad() { - RankedTensorType packedTensorType = getSourceType(); + ShapedType packedTensorType = getSourceType(); return isLikePadUnPad(*this, packedTensorType); } -OpFoldResult UnPackOp::fold(FoldAdaptor adaptor) { +::mlir::LogicalResult +UnPackOp::fold(FoldAdaptor adaptor, + ::llvm::SmallVectorImpl &results) { + // TODO: Support Memref UnPackOp. Temporarily return failure. + if (!hasPureTensorSemantics()) + return failure(); + if (OpFoldResult reshapedSource = reshapeConstantSource( llvm::dyn_cast_if_present(adaptor.getSource()), - getResult().getType())) - return reshapedSource; - return {}; + cast(getResult().getType()))) { + results.push_back(reshapedSource); + return success(); + } + return failure(); } /// Folds a tensor.cast op into a consuming UnPackOp op if the @@ -6072,6 +6455,10 @@ struct FoldTensorCastUnPackOp : public OpRewritePattern { LogicalResult matchAndRewrite(UnPackOp op, PatternRewriter &rewriter) const override { + // TODO: Support Memref UnPackOp. Temporarily return failure. + if (!op.hasPureTensorSemantics()) + return failure(); + if (!tensor::hasFoldableTensorCastOperand(op)) return failure(); diff --git a/mlir/lib/Dialect/Linalg/Transforms/BlockPackMatmul.cpp b/mlir/lib/Dialect/Linalg/Transforms/BlockPackMatmul.cpp index 6912da3ffbc83..6ea1eb50b13ce 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/BlockPackMatmul.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/BlockPackMatmul.cpp @@ -90,6 +90,10 @@ transposePackedMatmul(RewriterBase &rewriter, linalg::LinalgOp linalgOp, linalg::PackOp packOp, AffineMap operandMap, ArrayRef blocksStartDimPos, bool transposeOuterBlocks, bool transposeInnerBlocks) { + // TODO: Support Memref PackOp. Temporarily return failure. + if (!packOp.hasPureTensorSemantics()) + return failure(); + assert(operandMap.getNumDims() >= 4 && "expected at least 4D prepacked matmul"); assert(blocksStartDimPos.size() >= 2 && diff --git a/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp b/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp index 3bb5f8af821c0..69d216fab7da8 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp @@ -282,8 +282,8 @@ static bool getPackedOperandDetails( }); bool requirePadding = linalg::PackOp::requirePaddingValueStrict( inputType.getShape(), innerDimsPos, - linalg::PackOp::inferPackedType(inputType, maybeIntInnerTileSizes, - innerDimsPos, outerDimsPerm) + linalg::PackOp::inferPackedTensorType(inputType, maybeIntInnerTileSizes, + innerDimsPos, outerDimsPerm) .getShape(), outerDimsPerm, innerTileSizes); currOperandDetails.innerDimsPos = innerDimsPos; @@ -341,10 +341,11 @@ static std::tuple getOrCreatePackedViewOfOperand( b, loc, opOperand->get(), innerTileSizes, innerDimsPos, outerDimsPerm); auto poison = ub::PoisonOp::create( b, loc, getElementTypeOrSelf(opOperand->get().getType())); - Value packedOperand = + PackOp packedOperand = linalg::PackOp::create(b, loc, opOperand->get(), empty, innerDimsPos, innerTileSizes, poison, outerDimsPerm); - return std::make_tuple(packedOperand, currOperandDetails.indexingMap); + return std::make_tuple(packedOperand.getResult(), + currOperandDetails.indexingMap); } /// This function is a helper subroutine to pack a genericOp and return it. It @@ -571,6 +572,10 @@ struct BubbleUpPackOpThroughGenericOpPattern LogicalResult matchAndRewrite(linalg::PackOp packOp, PatternRewriter &rewriter) const override { + // TODO: Support Memref PackOp. Temporarily return failure. + if (!packOp.hasPureTensorSemantics()) + return failure(); + auto genericOp = bubbleUpPackOpThroughGenericOp(rewriter, packOp, controlFn, poisonPaddingOk); if (failed(genericOp)) @@ -594,6 +599,10 @@ class BubbleUpPackThroughPadOp final : public OpRewritePattern { LogicalResult matchAndRewrite(linalg::PackOp packOp, PatternRewriter &rewriter) const override { + // TODO: Support Memref PackOp. Temporarily return failure. + if (!packOp.hasPureTensorSemantics()) + return failure(); + auto padOp = packOp.getSource().getDefiningOp(); if (!padOp) return failure(); @@ -653,19 +662,19 @@ class BubbleUpPackThroughPadOp final : public OpRewritePattern { lowPad.append(pointLoopsSize, rewriter.getIndexAttr(0)); highPad.append(pointLoopsSize, rewriter.getIndexAttr(0)); - auto newPadOp = - tensor::PadOp::create(rewriter, loc, /*result=*/Type(), sourcePack, - lowPad, highPad, paddingVal, padOp.getNofold()); + auto newPadOp = tensor::PadOp::create( + rewriter, loc, /*result=*/Type(), sourcePack.getResult(), lowPad, + highPad, paddingVal, padOp.getNofold()); // If the pad has more than one user, create an unpack on the new pad to // replace the other uses. if (!padOp->hasOneUse()) { auto unpackEmpty = linalg::UnPackOp::createDestinationTensor( rewriter, loc, newPadOp, mixedTiles, innerDimsPos, outerDimsPerm); - Value unpackedPad = + UnPackOp unpackedPad = linalg::UnPackOp::create(rewriter, loc, newPadOp, unpackEmpty, innerDimsPos, mixedTiles, outerDimsPerm); - rewriter.replaceAllUsesExcept(padOp, unpackedPad, sourcePack); + rewriter.replaceAllUsesExcept(padOp, unpackedPad.getResult(), sourcePack); } // Replace the pack with the new pad. @@ -763,6 +772,10 @@ static LogicalResult bubbleUpPackOpThroughCollapseShape(tensor::CollapseShapeOp collapseOp, linalg::PackOp packOp, PatternRewriter &rewriter) { + // TODO: Support Memref PackOp. Temporarily return failure. + if (!packOp.hasPureTensorSemantics()) + return failure(); + SmallVector innerTileSizes = packOp.getStaticTiles(); ArrayRef innerDimsPos = packOp.getInnerDimsPos(); ArrayRef outerDimsPerm = packOp.getOuterDimsPerm(); @@ -812,8 +825,8 @@ bubbleUpPackOpThroughCollapseShape(tensor::CollapseShapeOp collapseOp, } auto newCollapseOp = tensor::CollapseShapeOp::create( - rewriter, collapseOp.getLoc(), packOp.getType(), newPackOp, - newReassocIndices); + rewriter, collapseOp.getLoc(), packOp.getResult().getType(), + newPackOp.getResult(), newReassocIndices); rewriter.replaceOp(packOp, newCollapseOp); return success(); @@ -868,6 +881,10 @@ static LogicalResult bubbleUpPackOpThroughExpandShape(tensor::ExpandShapeOp expandOp, linalg::PackOp packOp, PatternRewriter &rewriter) { + // TODO: Support Memref PackOp. Temporarily return failure. + if (!packOp.hasPureTensorSemantics()) + return failure(); + // Outer dimensions permutation is not supported currently. // TODO: Handle outer_dims_perm variants. ArrayRef outerDimsPerm = packOp.getOuterDimsPerm(); @@ -918,7 +935,7 @@ bubbleUpPackOpThroughExpandShape(tensor::ExpandShapeOp expandOp, // If reassociation is not possible, then reordering cannot happen. // This can be caused by pack padding affecting previously expanded // dimensions or packing extending dimensions. - RankedTensorType newPackType = linalg::PackOp::inferPackedType( + RankedTensorType newPackType = linalg::PackOp::inferPackedTensorType( expandOp.getSrcType(), packOp.getStaticInnerTiles(), projectedInnerDimsPos, /*outerDimsPerm=*/SmallVector{}); auto reassocExpand = @@ -930,14 +947,14 @@ bubbleUpPackOpThroughExpandShape(tensor::ExpandShapeOp expandOp, Value destTensor = linalg::PackOp::createDestinationTensor( rewriter, packOp.getLoc(), expandOp.getSrc(), packOp.getMixedTiles(), projectedInnerDimsPos, /*outerDimsPerm=*/SmallVector{}); - Value packedVal = linalg::PackOp::create( + PackOp packedVal = linalg::PackOp::create( rewriter, packOp.getLoc(), expandOp.getSrc(), destTensor, projectedInnerDimsPos, packOp.getMixedTiles(), packOp.getPaddingValue(), /*outerDimsPerm=*/SmallVector{}); - Value newExpandOp = tensor::ExpandShapeOp::create(rewriter, packOp.getLoc(), - packOp.getDestType(), - packedVal, *reassocExpand); + Value newExpandOp = tensor::ExpandShapeOp::create( + rewriter, packOp.getLoc(), packOp.getDestType(), packedVal.getResult(), + *reassocExpand); rewriter.replaceOp(packOp, newExpandOp); return success(); @@ -951,6 +968,10 @@ class BubbleUpPackOpThroughReshapeOp final LogicalResult matchAndRewrite(linalg::PackOp packOp, PatternRewriter &rewriter) const override { + // TODO: Support Memref PackOp. Temporarily return failure. + if (!packOp.hasPureTensorSemantics()) + return failure(); + Operation *srcOp = packOp.getSource().getDefiningOp(); // Currently only support when the pack op is the only user. if (!srcOp || !(srcOp->getNumResults() == 1) || @@ -1001,6 +1022,10 @@ class BubbleUpPackOpThroughReshapeOp final static LogicalResult pushDownUnPackOpThroughExpandShape( linalg::UnPackOp unPackOp, tensor::ExpandShapeOp expandOp, PatternRewriter &rewriter, ControlPropagationFn controlFn) { + // TODO: Support Memref UnpackOp. Temporarily return failure. + if (!unPackOp.hasPureTensorSemantics()) + return failure(); + // User controlled propagation function. if (!controlFn(&expandOp.getSrcMutable())) return failure(); @@ -1048,7 +1073,7 @@ static LogicalResult pushDownUnPackOpThroughExpandShape( nextPos += 1; } - RankedTensorType newExpandType = linalg::PackOp::inferPackedType( + RankedTensorType newExpandType = linalg::PackOp::inferPackedTensorType( expandTy, innerTileSizes, projectedInnerDimsPos, newOuterDimsPerm); auto newExpandOp = tensor::ExpandShapeOp::create(rewriter, expandOp.getLoc(), newExpandType, @@ -1075,6 +1100,10 @@ class PushDownUnPackOpThroughReshapeOp final LogicalResult matchAndRewrite(linalg::UnPackOp unPackOp, PatternRewriter &rewriter) const override { + // TODO: Support Memref UnPackOp. Temporarily return failure. + if (!unPackOp.hasPureTensorSemantics()) + return failure(); + Value result = unPackOp.getResult(); // Currently only support unpack op with the single user. if (!result.hasOneUse()) { @@ -1274,6 +1303,10 @@ struct PushDownUnPackThroughPadOp : public OpRewritePattern { if (!unpackOp) return failure(); + // TODO: Support Memref UnPackOp. Temporarily return failure. + if (!unpackOp.hasPureTensorSemantics()) + return failure(); + if (!controlFn(&padOp.getSourceMutable())) return failure(); @@ -1313,7 +1346,7 @@ struct PushDownUnPackThroughPadOp : public OpRewritePattern { tensor::EmptyOp::create(rewriter, loc, padOp.getResultType().getShape(), padOp.getResultType().getElementType()); - Value replacement = linalg::UnPackOp::create( + UnPackOp replacement = linalg::UnPackOp::create( rewriter, loc, newPadOp.getResult(), outputUnPack, innerDimsPos, unpackOp.getMixedTiles(), outerDimsPerm); rewriter.replaceOp(padOp, replacement); diff --git a/mlir/lib/Dialect/Linalg/Transforms/PackAndUnpackPatterns.cpp b/mlir/lib/Dialect/Linalg/Transforms/PackAndUnpackPatterns.cpp index 1d4c11e418006..993eae62535c3 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/PackAndUnpackPatterns.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/PackAndUnpackPatterns.cpp @@ -11,6 +11,7 @@ #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Utils/IndexingUtils.h" #include "mlir/Dialect/Utils/StaticValueUtils.h" +#include "mlir/IR/BuiltinTypeInterfaces.h" #include "mlir/IR/PatternMatch.h" namespace mlir { @@ -110,8 +111,11 @@ struct SimplifyPackToExpandShape : public OpRewritePattern { PatternRewriter &rewriter) const override { if (packOp.getPaddingValue()) return rewriter.notifyMatchFailure(packOp, "expects no padding value"); + // TODO: Support Memref PackOp. Temporarily return failure. + if (!packOp.hasPureTensorSemantics()) + return failure(); - RankedTensorType sourceType = packOp.getSourceType(); + ShapedType sourceType = packOp.getSourceType(); if (failed(isPackOnInnerMostDim(rewriter, packOp)) && failed(isPackOn1D(rewriter, packOp, sourceType.getShape(), packOp.getStaticTiles())) && @@ -119,7 +123,7 @@ struct SimplifyPackToExpandShape : public OpRewritePattern { return failure(); } - RankedTensorType destType = packOp.getDestType(); + ShapedType destType = packOp.getDestType(); auto reassociation = getReassociationIndicesForReshape(sourceType, destType); if (!reassociation) @@ -157,8 +161,8 @@ struct SimplifyUnPackToCollapseShape : public OpRewritePattern { "expects outer_dims_perm is empty or an identity permutation"); } - RankedTensorType sourceType = unpackOp.getSourceType(); - RankedTensorType destType = unpackOp.getDestType(); + ShapedType sourceType = unpackOp.getSourceType(); + ShapedType destType = unpackOp.getDestType(); if (!sourceType.hasStaticShape() || !destType.hasStaticShape()) return rewriter.notifyMatchFailure(unpackOp, "expects static shapes"); @@ -173,7 +177,11 @@ struct SimplifyUnPackToCollapseShape : public OpRewritePattern { LogicalResult matchAndRewrite(UnPackOp unpackOp, PatternRewriter &rewriter) const override { - RankedTensorType destType = unpackOp.getDestType(); + // TODO: Support Memref UnPackOp. Temporarily return failure. + if (!unpackOp.hasPureTensorSemantics()) + return failure(); + + ShapedType destType = unpackOp.getDestType(); if (failed(isUnpackOnInnerMostDim(rewriter, unpackOp)) && failed(isPackOn1D(rewriter, unpackOp, destType.getShape(), unpackOp.getStaticTiles())) && @@ -181,7 +189,7 @@ struct SimplifyUnPackToCollapseShape : public OpRewritePattern { return failure(); } - RankedTensorType sourceType = unpackOp.getSourceType(); + ShapedType sourceType = unpackOp.getSourceType(); auto reassociation = getReassociationIndicesForReshape(sourceType, destType); if (!reassociation) @@ -225,7 +233,7 @@ struct FoldPadWithPackOp : public OpRewritePattern { // sizes - that is because it would be impossible to compute the padding // size and hence to establish whether "artificial" padding would be // created. - RankedTensorType unpackedType = packOp.getSourceType(); + ShapedType unpackedType = packOp.getSourceType(); SmallVector outerShapeWithoutTranspose = getPackedOuterShapeWithoutTransposition(packOp); for (auto [pos, tileSize, high] : @@ -274,6 +282,10 @@ struct FoldUnpackWithExtractSliceOp if (!unpackOp) return failure(); + // TODO: Support Memref UnPackOp. Temporarily return failure. + if (!unpackOp.hasPureTensorSemantics()) + return failure(); + // User controlled folding function. if (controlFn && !controlFn(&sliceOp.getSourceMutable())) return failure(); @@ -336,6 +348,10 @@ struct FoldProducerPackWithConsumerLinalgTransposeOp if (!packOp) return failure(); + // TODO: Support Memref PackOp. Temporarily return failure. + if (!packOp.hasPureTensorSemantics()) + return failure(); + // User controlled folding function. if (controlFn && !controlFn(&linalgOp->getOpOperand(0))) return failure(); @@ -395,6 +411,10 @@ struct FoldConsumerPackWithProducerLinalgTransposeOp LogicalResult matchAndRewrite(PackOp packOp, PatternRewriter &rewriter) const override { + // TODO: Support Memref PackOp. Temporarily return failure. + if (!packOp.hasPureTensorSemantics()) + return failure(); + auto linalgOp = packOp.getSource().getDefiningOp(); if (!linalgOp) return failure(); @@ -456,6 +476,10 @@ struct FoldProducerUnPackWithConsumerLinalgTransposeOp if (!unPackOp) return failure(); + // TODO: Support Memref UnPackOp. Temporarily return failure. + if (!unPackOp.hasPureTensorSemantics()) + return failure(); + // User controlled folding function. if (controlFn && !controlFn(&linalgOp->getOpOperand(0))) return failure(); @@ -504,6 +528,10 @@ struct FoldConsumerUnPackWithProducerLinalgTransposeOp LogicalResult matchAndRewrite(UnPackOp unPackOp, PatternRewriter &rewriter) const override { + // TODO: Support Memref UnPackOp. Temporarily return failure. + if (!unPackOp.hasPureTensorSemantics()) + return failure(); + auto linalgOp = unPackOp.getSource().getDefiningOp(); if (!linalgOp) return failure(); @@ -568,6 +596,10 @@ struct FoldEmptyTensorWithPackOp : public OpRewritePattern { LogicalResult matchAndRewrite(PackOp packOp, PatternRewriter &rewriter) const override { + // TODO: Support Memref PackOp. Temporarily return failure. + if (!packOp.hasPureTensorSemantics()) + return failure(); + // Check for tensor.empty source. auto emptyOp = packOp.getSource().getDefiningOp(); if (!emptyOp) @@ -592,6 +624,10 @@ struct FoldEmptyTensorWithUnPackOp : public OpRewritePattern { LogicalResult matchAndRewrite(UnPackOp unPackOp, PatternRewriter &rewriter) const override { + // TODO: Support Memref UnPackOp. Temporarily return failure. + if (!unPackOp.hasPureTensorSemantics()) + return failure(); + // Check for tensor.empty source. auto emptyOp = unPackOp.getSource().getDefiningOp(); if (!emptyOp) diff --git a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp index 8a0440bcc6fb9..10e9fbd22a41d 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp @@ -740,6 +740,10 @@ struct PackOpTiling ArrayRef offsets, ArrayRef sizes) const { auto packOp = cast(op); + // TODO: Support Memref PackOp. Temporarily return failure. + if (!packOp.hasPureTensorSemantics()) + return failure(); + Location loc = packOp.getLoc(); // The tiling is applied on interchanged dimensions. We have to undo the @@ -984,6 +988,10 @@ struct PackOpTiling ArrayRef sizes(allSizes[0]); auto packOp = cast(op); + // TODO: Support Memref UnPackOp. Temporarily return failure. + if (!packOp.hasPureTensorSemantics()) + return failure(); + Location loc = packOp.getLoc(); int64_t inputRank = packOp.getSourceRank(); @@ -1163,6 +1171,10 @@ struct UnPackOpTiling ArrayRef offsets, ArrayRef sizes) const { auto unpackOp = cast(op); + // TODO: Support Memref UnPackOp. Temporarily return failure. + if (!unpackOp.hasPureTensorSemantics()) + return failure(); + int64_t srcRank = unpackOp.getSourceRank(); int64_t destRank = unpackOp.getDestRank(); int64_t numInnerTiles = srcRank - destRank; @@ -1333,6 +1345,10 @@ struct UnPackOpTiling return failure(); } auto unPackOp = cast(op); + // TODO: Support Memref UnPackOp. Temporarily return failure. + if (!unPackOp.hasPureTensorSemantics()) + return failure(); + ArrayRef offsets(allOffsets[0]); ArrayRef sizes(allSizes[0]); diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp index 027268cc20ddd..821b9d229ac42 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp @@ -26,6 +26,8 @@ #include "mlir/Dialect/Utils/StructuredOpsUtils.h" #include "mlir/Dialect/Vector/IR/VectorOps.h" #include "mlir/IR/AffineExpr.h" +#include "mlir/IR/BuiltinTypeInterfaces.h" +#include "mlir/IR/BuiltinTypes.h" #include "mlir/Support/LLVM.h" #include "llvm/ADT/TypeSwitch.h" #include "llvm/Support/Debug.h" @@ -217,6 +219,10 @@ struct PackedOperandsDimList { FailureOr linalg::lowerPack(RewriterBase &rewriter, linalg::PackOp packOp, bool lowerPadLikeWithInsertSlice) { + // TODO: Support Memref PackOp. Temporarily return failure. + if (!packOp.hasPureTensorSemantics()) + return failure(); + // 1. Filter out NYI cases. auto packedTensorType = cast(packOp->getResultTypes().front()); @@ -344,11 +350,15 @@ FailureOr linalg::lowerPack(RewriterBase &rewriter, FailureOr linalg::lowerUnPack(RewriterBase &rewriter, linalg::UnPackOp unPackOp, bool lowerUnpadLikeWithExtractSlice) { + // TODO: Support Memref UnPackOp. Temporarily return failure. + if (!unPackOp.hasPureTensorSemantics()) + return failure(); + Location loc = unPackOp->getLoc(); OpBuilder::InsertionGuard g(rewriter); rewriter.setInsertionPoint(unPackOp); - RankedTensorType packedTensorType = unPackOp.getSourceType(); + auto packedTensorType = cast(unPackOp.getSourceType()); int64_t packedRank = packedTensorType.getRank(); OpFoldResult zero = rewriter.getIndexAttr(0), one = rewriter.getIndexAttr(1); @@ -548,7 +558,7 @@ FailureOr linalg::pack(RewriterBase &rewriter, packOps.push_back(linalg::PackOp::create( rewriter, loc, operand, dest, innerPos, innerPackSizes, zero)); } - inputsAndInits.push_back(packOps.back()); + inputsAndInits.push_back(packOps.back().getResult()); } } @@ -575,7 +585,7 @@ FailureOr linalg::pack(RewriterBase &rewriter, unPackOps.push_back(linalg::UnPackOp::create( rewriter, packedLinalgOp->getLoc(), result, maybePackedInit.getSource(), maybePackedInit.getInnerDimsPos(), maybePackedInit.getMixedTiles())); - results.push_back(unPackOps.back()); + results.push_back(unPackOps.back().getResult()); } // Step 5. Replace `linalgOp`. @@ -665,7 +675,7 @@ linalg::packTranspose(RewriterBase &rewriter, linalg::PackOp packOp, linalg::PackOp transposedPackOp = packOp.createTransposedClone(rewriter, loc, innerPerm, outerPerm); - if (!packOp.getResult().hasOneUse()) + if (packOp.getNumResults() == 0 || !packOp.getResult().hasOneUse()) return rewriter.notifyMatchFailure(linalgOp, "expect single pack use"); OpOperand &packUse = *packOp->getUses().begin(); @@ -726,7 +736,10 @@ linalg::packTranspose(RewriterBase &rewriter, linalg::PackOp packOp, } // Step 4. Finally, replace packOp now that we don't need it anymore. - rewriter.replaceOp(packOp, transposedPackOp->getResults()); + if (packOp.getNumResults() != 0) + rewriter.replaceOp(packOp, transposedPackOp->getResults()); + else + rewriter.eraseOp(packOp); return PackTransposeResult{transposedPackOp, transposedLinalgOp, transposedUnPackOp}; @@ -1018,6 +1031,10 @@ LogicalResult ExtractSliceOfPadTensorSwapPattern::matchAndRewrite( static Value getPackOpSourceOrPaddedSource(OpBuilder &builder, linalg::PackOp packOp) { Value input = packOp.getSource(); + // TODO: Support Memref PackOp. Temporarily return just Op Source. + if (!packOp.hasPureTensorSemantics()) + return input; + if (!packOp.getPaddingValue()) { return input; } @@ -1134,6 +1151,10 @@ getPackUnpackRankReducedPerm(ArrayRef shape, LogicalResult DecomposeOuterUnitDimsPackOpPattern::matchAndRewrite( linalg::PackOp packOp, PatternRewriter &rewriter) const { + // TODO: Support Memref PackOp. Temporarily return failure. + if (!packOp.hasPureTensorSemantics()) + return failure(); + if (llvm::any_of(packOp.getTiledOuterDims(), [](int64_t dim) { return dim != 1; })) { return rewriter.notifyMatchFailure( @@ -1155,8 +1176,8 @@ LogicalResult DecomposeOuterUnitDimsPackOpPattern::matchAndRewrite( // Check whether this dim has been permuted. Permuting unit dims is fine // as that's effectively a no-op. - if (dim < prev && (packOp.getType().getShape()[prev] != 1 || - packOp.getType().getShape()[dim] != 1)) + if (dim < prev && (packOp.getResult().getType().getShape()[prev] != 1 || + packOp.getResult().getType().getShape()[dim] != 1)) return false; prev = dim; @@ -1279,6 +1300,9 @@ LogicalResult DecomposeOuterUnitDimsPackOpPattern::matchAndRewrite( LogicalResult DecomposeOuterUnitDimsUnPackOpPattern::matchAndRewrite( linalg::UnPackOp unpackOp, PatternRewriter &rewriter) const { + if (!unpackOp.hasPureTensorSemantics()) + return failure(); + int64_t srcRank = unpackOp.getSourceRank(); int64_t destRank = unpackOp.getDestRank(); ArrayRef srcShape = unpackOp.getSourceType().getShape(); diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp index dcf84c46949f3..44c799d97e8aa 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp @@ -1877,8 +1877,9 @@ vectorizeAsTensorPackOp(RewriterBase &rewriter, linalg::PackOp packOp, SmallVector preTransposeWriteVecSizses(writeVectorSizes); auto destInvPermutation = getPackInverseDestPerm(packOp, packMetadata); applyPermutationToVector(preTransposeWriteVecSizses, destInvPermutation); - auto preTransposeWriteVecType = VectorType::get( - preTransposeWriteVecSizses, packOp.getType().getElementType()); + auto preTransposeWriteVecType = + VectorType::get(preTransposeWriteVecSizses, + packOp.getResult().getType().getElementType()); // Compute vector type for the _read_ opeartion. This is simply // pre-transpose-write-vector-type with the dimensions collapsed @@ -1954,7 +1955,7 @@ vectorizeAsTensorUnpackOp(RewriterBase &rewriter, linalg::UnPackOp unpackOp, OpBuilder::InsertionGuard g(rewriter); rewriter.setInsertionPoint(unpackOp); - RankedTensorType unpackTensorType = unpackOp.getSourceType(); + ShapedType unpackTensorType = unpackOp.getSourceType(); ArrayRef sourceShape = unpackTensorType.getShape(); bool useInBoundsInsteadOfMasking = false; @@ -2117,6 +2118,10 @@ vectorizeDynamicLinalgOpPrecondition(linalg::LinalgOp op, static LogicalResult vectorizeUnPackOpPrecondition(linalg::UnPackOp unpackOp, ArrayRef inputVectorSizes) { + // TODO: Support Memref UnPackOp. Temporarily return failure. + if (!unpackOp.hasPureTensorSemantics()) + return failure(); + // If there are no input vector sizes and all shapes are static, there is // nothing left to check. if (inputVectorSizes.empty() && unpackOp.getDestType().hasStaticShape() && @@ -2454,6 +2459,10 @@ static LogicalResult vectorizeLinalgOpPrecondition( static LogicalResult vectorizePackOpPrecondition(linalg::PackOp packOp, ArrayRef inputVectorSizes) { + // TODO: Support Memref PackOp. Temporarily return failure. + if (!packOp.hasPureTensorSemantics()) + return failure(); + auto padValue = packOp.getPaddingValue(); Attribute cstAttr; // TODO: Relax this condiiton diff --git a/mlir/test/Dialect/Linalg/canonicalize.mlir b/mlir/test/Dialect/Linalg/canonicalize.mlir index f4020ede4854e..1f5995a680f71 100644 --- a/mlir/test/Dialect/Linalg/canonicalize.mlir +++ b/mlir/test/Dialect/Linalg/canonicalize.mlir @@ -2058,3 +2058,62 @@ func.func @no_fold_extract_slice_into_unpack_non_zero_offset( // CHECK-SAME: into %[[DEST]] // CHECK: %[[SLICE:.+]] = tensor.extract_slice %[[UNPACK]] // CHECK: return %[[SLICE]] + +// ----- + +// CHECK-LABEL: func.func @fold_cast_unpack_dynamic_tile_size( +// CHECK-SAME: %[[SRC:.*]]: tensor<1x1x8x1xi32>, +// CHECK-SAME: %[[DEST:.*]]: tensor<7x?xi32>) -> tensor<7x?xi32> { +// CHECK: %[[RES:.*]] = linalg.unpack %[[SRC]] inner_dims_pos = [0, 1] inner_tiles = [8, 1] into %[[DEST]] {test_attr} : tensor<1x1x8x1xi32> -> tensor<7x?xi32> +// CHECK: return %[[RES]] : tensor<7x?xi32> +func.func @fold_cast_unpack_dynamic_tile_size( + %src: tensor<1x1x8x1xi32>, + %res: tensor<7x?xi32>) -> tensor<7x?xi32> { + + %cast = tensor.cast %src : tensor<1x1x8x1xi32> to tensor<1x1x?x1xi32> + %c8 = arith.constant 8 : index + %unpack = linalg.unpack %cast + inner_dims_pos = [0, 1] + inner_tiles = [%c8, 1] + into %res {test_attr} : tensor<1x1x?x1xi32> -> tensor<7x?xi32> + return %unpack : tensor<7x?xi32> +} + +// ----- + +// CHECK-LABEL: func.func @fold_pack_unpack_tensor +// CHECK-SAME: (%[[ARG0:.*]]: tensor<2x3xf32>) -> tensor<2x3xf32> +// CHECK: return %[[ARG0]] : tensor<2x3xf32> +func.func @fold_pack_unpack_tensor(%x: tensor<2x3xf32>) -> tensor<2x3xf32> { + %unpacked = linalg.unpack %x outer_dims_perm = [] inner_dims_pos = [] inner_tiles = [] + into %x : tensor<2x3xf32> -> tensor<2x3xf32> + %packed = linalg.pack %unpacked outer_dims_perm = [] inner_dims_pos = [] inner_tiles = [] + into %x : tensor<2x3xf32> -> tensor<2x3xf32> + return %packed : tensor<2x3xf32> +} + +// ----- + +// Test that pack/unpack canonicalization is disabled for memref versions +// CHECK-LABEL: func.func @pack_unpack_memref_no_canonicalization +// CHECK: linalg.pack +// CHECK: linalg.unpack +// CHECK: return +func.func @pack_unpack_memref_no_canonicalization(%source: memref<128x256xf32>, %packed: memref<16x8x8x32xf32>, %dest: memref<128x256xf32>) { + linalg.pack %source inner_dims_pos = [0, 1] inner_tiles = [8, 32] into %packed : memref<128x256xf32> -> memref<16x8x8x32xf32> + linalg.unpack %packed inner_dims_pos = [0, 1] inner_tiles = [8, 32] into %dest : memref<16x8x8x32xf32> -> memref<128x256xf32> + return +} + +// ----- + +// Test that unpack/pack canonicalization is disabled for memref versions +// CHECK-LABEL: func.func @unpack_pack_memref_no_canonicalization +// CHECK: linalg.unpack +// CHECK: linalg.pack +// CHECK: return +func.func @unpack_pack_memref_no_canonicalization(%packed: memref<16x8x8x32xf32>, %unpacked: memref<128x256xf32>, %dest: memref<16x8x8x32xf32>) { + linalg.unpack %packed inner_dims_pos = [0, 1] inner_tiles = [8, 32] into %unpacked : memref<16x8x8x32xf32> -> memref<128x256xf32> + linalg.pack %unpacked inner_dims_pos = [0, 1] inner_tiles = [8, 32] into %dest : memref<128x256xf32> -> memref<16x8x8x32xf32> + return +} \ No newline at end of file diff --git a/mlir/test/Dialect/Linalg/memref-pack-unpack.mlir b/mlir/test/Dialect/Linalg/memref-pack-unpack.mlir new file mode 100644 index 0000000000000..2701a1e3512a2 --- /dev/null +++ b/mlir/test/Dialect/Linalg/memref-pack-unpack.mlir @@ -0,0 +1,47 @@ +// RUN: mlir-opt %s -split-input-file | FileCheck %s + +// CHECK-LABEL: func @test_pack_memref +func.func @test_pack_memref(%arg0: memref<128x256xf32>, %arg1: memref<16x8x8x32xf32>) { + // CHECK-NOT: %{{.*}} = linalg.pack + // CHECK: linalg.pack %{{.*}} inner_dims_pos = [0, 1] inner_tiles = [8, 32] into %{{.*}} : memref<128x256xf32> -> memref<16x8x8x32xf32> + linalg.pack %arg0 inner_dims_pos = [0, 1] inner_tiles = [8, 32] into %arg1 : memref<128x256xf32> -> memref<16x8x8x32xf32> + return +} + +// ----- + +// CHECK-LABEL: func @test_unpack_memref +func.func @test_unpack_memref(%arg0: memref<16x8x8x32xf32>, %arg1: memref<128x256xf32>) { + // CHECK: linalg.unpack %{{.*}} inner_dims_pos = [0, 1] inner_tiles = [8, 32] into %{{.*}} : memref<16x8x8x32xf32> + linalg.unpack %arg0 inner_dims_pos = [0, 1] inner_tiles = [8, 32] into %arg1 : memref<16x8x8x32xf32> -> memref<128x256xf32> + return +} + +// ----- + +// CHECK-LABEL: func @test_pack_memref_with_padding +func.func @test_pack_memref_with_padding(%arg0: memref<127x255xf32>, %arg1: memref<16x8x8x32xf32>, %pad: f32) { + // CHECK: linalg.pack %{{.*}} padding_value(%{{.*}} : f32) inner_dims_pos = [0, 1] inner_tiles = [8, 32] into %{{.*}} : memref<127x255xf32> + linalg.pack %arg0 padding_value(%pad : f32) inner_dims_pos = [0, 1] inner_tiles = [8, 32] into %arg1 : memref<127x255xf32> -> memref<16x8x8x32xf32> + return +} + +// ----- + +// CHECK-LABEL: func @test_pack_tensor +func.func @test_pack_tensor(%arg0: tensor<128x256xf32>, %arg1: tensor<16x8x8x32xf32>) -> tensor<16x8x8x32xf32> { + // CHECK: %[[RESULT:.*]] = linalg.pack %{{.*}} inner_dims_pos = [0, 1] inner_tiles = [8, 32] into %{{.*}} : tensor<128x256xf32> -> tensor<16x8x8x32xf32> + %0 = linalg.pack %arg0 inner_dims_pos = [0, 1] inner_tiles = [8, 32] into %arg1 : tensor<128x256xf32> -> tensor<16x8x8x32xf32> + // CHECK: return %[[RESULT]] : tensor<16x8x8x32xf32> + return %0 : tensor<16x8x8x32xf32> +} + +// ----- + +// CHECK-LABEL: func @test_unpack_tensor +func.func @test_unpack_tensor(%arg0: tensor<16x8x8x32xf32>, %arg1: tensor<128x256xf32>) -> tensor<128x256xf32> { + // CHECK: %[[RESULT:.*]] = linalg.unpack %{{.*}} inner_dims_pos = [0, 1] inner_tiles = [8, 32] into %{{.*}} : tensor<16x8x8x32xf32> -> tensor<128x256xf32> + %0 = linalg.unpack %arg0 inner_dims_pos = [0, 1] inner_tiles = [8, 32] into %arg1 : tensor<16x8x8x32xf32> -> tensor<128x256xf32> + // CHECK: return %[[RESULT]] : tensor<128x256xf32> + return %0 : tensor<128x256xf32> +} \ No newline at end of file diff --git a/mlir/test/Dialect/Linalg/roundtrip.mlir b/mlir/test/Dialect/Linalg/roundtrip.mlir index 74928920c695a..71d0ffc417179 100644 --- a/mlir/test/Dialect/Linalg/roundtrip.mlir +++ b/mlir/test/Dialect/Linalg/roundtrip.mlir @@ -755,3 +755,29 @@ func.func @conv2d_channel_first_q_promote(%img: tensor<100x3x224x224xi8>, %filt: // CHECK-LABEL: func @conv2d_channel_first_q_promote( // CHECK: %[[arg0:[a-zA-z0-9]*]]: tensor<100x3x224x224xi8>, %[[arg1:[a-zA-z0-9]*]]: tensor<64x3x5x5xi8>, %[[arg2:[a-zA-z0-9]*]]: i8, %[[arg3:[a-zA-z0-9]*]]: i8) // CHECK: linalg.conv_2d_nchw_fchw_q {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%[[arg0]], %[[arg1]], %[[arg2]], %[[arg3]] : tensor<100x3x224x224xi8>, tensor<64x3x5x5xi8>, i8, i8) outs(%{{.*}} : tensor<100x64x220x220xi32>) -> tensor<100x64x220x220xi32> + +// ----- + +func.func @pack_memref(%source: memref<128x256xf32>, %dest: memref<8x16x8x32xf32>) { + linalg.pack %source outer_dims_perm = [1, 0] inner_dims_pos = [0, 1] inner_tiles = [8, 32] + into %dest : memref<128x256xf32> -> memref<8x16x8x32xf32> + return +} + +// CHECK-label: func @pack_memref( +// CHECK: %[[source:[a-zA-z0-9]*]]: memref<128x256xf32>, %[[dest:[a-zA-z0-9]*]]: memref<8x16x8x32xf32>) { +// CHECK: linalg.pack %[[source]] outer_dims_perm = [1, 0] inner_dims_pos = [0, 1] inner_tiles = [8, 32] into %[[dest]] : memref<128x256xf32> -> memref<8x16x8x32xf32> +// CHECK: return +// CHECK: } +// ----- + +func.func @unpack_memref(%source: memref<16x8x8x32xf32>, %dest: memref<128x256xf32>) { + linalg.unpack %source inner_dims_pos = [0, 1] inner_tiles = [8, 32] + into %dest : memref<16x8x8x32xf32> -> memref<128x256xf32> + return +} + +// CHECK-label: func @unpack_memref( +// CHECK: %[[source:[a-zA-z0-9]*]]: memref<16x8x8x32xf32>, %[[dest:[a-zA-z0-9]*]]: memref<128x256xf32>) { +// CHECK: linalg.unpack %[[source]] inner_dims_pos = [0, 1] inner_tiles = [8, 32] into %[[dest]] : memref<16x8x8x32xf32> -> memref<128x256xf32> +// CHECK: return \ No newline at end of file