Skip to content

Conversation

@sakupan102
Copy link
Contributor

Extend linalg.pack and linalg.unpack to accept memref operands in addition to tensors. As part of this change, we now disable all transformations when these ops have memref semantics.

Closes #129004

@github-actions
Copy link

Thank you for submitting a Pull Request (PR) to the LLVM Project!

This PR will be automatically labeled and the relevant teams will be notified.

If you wish to, you can add reviewers by using the "Reviewers" section on this page.

If this is not working for you, it is probably because you do not have write permissions for the repository. In which case you can instead tag reviewers by name in a comment by using @ followed by their GitHub username.

If you have received no comments on your PR for a week, you can request a review by "ping"ing the PR by adding a comment “Ping”. The common courtesy "ping" rate is once a week. Please remember that you are asking for valuable time from other developers.

If you have further questions, they may be answered by the LLVM GitHub User Guide.

You can also ask questions in a comment on this PR, on the LLVM Discord or on the forums.

@llvmbot
Copy link
Member

llvmbot commented Nov 12, 2025

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-linalg

Author: Ryutaro Okada (sakupan102)

Changes

Extend linalg.pack and linalg.unpack to accept memref operands in addition to tensors. As part of this change, we now disable all transformations when these ops have memref semantics.

Closes #129004


Patch is 68.30 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/167675.diff

11 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Linalg/IR/LinalgRelayoutOps.td (+46-47)
  • (modified) mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp (+431-48)
  • (modified) mlir/lib/Dialect/Linalg/Transforms/BlockPackMatmul.cpp (+4)
  • (modified) mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp (+51-18)
  • (modified) mlir/lib/Dialect/Linalg/Transforms/PackAndUnpackPatterns.cpp (+43-7)
  • (modified) mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp (+16)
  • (modified) mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp (+31-7)
  • (modified) mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp (+12-3)
  • (modified) mlir/test/Dialect/Linalg/canonicalize.mlir (+59)
  • (added) mlir/test/Dialect/Linalg/memref-pack-unpack.mlir (+47)
  • (modified) mlir/test/Dialect/Linalg/roundtrip.mlir (+26)
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgRelayoutOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgRelayoutOps.td
index 6504ca8664d49..6a47fe43adf90 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgRelayoutOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgRelayoutOps.td
@@ -30,23 +30,26 @@ include "mlir/IR/OpAsmInterface.td"
 // RelayoutOp
 //===----------------------------------------------------------------------===//
 
-class Linalg_RelayoutOp<string mnemonic, list<Trait> traits = []> :
-      Op<Linalg_Dialect, mnemonic, !listconcat(traits, [
-        DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
-        DestinationStyleOpInterface, LinalgRelayoutOpInterface,
-        ConditionallySpeculatable, NoMemoryEffect,
-        DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface>,
-        TypesMatchWith<"result type matches type of dest",
-                   "dest", "result",
-                   "$_self">])> {
+class Linalg_RelayoutOp<string mnemonic, list<Trait> traits = []>
+    : Op<Linalg_Dialect, mnemonic,
+         !listconcat(
+             traits, [DeclareOpInterfaceMethods<
+                          OpAsmOpInterface, ["getAsmResultNames"]>,
+                      DestinationStyleOpInterface, LinalgRelayoutOpInterface,
+                      ConditionallySpeculatable,
+                      DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
+                      DeclareOpInterfaceMethods<
+                          ReifyRankedShapedTypeOpInterface>,
+                      OptionalTypesMatchWith<"result type matches type of dest",
+                                             "dest", "result", "$_self">])> {
 
   code commonExtraClassDeclaration = [{
     size_t getSourceRank() { return getSourceType().getRank(); };
     size_t getDestRank() { return getDestType().getRank(); };
-    RankedTensorType getSourceType() {
-      return ::llvm::cast<RankedTensorType>(getSource().getType()); };
-    RankedTensorType getDestType() {
-      return ::llvm::cast<RankedTensorType>(getDest().getType()); };
+    ShapedType getSourceType() {
+      return ::llvm::cast<ShapedType>(getSource().getType()); };
+    ShapedType getDestType() {
+      return ::llvm::cast<ShapedType>(getDest().getType()); };
 
     MutableOperandRange getDpsInitsMutable() { return getDestMutable(); }
 
@@ -191,23 +194,12 @@ def Linalg_PackOp : Linalg_RelayoutOp<"pack", [
     //            expect tensor<2x8xf32> because CeilDiv(9, 8) = 2
     ```
   }];
-  let arguments = (ins AnyRankedTensor:$source,
-                       AnyRankedTensor:$dest,
-                       Optional<AnyType>:$padding_value,
-                       DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$outer_dims_perm,
-                       DenseI64ArrayAttr:$inner_dims_pos,
-                       Variadic<Index>:$inner_tiles,
-                       DenseI64ArrayAttr:$static_inner_tiles);
-  let results = (outs AnyRankedTensor:$result);
-  let assemblyFormat = [{
-    $source
-    (`padding_value` `(` $padding_value^ `:` type($padding_value) `)`)?
-    (`outer_dims_perm` `=` $outer_dims_perm^)?
-    `inner_dims_pos` `=` $inner_dims_pos
-    `inner_tiles` `=`
-    custom<DynamicIndexList>($inner_tiles, $static_inner_tiles)
-    `into` $dest attr-dict `:` type($source) `->` type($dest)
-  }];
+  let arguments = (ins TensorOrMemRef<[AnyType]>:$source,
+      TensorOrMemRef<[AnyType]>:$dest, Optional<AnyType>:$padding_value,
+      DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$outer_dims_perm,
+      DenseI64ArrayAttr:$inner_dims_pos, Variadic<Index>:$inner_tiles,
+      DenseI64ArrayAttr:$static_inner_tiles);
+  let results = (outs Optional<AnyRankedTensor>:$result);
 
   let builders = [
     OpBuilder<(ins "Value":$source, "Value":$dest,
@@ -217,7 +209,7 @@ def Linalg_PackOp : Linalg_RelayoutOp<"pack", [
       CArg<"ArrayRef<int64_t>", "{}">:$outerDimsPerm)>
   ];
 
-  let extraClassDeclaration = commonExtraClassDeclaration # [{
+  let extraClassDeclaration = commonExtraClassDeclaration#[{
     // Method to get the shape of the result as `SmallVector<OpFoldResult>`.
     // This is a static method to allow getting the shape of the destination
     // expected while creating a `pack` op.
@@ -229,7 +221,19 @@ def Linalg_PackOp : Linalg_RelayoutOp<"pack", [
     // Method to get the `RankedTensorType` of the result based on the inner
     // tiles, position of the inner tiles (innerDimsPos)  and interchange vector
     // of outer loops (outerDimsPerm).
-    static RankedTensorType inferPackedType(RankedTensorType sourceType,
+    static RankedTensorType inferPackedTensorType(RankedTensorType sourceType,
+        ArrayRef<int64_t> innerTileSizes, ArrayRef<int64_t> innerDimsPos,
+        ArrayRef<int64_t> outerDimsPerm = {});
+
+    // Method to get the `MemRefType` of the result based on the inner
+    // tiles, position of the inner tiles (innerDimsPos)  and interchange vector
+    // of outer loops (outerDimsPerm).
+    static MemRefType inferPackedMemRefType(MemRefType sourceType,
+        ArrayRef<int64_t> innerTileSizes, ArrayRef<int64_t> innerDimsPos,
+        ArrayRef<int64_t> outerDimsPerm = {});
+
+    // Returns the shape of the packed type. It is a shared helper helps type inference methods in a way that ensures  that they agree on which dimensions are dynamic.
+    static SmallVector<int64_t> inferPackedShape(ArrayRef<int64_t> inputShape,
         ArrayRef<int64_t> innerTileSizes, ArrayRef<int64_t> innerDimsPos,
         ArrayRef<int64_t> outerDimsPerm = {});
 
@@ -281,6 +285,8 @@ def Linalg_PackOp : Linalg_RelayoutOp<"pack", [
   let hasCanonicalizeMethod = 1;
 
   let hasFolder = 1;
+
+  let hasCustomAssemblyFormat = 1;
 }
 
 //===----------------------------------------------------------------------===//
@@ -348,21 +354,12 @@ def Linalg_UnPackOp : Linalg_RelayoutOp<"unpack"> {
     //          Outer Dims: 9x3x8   Inner Dims: 4x2
     ```
   }];
-  let arguments = (ins AnyRankedTensor:$source,
-                       AnyRankedTensor:$dest,
-                       DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$outer_dims_perm,
-                       DenseI64ArrayAttr:$inner_dims_pos,
-                       Variadic<Index>:$inner_tiles,
-                       DenseI64ArrayAttr:$static_inner_tiles);
-  let results = (outs AnyRankedTensor:$result);
-  let assemblyFormat = [{
-    $source
-    (`outer_dims_perm` `=` $outer_dims_perm^)?
-    `inner_dims_pos` `=` $inner_dims_pos
-    `inner_tiles` `=`
-    custom<DynamicIndexList>($inner_tiles, $static_inner_tiles)
-    `into` $dest attr-dict `:` type($source) `->` type($dest)
-  }];
+  let arguments = (ins TensorOrMemRef<[AnyType]>:$source,
+      TensorOrMemRef<[AnyType]>:$dest,
+      DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$outer_dims_perm,
+      DenseI64ArrayAttr:$inner_dims_pos, Variadic<Index>:$inner_tiles,
+      DenseI64ArrayAttr:$static_inner_tiles);
+  let results = (outs Optional<AnyRankedTensor>:$result);
 
   let builders = [
     OpBuilder<(ins "Value":$source, "Value":$dest,
@@ -405,6 +402,8 @@ def Linalg_UnPackOp : Linalg_RelayoutOp<"unpack"> {
   let hasCanonicalizeMethod = 1;
 
   let hasFolder = 1;
+
+  let hasCustomAssemblyFormat = 1;
 }
 
 #endif // LINALG_RELEAYOUT_OPS
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 3dc45edf4a23f..d268ddd613829 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -4968,12 +4968,12 @@ Speculation::Speculatability ElementwiseOp::getSpeculatability() {
 template <typename OpTy, typename>
 SmallVector<int64_t>
 getPackedOuterShapeWithoutTransposition(OpTy packOrUnPack) {
-  RankedTensorType packedType = (std::is_same<OpTy, PackOp>::value)
-                                    ? packOrUnPack.getDestType()
-                                    : packOrUnPack.getSourceType();
-  RankedTensorType unpackedType = (std::is_same<OpTy, PackOp>::value)
-                                      ? packOrUnPack.getSourceType()
-                                      : packOrUnPack.getDestType();
+  ShapedType packedType = (std::is_same<OpTy, PackOp>::value)
+                              ? packOrUnPack.getDestType()
+                              : packOrUnPack.getSourceType();
+  ShapedType unpackedType = (std::is_same<OpTy, PackOp>::value)
+                                ? packOrUnPack.getSourceType()
+                                : packOrUnPack.getDestType();
   SmallVector<int64_t> result(
       packedType.getShape().take_front(unpackedType.getRank()));
   if (!packOrUnPack.getOuterDimsPerm().empty()) {
@@ -5107,15 +5107,34 @@ static LogicalResult commonVerifierPackAndUnPackOp(OpTy packOrUnPack) {
     return llvm::any_of(tiles, isZeroInteger);
   };
 
+  // Verify that the source and destination are ranked types.
+  if (!packOrUnPack.getSourceType().hasRank() ||
+      !packOrUnPack.getDestType().hasRank()) {
+    return op->emitError("expected both source and destination to have rank");
+  }
+
+  // Verify that the Operation does not have mixed tensor/buffer semantics.
+  if (!packOrUnPack.hasPureBufferSemantics() &&
+      !packOrUnPack.hasPureTensorSemantics()) {
+    return op->emitError("mixing tensor and buffer semantics is not allowed");
+  }
+  const unsigned numResults = packOrUnPack.getNumResults();
+  if (packOrUnPack.hasPureTensorSemantics() && numResults != 1) {
+    return op->emitError("expected 1 result, got ") << numResults;
+  }
+  if (packOrUnPack.hasPureBufferSemantics() && numResults != 0) {
+    return op->emitError("expected 0 results, got ") << numResults;
+  }
+
   // Verify tiles. Do not allow zero tiles.
   SmallVector<OpFoldResult> mixedTiles = packOrUnPack.getMixedTiles();
   if (hasZeros(mixedTiles))
     return op->emitError("invalid zero tile factor");
 
   // Verify inner_dims_pos and outer_dims_perm.
-  RankedTensorType unpackedType = (std::is_same<OpTy, PackOp>::value)
-                                      ? packOrUnPack.getSourceType()
-                                      : packOrUnPack.getDestType();
+  ShapedType unpackedType = (std::is_same<OpTy, PackOp>::value)
+                                ? packOrUnPack.getSourceType()
+                                : packOrUnPack.getDestType();
   size_t unpackedRank = unpackedType.getRank();
   ArrayRef<int64_t> innerDimsPos = packOrUnPack.getInnerDimsPos();
   ArrayRef<int64_t> outerDimPerm = packOrUnPack.getOuterDimsPerm();
@@ -5152,8 +5171,9 @@ static LogicalResult commonVerifierPackAndUnPackOp(OpTy packOrUnPack) {
   // Verify result shape is greater than the minimum expected
   // by the pack operation, and that the output shape
   // represents full tiles.
-  RankedTensorType expectedPackedType = PackOp::inferPackedType(
-      unpackedType, packOrUnPack.getStaticTiles(), innerDimsPos, outerDimPerm);
+  SmallVector<int64_t> expectedPackedShape = PackOp::inferPackedShape(
+      unpackedType.getShape(), packOrUnPack.getStaticTiles(),
+      packOrUnPack.getInnerDimsPos(), packOrUnPack.getOuterDimsPerm());
   if (!llvm::all_of(
           llvm::zip(packedType.getShape().take_back(mixedTiles.size()),
                     mixedTiles),
@@ -5170,11 +5190,20 @@ static LogicalResult commonVerifierPackAndUnPackOp(OpTy packOrUnPack) {
     return op->emitError("mismatch in inner tile sizes specified and shaped of "
                          "tiled dimension in the packed type");
   }
-  if (failed(verifyCompatibleShape(expectedPackedType.getShape(),
-                                   packedType.getShape()))) {
+  if (failed(
+          verifyCompatibleShape(expectedPackedShape, packedType.getShape()))) {
+    auto elementType = unpackedType.getElementType();
+    Type expectedType, actualType;
+    if (packOrUnPack.hasPureTensorSemantics()) {
+      expectedType = RankedTensorType::get(expectedPackedShape, elementType);
+      actualType = RankedTensorType::get(packedType.getShape(), elementType);
+    } else {
+      expectedType = MemRefType::get(expectedPackedShape, elementType);
+      actualType = MemRefType::get(packedType.getShape(), elementType);
+    }
     return op->emitError("expected ")
-           << expectedPackedType << " for the packed domain value, got "
-           << packedType;
+           << expectedType << " for the packed domain value, got "
+           << actualType;
   }
   return success();
 }
@@ -5235,9 +5264,158 @@ commonPermutationOfPackAndUnPackOp(OpTy packOrUnPackOp,
 //===----------------------------------------------------------------------===//
 
 void PackOp::getAsmResultNames(function_ref<void(Value, StringRef)> setNameFn) {
+  if (getNumResults() == 0)
+    return;
   setNameFn(getResult(), "pack");
 }
 
+ParseResult PackOp::parse(OpAsmParser &parser, OperationState &result) {
+  OpAsmParser::UnresolvedOperand source, dest;
+  SmallVector<OpAsmParser::UnresolvedOperand> dynamicTiles;
+  SmallVector<OpAsmParser::UnresolvedOperand> paddingValue;
+  SmallVector<Type> paddingValueType;
+  SmallVector<int64_t> staticTiles;
+  DenseI64ArrayAttr innerDimsPos, outerDimsPerm;
+  Type sourceType, destType, resultType;
+
+  if (parser.parseOperand(source))
+    return failure();
+
+  if (succeeded(parser.parseOptionalKeyword("padding_value"))) {
+    if (parser.parseLParen() ||
+        parser.parseOperandList(paddingValue, /*requiredOperandCount=*/1) ||
+        parser.parseColon() || parser.parseTypeList(paddingValueType) ||
+        parser.parseRParen())
+      return failure();
+  }
+
+  if (succeeded(parser.parseOptionalKeyword("outer_dims_perm"))) {
+    if (parser.parseEqual())
+      return failure();
+
+    SmallVector<int64_t> outerDimsPermVec;
+    if (parser.parseCommaSeparatedList(AsmParser::Delimiter::Square, [&]() {
+          int64_t value;
+          if (parser.parseInteger(value))
+            return failure();
+          outerDimsPermVec.push_back(value);
+          return success();
+        }))
+      return failure();
+    outerDimsPerm = parser.getBuilder().getDenseI64ArrayAttr(outerDimsPermVec);
+  }
+
+  if (parser.parseKeyword("inner_dims_pos") || parser.parseEqual())
+    return failure();
+
+  SmallVector<int64_t> innerDimsPosVec;
+  if (parser.parseCommaSeparatedList(AsmParser::Delimiter::Square, [&]() {
+        int64_t value;
+        if (parser.parseInteger(value))
+          return failure();
+        innerDimsPosVec.push_back(value);
+        return success();
+      }))
+    return failure();
+  innerDimsPos = parser.getBuilder().getDenseI64ArrayAttr(innerDimsPosVec);
+
+  if (parser.parseKeyword("inner_tiles") || parser.parseEqual())
+    return failure();
+
+  DenseI64ArrayAttr staticTilesAttr;
+  if (parseDynamicIndexList(parser, dynamicTiles, staticTilesAttr))
+    return failure();
+  for (auto val : staticTilesAttr.asArrayRef())
+    staticTiles.push_back(val);
+
+  if (parser.parseKeyword("into") || parser.parseOperand(dest))
+    return failure();
+
+  if (parser.parseOptionalAttrDict(result.attributes))
+    return failure();
+
+  if (parser.parseColon() || parser.parseType(sourceType))
+    return failure();
+
+  bool hasArrow = succeeded(parser.parseOptionalArrow());
+  if (hasArrow) {
+    if (parser.parseType(destType))
+      return failure();
+  }
+
+  bool isMemRef = llvm::isa<MemRefType>(sourceType);
+  if (!hasArrow) {
+    return parser.emitError(parser.getCurrentLocation(),
+                            "pack/unpack requires '->' and destination type");
+  }
+
+  if (!isMemRef) {
+    resultType = destType;
+  }
+
+  if (parser.resolveOperand(source, sourceType, result.operands) ||
+      parser.resolveOperand(dest, destType, result.operands))
+    return failure();
+
+  if (!paddingValue.empty() &&
+      parser.resolveOperands(paddingValue, paddingValueType[0],
+                             result.operands))
+    return failure();
+
+  if (!dynamicTiles.empty() &&
+      parser.resolveOperands(dynamicTiles, parser.getBuilder().getIndexType(),
+                             result.operands))
+    return failure();
+
+  result.addAttribute("static_inner_tiles",
+                      parser.getBuilder().getDenseI64ArrayAttr(staticTiles));
+  result.addAttribute("inner_dims_pos", innerDimsPos);
+  if (outerDimsPerm)
+    result.addAttribute("outer_dims_perm", outerDimsPerm);
+
+  SmallVector<int32_t> segmentSizes = {
+      1, 1, static_cast<int32_t>(paddingValue.size()),
+      static_cast<int32_t>(dynamicTiles.size())};
+  result.addAttribute("operandSegmentSizes",
+                      parser.getBuilder().getDenseI32ArrayAttr(segmentSizes));
+
+  if (!isMemRef)
+    result.addTypes(resultType);
+
+  return success();
+}
+
+void PackOp::print(OpAsmPrinter &p) {
+  p << " " << getSource();
+
+  if (getPaddingValue()) {
+    p << " padding_value(" << getPaddingValue() << " : "
+      << getPaddingValue().getType() << ")";
+  }
+
+  if (!getOuterDimsPerm().empty()) {
+    p << " outer_dims_perm = [";
+    llvm::interleaveComma(getOuterDimsPerm(), p);
+    p << "]";
+  }
+
+  p << " inner_dims_pos = [";
+  llvm::interleaveComma(getInnerDimsPos(), p);
+  p << "]";
+
+  p << " inner_tiles = ";
+  printDynamicIndexList(p, *this, getInnerTiles(), getStaticInnerTilesAttr());
+
+  p << " into " << getDest();
+
+  p.printOptionalAttrDict((*this)->getAttrs(),
+                          {"static_inner_tiles", "inner_dims_pos",
+                           "outer_dims_perm", "operandSegmentSizes"});
+
+  p << " : " << getSource().getType();
+  p << " -> " << getDest().getType();
+}
+
 void PackOp::build(OpBuilder &builder, OperationState &state, Value source,
                    Value dest, ArrayRef<int64_t> innerDimsPos,
                    ArrayRef<OpFoldResult> innerTiles,
@@ -5395,13 +5573,11 @@ asShapeWithAnyValueAsDynamic(ArrayRef<OpFoldResult> ofrs) {
   return result;
 }
 
-/// Helper for PackOp::{getResultShape,inferPackedType}. Returns the shape of
-/// the packed type. Having a shared helper helps implement these two methods in
-/// a way that ensures that they agree on which dimensions are dynamic.
-static SmallVector<int64_t> getPackOpResultTypeShape(
-    ArrayRef<int64_t> sourceShape, ArrayRef<int64_t> innerTileSizes,
-    ArrayRef<int64_t> innerDimsPos, ArrayRef<int64_t> outerDimsPerm) {
-  SmallVector<int64_t> resultShape = llvm::to_vector(sourceShape);
+SmallVector<int64_t> PackOp::inferPackedShape(ArrayRef<int64_t> inputShape,
+                                              ArrayRef<int64_t> innerTileSizes,
+                                              ArrayRef<int64_t> innerDimsPos,
+                                              ArrayRef<int64_t> outerDimsPerm) {
+  SmallVector<int64_t> resultShape = llvm::to_vector(inputShape);
   for (auto tiledDim : llvm::enumerate(llvm::to_vector(innerDimsPos))) {
     if (ShapedType::isDynamic(resultShape[tiledDim.value()]))
       continue;
@@ -5441,9 +5617,9 @@ SmallVector<OpFoldResult> PackOp::getResultShape(
   resultDims.append(innerTileSizes.begin(), innerTileSizes.end());
 
   SmallVector<int64_t> resultTypeShape =
-      getPackOpResultTypeShape(asShapeWithAnyValueAsDynamic(sourceDims),
-                               asShapeWithAnyValueAsDynamic(innerTileSizes),
-                               innerDimsPos, outerDimsPerm);
+      inferPackedShape(asShapeWithAnyValueAsDynamic(sourceDims),
+                       asShapeWithAnyValueAsDynamic(innerTileSizes),
+                       innerDimsPos, outerDimsPerm);
 
   // Fix-up `resultDims` to ensure that they are Value's if and only if the
   // result type shape says it's a dynamic dim. This is needed as callers may
@@ -5459,15 +5635,21 @@ SmallVector<OpFoldResult> PackOp::getResultShape(
   return resultDims;
 }
 
-/// Get the expected packed type based on source type, tile factors, position of
-/// the inner tiles and permutation of the outer tiled loop.
-RankedTensorType PackOp::inferPackedType(RankedTensorType sourceType,
+RankedTensorType PackOp::inferPackedTensorType(
+    RankedTensorType sourceType, ArrayRef<int64_t> innerTileSizes,
+    ArrayRef<int64_t> innerDimsPos, ArrayRef<int64_t> outerDimsPerm) {
+  SmallVector<int64_t> resultShape = inferPackedShape(
+      sourceType.getShape(), innerTileSizes, innerDimsPos, outerDimsPerm);
+  return RankedTensorType::get(resultShape, sourceType.getElementType());
+}
+
+MemRefType PackOp::inferPackedMemRefType(MemRefType sourceType,
         ...
[truncated]

Extend linalg.pack and linalg.unpack to accept memref operands in addition to tensors.
As part of this change, we now disable all transformations when these ops have memref semantics.

Closes llvm#129004

Co-authored-by: Hyunsung Lee <[email protected]>
Signed-off-by: Ryutaro Okada <[email protected]>
@sakupan102
Copy link
Contributor Author

Could anyone review this?

@hanhanW
Copy link
Contributor

hanhanW commented Nov 19, 2025

Could anyone review this?

Sorry that I was busy on many reviews and CI issues. I'll review it tomorrow. Thanks for picking this up!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[mlir][linalg] Linalg::PackOp Linalg::UnPackOp to support memref input

3 participants