Skip to content

Conversation

@banach-space
Copy link
Contributor

@banach-space banach-space commented Oct 15, 2025

This patch changes vectorizeAsTensorPackOp to require users to specify
all write-side vector sizes for linalg.pack (not just the outer
dimensions). This makes linalg.pack vectorization consistent with
linalg.unpack (see #149293 for a similar change).

Conceptually, linalg.pack consists of these high-level steps:

  • Read from the source tensor using vector.transfer_read.
  • Re-associate dimensions of the read value, as specified by
    the op (via vector.shape_cast)
  • Transpose the re-associated value according to the permutation
    in the linalg.pack op (via vector.transpose).
  • Write the result into the destination tensor via
    vector.transfer_write.

Previously, the vector sizes provided by the user were interpreted as
write-vector-sizes for PackOp outer dims (i.e. the final step above).
These were used to:

  • Infer read-vector-sizes using the inner_tiles attribute of PackOp.
  • Deduce vector sizes for the transpose and shape cast operations.
  • Ultimately determine the vector shape for the read.

However, this logic breaks when one or more tile sizes are dynamic (*).
In such cases, vectorizePackOpPrecondition would currently fail (see
@pack_with_dynamic_dims_and_dynamic_inner_tile added in this PR -
without this change it will crash).

This patch updates the contract: users now directly specify all the
"write-vector-sizes", which inherently encode all inner tile sizes - including
dynamic ones. It becomes the user's responsibility to provide valid sizes.

In practice, since linalg.pack is typically constructed, tiled, and
vectorized by the same transformation pipeline, the necessary
"write-vector-sizes" should be recoverable.

Notes for reviewers:

  • See test updates for user-facing impact.
  • Review vectorizeAsTensorPackOp as a new implementation rather than
    a diff.
  • Comments and variable names were updated to align with
    vectorizeAsTensorUnPackOp.

(*) As a concrete example, "scalable" tile sizes are represent as
dynamic values. Note, support for "scalable" vectorisation will be added
in a separate PR.

@banach-space banach-space changed the base branch from main to users/banach-space/linalg/update_vec_tests October 15, 2025 10:07
@llvmbot
Copy link
Member

llvmbot commented Oct 15, 2025

@llvm/pr-subscribers-mlir

Author: Andrzej Warzyński (banach-space)

Changes
  • [mlir][linalg][nfc] Clean-up vectorization tests
  • [mlir][linalg] Update vectorizatio of linalg.pack

Patch is 22.66 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/163539.diff

6 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Linalg/Utils/Utils.h (+2-1)
  • (modified) mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp (+2-1)
  • (modified) mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp (+116-98)
  • (modified) mlir/lib/Dialect/Linalg/Utils/Utils.cpp (+4-4)
  • (modified) mlir/test/Dialect/Linalg/vectorization/linalg-ops-with-patterns.mlir (+2)
  • (modified) mlir/test/Dialect/Linalg/vectorization/linalg-ops.mlir (+93-5)
diff --git a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
index 48978eb7663d5..49c75f4b00280 100644
--- a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
+++ b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
@@ -37,7 +37,8 @@ namespace linalg {
 /// This function uses the helper function `computePackUnPackPerm` to get
 /// the permutation vector. Only major difference between UnPack and Pack is
 /// that packOp uses destination rank whereas unpack Uses source rank.
-SmallVector<int64_t> getPackInverseDestPerm(linalg::PackOp packOp);
+SmallVector<int64_t> getPackInverseDestPerm(linalg::PackOp packOp,
+                                            PackingMetadata &metadatap);
 
 /// Shell function to compute the Source Permutation of unPackOp.
 /// This function, like the getPackInverseDestPerm uses the helper function
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
index eb2d825e17e44..12b6da774701c 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
@@ -234,8 +234,9 @@ FailureOr<LowerPackResult> linalg::lowerPack(RewriterBase &rewriter,
   // before any outer or inner permutations have been applied.
   PackingMetadata packingMetadata = computePackingMetadata(
       packedTensorType.getRank(), packOp.getInnerDimsPos());
+  PackingMetadata packMetadata;
   SmallVector<int64_t> packedToStripMinedShapePerm =
-      getPackInverseDestPerm(packOp);
+      getPackInverseDestPerm(packOp, packMetadata);
 
   // 3. Compute the stripMinedShape: this is the packed shape before any outer
   // or inner permutations have been applied.
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index 9d62491214018..e460797a309c4 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -1568,7 +1568,9 @@ vectorizeAsLinalgGeneric(RewriterBase &rewriter, VectorizationState &state,
 /// permutations.
 static SmallVector<int64_t> getTiledPackShape(linalg::PackOp packOp,
                                               ArrayRef<int64_t> destShape) {
-  return applyPermutation(destShape, linalg::getPackInverseDestPerm(packOp));
+  PackingMetadata metadata;
+  return applyPermutation(destShape,
+                          linalg::getPackInverseDestPerm(packOp, metadata));
 }
 
 /// Determines whether a mask for xfer_write is trivially "all true"
@@ -1761,99 +1763,6 @@ createWriteOrMaskedWrite(OpBuilder &builder, Location loc, Value vecToStore,
   return mlir::vector::maskOperation(builder, write, maskForWrite);
 }
 
-/// Vectorize linalg::PackOp with (1) static inner_tiles (2) constant
-/// padding value and (3) input vector sizes into:
-///
-///   masked_transfer_read->shape_cast->transpose->transfer_write_in_bounds
-///
-/// As in the following example:
-/// %pack = tensor.pack %src inner_dims_pos = [2, 1] inner_tiles = [16, 2]
-///     into %dst : tensor<32x8x16xf32> -> tensor<32x4x1x16x2xf32>
-///
-/// This pack would be vectorized to:
-///
-/// %load = vector.mask %mask {
-///     vector.transfer_read %arg0[%c0, %c0, %c0], %cst
-///         {in_bounds = [true, true, true]} :
-///         tensor<32x7x16xf32>, vector<32x8x16xf32>
-/// } : vector<32x8x16xi1> -> vector<32x8x16xf32>
-/// %shape_cast = vector.shape_cast %load : vector<32x8x16xf32>
-///                                         to vector<32x4x2x1x16xf32>
-/// %transpose = vector.transpose %shape_cast, [0, 1, 3, 4, 2]
-///     : vector<32x4x2x1x16xf32> to vector<32x4x1x16x2xf32>
-/// %write = vector.transfer_write %transpose,
-///     %empty[%c0_0, %c0_0, %c0_0, %c0_0, %c0_0]
-///     {in_bounds = [true, true, true, true, true]}
-///     : vector<32x4x1x16x2xf32>, tensor<32x4x1x16x2xf32>
-///
-/// If the (3) input vector sizes are not provided, the vector sizes are
-/// determined by the result tensor shape and the `in_bounds`
-/// attribute is used instead of masking to mark out-of-bounds accesses.
-///
-/// NOTE: The input vector sizes specify the dimensions corresponding to the
-/// outer dimensions of the output tensor. The remaining dimensions are
-/// computed based on, e.g., the static inner tiles.
-/// Supporting dynamic inner tiles will require the user to specify the
-/// missing vector sizes. This is left as a TODO.
-static LogicalResult
-vectorizeAsTensorPackOp(RewriterBase &rewriter, linalg::PackOp packOp,
-                        ArrayRef<int64_t> inputVectorSizes,
-                        SmallVectorImpl<Value> &newResults) {
-  // TODO: Introduce a parent class that will handle the insertion point update.
-  OpBuilder::InsertionGuard g(rewriter);
-  rewriter.setInsertionPoint(packOp);
-
-  Location loc = packOp.getLoc();
-  std::optional<Value> padValue = packOp.getPaddingValue()
-                                      ? std::optional(packOp.getPaddingValue())
-                                      : std::nullopt;
-
-  // If the input vector sizes are not provided, then the vector sizes are
-  // determined by the result tensor shape. In case the vector sizes aren't
-  // provided, we update the inBounds attribute instead of masking.
-  bool useInBoundsInsteadOfMasking = false;
-  if (inputVectorSizes.empty()) {
-    ArrayRef<int64_t> resultTensorShape = packOp.getDestType().getShape();
-    inputVectorSizes = resultTensorShape.take_front(packOp.getSourceRank());
-    useInBoundsInsteadOfMasking = true;
-  }
-
-  // Create masked TransferReadOp.
-  SmallVector<int64_t> inputShape(inputVectorSizes);
-  auto innerTiles = packOp.getStaticInnerTiles();
-  auto innerDimsPos = packOp.getInnerDimsPos();
-  auto outerDimsPerm = packOp.getOuterDimsPerm();
-  if (!outerDimsPerm.empty())
-    applyPermutationToVector(inputShape,
-                             invertPermutationVector(outerDimsPerm));
-  for (auto [idx, size] : enumerate(innerTiles))
-    inputShape[innerDimsPos[idx]] *= size;
-  auto maskedRead = vector::createReadOrMaskedRead(
-      rewriter, loc, packOp.getSource(), inputShape, padValue,
-      useInBoundsInsteadOfMasking,
-      /*inputScalableVecSizes=*/{});
-
-  // Create ShapeCastOp.
-  SmallVector<int64_t> destShape(inputVectorSizes);
-  destShape.append(innerTiles.begin(), innerTiles.end());
-  auto tiledPackType = VectorType::get(getTiledPackShape(packOp, destShape),
-                                       packOp.getDestType().getElementType());
-  auto shapeCastOp =
-      vector::ShapeCastOp::create(rewriter, loc, tiledPackType, maskedRead);
-
-  // Create TransposeOp.
-  auto destPermutation =
-      invertPermutationVector(getPackInverseDestPerm(packOp));
-  auto transposeOp = vector::TransposeOp::create(
-      rewriter, loc, shapeCastOp.getResult(), destPermutation);
-
-  // Create TransferWriteOp.
-  Operation *write = createWriteOrMaskedWrite(
-      rewriter, loc, transposeOp.getResult(), packOp.getDest());
-  newResults.push_back(write->getResult(0));
-  return success();
-}
-
 /// Given the re-associations, "collapses" the input Vector type
 ///
 /// This is similar to CollapseShapeOp::inferCollapsedType with two notable
@@ -1901,12 +1810,119 @@ static VectorType getCollapsedVecType(VectorType type,
   return VectorType::get(newShape, type.getElementType(), newScalableFlags);
 }
 
+/// Vectorize `linalg.pack` as:
+///   * xfer_read -> shape_cast -> transpose -> xfer_write
+///
+/// The input-vector-sizes specify the _write_ vector sizes (i.e. the vector
+/// sizes for the xfer_write operation). This is sufficient to infer the other
+/// vector sizes required here.
+///
+/// If the vector sizes are not provided:
+///  * the vector sizes are determined from the destination tensor static shape.
+///  * the inBounds attribute is used instead of masking.
+///
+/// EXAMPLE (no vector sizes):
+/// ```
+///   %pack = tensor.pack %src
+///     inner_dims_pos = [2, 1]
+///     inner_tiles = [16, 2]
+///     into %dst : tensor<32x8x16xf32> -> tensor<32x4x1x16x2xf32>
+/// ``
+/// is vectorizes as:
+/// ```
+///   %read = vector.transfer_read %src
+///     : tensor<32x7x16xf32>, vector<32x8x16xf32>
+///   %sc = vector.shape_cast %read
+///     : vector<32x8x16xf32> to vector<32x4x2x1x16xf32>
+///   %tr = vector.transpose %sc, [0, 1, 3, 4, 2]
+///     : vector<32x4x2x1x16xf32> to vector<32x4x1x16x2xf32>
+///   %write = vector.transfer_write %tr into %dest
+///     : vector<32x4x1x16x2xf32>, tensor<32x4x1x16x2xf32>
+/// ```
+static LogicalResult
+vectorizeAsTensorPackOp(RewriterBase &rewriter, linalg::PackOp packOp,
+                        ArrayRef<int64_t> inputVectorSizes,
+                        SmallVectorImpl<Value> &newResults) {
+  if (!inputVectorSizes.empty()) {
+    assert(inputVectorSizes.size() == packOp.getDestRank() &&
+           "Invalid number of input vector sizes!");
+  }
+
+  // TODO: Introduce a parent class that will handle the insertion point update.
+  OpBuilder::InsertionGuard g(rewriter);
+  rewriter.setInsertionPoint(packOp);
+
+  Location loc = packOp.getLoc();
+  std::optional<Value> padValue = packOp.getPaddingValue()
+                                      ? std::optional(packOp.getPaddingValue())
+                                      : std::nullopt;
+
+  SmallVector<int64_t> destShape =
+      SmallVector<int64_t>(packOp.getDestType().getShape());
+
+  // This is just a convenience alias to clearly communicate that the input
+  // vector sizes determine the _write_ sizes.
+  ArrayRef<int64_t> &writeVectorSizes = inputVectorSizes;
+
+  // In the absence of input-vector-sizes, use the _static_ input tensor shape.
+  // In addition, use the inBounds attribute instead of masking.
+  bool useInBoundsInsteadOfMasking = false;
+  if (writeVectorSizes.empty()) {
+    if (ShapedType::isDynamicShape(destShape))
+      return rewriter.notifyMatchFailure(packOp,
+                                         "Unable to infer vector sizes!");
+
+    writeVectorSizes = destShape;
+    useInBoundsInsteadOfMasking = true;
+  }
+
+  // Compute vector type for the _read_ opeartion. The required dims are
+  // determined based on the _write_ vector sizes. This is done in two
+  // steps:
+  //  1) Invert the permutation/transposition that's part of the Pack
+  //  operation.
+  //  2) Collapse the tiled sizes/dims to "return" to the unpacked domain.
+  PackingMetadata packMetadata;
+  auto destInvPermutation = getPackInverseDestPerm(packOp, packMetadata);
+
+  SmallVector<int64_t> inputVecSizesPrePerm(writeVectorSizes);
+  applyPermutationToVector(inputVecSizesPrePerm, destInvPermutation);
+
+  VectorType readVecType = getCollapsedVecType(
+      VectorType::get(inputVecSizesPrePerm, packOp.getType().getElementType()),
+      getSymbolLessAffineMaps(convertReassociationIndicesToExprs(
+          rewriter.getContext(), packMetadata.reassociations)));
+
+  // Create masked TransferReadOp.
+  auto maskedRead = vector::createReadOrMaskedRead(
+      rewriter, loc, packOp.getSource(), readVecType.getShape(), padValue,
+      useInBoundsInsteadOfMasking,
+      /*inputScalableVecSizes=*/{});
+
+  // Create ShapeCastOp.
+  auto expandedVecType =
+      VectorType::get(inputVecSizesPrePerm, packOp.getType().getElementType());
+  auto shapeCastOp =
+      vector::ShapeCastOp::create(rewriter, loc, expandedVecType, maskedRead);
+
+  // Create TransposeOp.
+  auto destPermutation = invertPermutationVector(destInvPermutation);
+  auto transposeOp = vector::TransposeOp::create(
+      rewriter, loc, shapeCastOp.getResult(), destPermutation);
+
+  // Create TransferWriteOp.
+  Operation *write = createWriteOrMaskedWrite(
+      rewriter, loc, transposeOp.getResult(), packOp.getDest());
+  newResults.push_back(write->getResult(0));
+  return success();
+}
+
 /// Vectorize `linalg.unpack` as:
 ///   * xfer_read -> vector.transpose -> vector.shape_cast -> xfer_write
 ///
-/// The input-vector-sizes specify the read vector sizes (i.e. the vector sizes
-/// for the xfer_read operation). This is sufficient to infer the other vector
-/// sizes required here.
+/// The input-vector-sizes specify the _read_ vector sizes (i.e. the vector
+/// sizes for the xfer_read operation). This is sufficient to infer the other
+/// vector sizes required here.
 ///
 /// If the vector sizes are not provided:
 ///  * the vector sizes are determined from the input tensor static shape.
@@ -1960,7 +1976,8 @@ vectorizeAsTensorUnpackOp(RewriterBase &rewriter, linalg::UnPackOp unpackOp,
   // In the absence of input-vector-sizes, use the _static_ input tensor shape.
   if (inputVectorSizes.empty()) {
     if (ShapedType::isDynamicShape(sourceShape))
-      return failure();
+      return rewriter.notifyMatchFailure(unpackOp,
+                                         "Unable to infer vector sizes!");
 
     readVectorSizes.assign(sourceShape.begin(), sourceShape.end());
     useInBoundsInsteadOfMasking = true;
@@ -2443,6 +2460,7 @@ vectorizePackOpPrecondition(linalg::PackOp packOp,
                             ArrayRef<int64_t> inputVectorSizes) {
   auto padValue = packOp.getPaddingValue();
   Attribute cstAttr;
+  // TODO: Relax this condiiton
   if (padValue && !matchPattern(padValue, m_Constant(&cstAttr))) {
     LDBG() << "pad value is not constant: " << packOp;
     return failure();
diff --git a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
index 24d3722cf5426..a91397d29f3e3 100644
--- a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
@@ -171,9 +171,9 @@ computePackUnPackPerm(int64_t rank, ArrayRef<int64_t> &innerDimsPos,
 namespace mlir {
 namespace linalg {
 
-SmallVector<int64_t> getPackInverseDestPerm(PackOp packOp) {
+SmallVector<int64_t> getPackInverseDestPerm(PackOp packOp,
+                                            PackingMetadata &pMetadata) {
 
-  PackingMetadata pMetadata;
   int64_t packedRank = packOp.getDestType().getRank();
   ArrayRef<int64_t> innerDimPos = packOp.getInnerDimsPos();
   ArrayRef<int64_t> outerPerm = packOp.getOuterDimsPerm();
@@ -189,11 +189,11 @@ SmallVector<int64_t> getUnPackInverseSrcPerm(UnPackOp unpackOp) {
 
 SmallVector<int64_t> getUnPackInverseSrcPerm(UnPackOp unpackOp,
                                              PackingMetadata &metadata) {
-  int64_t unpackRank = unpackOp.getSourceType().getRank();
+  int64_t packedRank = unpackOp.getSourceType().getRank();
   ArrayRef<int64_t> innerDimPos = unpackOp.getInnerDimsPos();
   ArrayRef<int64_t> outerPerm = unpackOp.getOuterDimsPerm();
   SmallVector<int64_t> unpackInvSrcPerm =
-      computePackUnPackPerm(unpackRank, innerDimPos, outerPerm, metadata);
+      computePackUnPackPerm(packedRank, innerDimPos, outerPerm, metadata);
   return unpackInvSrcPerm;
 }
 
diff --git a/mlir/test/Dialect/Linalg/vectorization/linalg-ops-with-patterns.mlir b/mlir/test/Dialect/Linalg/vectorization/linalg-ops-with-patterns.mlir
index 93a03369be239..cd472802dd307 100644
--- a/mlir/test/Dialect/Linalg/vectorization/linalg-ops-with-patterns.mlir
+++ b/mlir/test/Dialect/Linalg/vectorization/linalg-ops-with-patterns.mlir
@@ -285,6 +285,8 @@ module attributes {transform.with_named_sequence} {
 
 ///----------------------------------------------------------------------------------------
 /// Tests for linalg.pack
+///
+/// TODO: Add similar tests for linalg.unpack
 ///----------------------------------------------------------------------------------------
 
 // Note, see a similar test in:
diff --git a/mlir/test/Dialect/Linalg/vectorization/linalg-ops.mlir b/mlir/test/Dialect/Linalg/vectorization/linalg-ops.mlir
index 1304a90349f71..6d3544ff4f23d 100644
--- a/mlir/test/Dialect/Linalg/vectorization/linalg-ops.mlir
+++ b/mlir/test/Dialect/Linalg/vectorization/linalg-ops.mlir
@@ -1335,7 +1335,7 @@ func.func @pack_no_padding(%src: tensor<32x8x16xf32>, %dest: tensor<4x1x32x16x2x
 module attributes {transform.with_named_sequence} {
   transform.named_sequence @__transform_main(%src: !transform.any_op {transform.readonly}) {
     %0 = transform.structured.match ops{["linalg.pack"]} in %src : (!transform.any_op) -> !transform.any_op
-    transform.structured.vectorize %0 vector_sizes [4, 1, 32] : !transform.any_op
+    transform.structured.vectorize %0 vector_sizes [4, 1, 32, 16, 2] : !transform.any_op
     transform.yield
   }
 }
@@ -1378,7 +1378,7 @@ func.func @pack_with_padding(%src: tensor<32x7x15xf32>, %dest: tensor<32x4x1x16x
 module attributes {transform.with_named_sequence} {
   transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
     %0 = transform.structured.match ops{["linalg.pack"]} in %arg0 : (!transform.any_op) -> !transform.any_op
-    transform.structured.vectorize %0 vector_sizes [32, 4, 1] : !transform.any_op
+    transform.structured.vectorize %0 vector_sizes [32, 4, 1, 16, 2] : !transform.any_op
     transform.yield
   }
 }
@@ -1424,8 +1424,15 @@ module attributes {transform.with_named_sequence} {
 // CHECK-LABEL: func @pack_with_dynamic_dims
 // CHECK-SAME:      %[[SRC:.*]]: tensor<?x?xf32>,
 // CHECK-SAME:      %[[DEST:.*]]: tensor<?x?x16x2xf32>
-func.func @pack_with_dynamic_dims(%src: tensor<?x?xf32>, %dest: tensor<?x?x16x2xf32>) -> tensor<?x?x16x2xf32> {
-  %pack = linalg.pack %src inner_dims_pos = [1, 0] inner_tiles = [16, 2] into %dest : tensor<?x?xf32> -> tensor<?x?x16x2xf32>
+func.func @pack_with_dynamic_dims(
+    %src: tensor<?x?xf32>, 
+    %dest: tensor<?x?x16x2xf32>) -> tensor<?x?x16x2xf32> {
+
+  %pack = linalg.pack %src 
+    inner_dims_pos = [1, 0]
+    inner_tiles = [16, 2]
+    into %dest : tensor<?x?xf32> -> tensor<?x?x16x2xf32>
+
   return %pack : tensor<?x?x16x2xf32>
 }
 
@@ -1433,30 +1440,111 @@ func.func @pack_with_dynamic_dims(%src: tensor<?x?xf32>, %dest: tensor<?x?x16x2x
 //  CHECK-DAG: %[[C0_1:.*]] = arith.constant 0 : index
 //  CHECK-DAG: %[[C0_0:.*]] = arith.constant 0 : index
 //  CHECK-DAG: %[[C1_0:.*]] = arith.constant 1 : index
+
+/// Compute mask for xfer_read
 //  CHECK-DAG: %[[D0_0:.*]] = tensor.dim {{.*}} %[[C0_0]] : tensor<?x?xf32>
 //  CHECK-DAG: %[[D1_0:.*]] = tensor.dim {{.*}} %[[C1_0]] : tensor<?x?xf32>
 //      CHECK: %[[MASK:.*]] = vector.create_mask %[[D0_0]], %[[D1_0]] : vector<8x16xi1>
+
+/// --= read =---
 //      CHECK: %[[READ:.*]] = vector.mask %[[MASK]] {
 // CHECK-SAME:   vector.transfer_read %{{.*}}[%[[C0_1]], %[[C0_1]]], %[[CST]]
 // CHECK-SAME:   {in_bounds = [true, true]} : tensor<?x?xf32>, vector<8x16xf32>
 // CHECK-SAME: } : vector<8x16xi1> -> vector<8x16xf32>
+
+/// --= shape_cast =---
 //      CHECK: %[[SC:.*]] = vector.shape_cast %[[READ]] : vector<8x16xf32> to vector<4x2x1x16xf32>
+
+/// --= transpose =---
 //      CHECK: %[[TR:.*]] = vector.transpose %[[SC]], [0, 2, 3, 1] : vector<4x2x1x16xf32> to vector<4x1x16x2xf32>
+
+/// Compute mask for xfer_write
 //  CHECK-DAG: %[[C0_2:.*]] = arith.constant 0 : index
 //  CHECK-DAG: %[[C16:.*]] = arith.constant 16 : index
 //  CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index
 //  CHECK-DAG: %[[D2:.*]] = tensor.dim %[[DEST]], {{.*}} : tensor<?x?x16x2xf32>
 //  CHECK-DAG: %[[D3:.*]] = tensor.dim %[[DEST]], {{.*}} : tensor<?x?x16x2xf32>
 //      CHECK: %[[MASK_0:.*]] = vector.create_mask %[[D2]], %[[D3]], %[[C16]], %[[C2]] : vector<4x1x16x2xi1>
+
+/// --= write =---
 //      CHECK: %[[WRITE:.*]] = vector.mask %[[MASK_0]] {
 // CHECK-SAME:   vector.transfer_write %[[TR]], %[[DEST]][%[[C0_2]], %[[C0_2]], %[[C0_2]], %[[C0_2]]]
 // CHECK-SAME:   {in_bounds = [true, true, true, true]} : vector<4x1x16x2xf32>, tensor<?x?x16x2xf32>
+
 //      CHECK: return %[[WRITE]] : tensor<?x?x16x2xf32>
 
 module attributes {transform.with_named_sequence} {
   transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
     %0 = transform.structured.match ops{["linalg.pack"]} in %arg0 : (!transform.any_op) -> !transform.any_op
-    transform.structured.vectorize %0 vector_sizes [4, 1] : !transform.any_op
+    transform.structured.vectorize %0 vector_sizes [4, 1, 16, 2] : !transform.any_op
+    transform.yield
+  }
+}
+
+// -----
+
+/// Similar to the test above, but one of the inner tile sizes is dynamic. As a
+/// result, more output dims are dynamic (and, e.g., output mask calcuation is a bit different).
+
+// CHECK-LABEL: func @pack_with_dynamic_dims_and_dynamic_inner_tile
+// CHECK-SAME:      %[[SRC:.*]]: tensor<?x?xf32>,
+// CHECK-SAME:     ...
[truncated]

@llvmbot
Copy link
Member

llvmbot commented Oct 15, 2025

@llvm/pr-subscribers-mlir-linalg

Author: Andrzej Warzyński (banach-space)

Changes
  • [mlir][linalg][nfc] Clean-up vectorization tests
  • [mlir][linalg] Update vectorizatio of linalg.pack

Patch is 22.66 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/163539.diff

6 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Linalg/Utils/Utils.h (+2-1)
  • (modified) mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp (+2-1)
  • (modified) mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp (+116-98)
  • (modified) mlir/lib/Dialect/Linalg/Utils/Utils.cpp (+4-4)
  • (modified) mlir/test/Dialect/Linalg/vectorization/linalg-ops-with-patterns.mlir (+2)
  • (modified) mlir/test/Dialect/Linalg/vectorization/linalg-ops.mlir (+93-5)
diff --git a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
index 48978eb7663d5..49c75f4b00280 100644
--- a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
+++ b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
@@ -37,7 +37,8 @@ namespace linalg {
 /// This function uses the helper function `computePackUnPackPerm` to get
 /// the permutation vector. Only major difference between UnPack and Pack is
 /// that packOp uses destination rank whereas unpack Uses source rank.
-SmallVector<int64_t> getPackInverseDestPerm(linalg::PackOp packOp);
+SmallVector<int64_t> getPackInverseDestPerm(linalg::PackOp packOp,
+                                            PackingMetadata &metadatap);
 
 /// Shell function to compute the Source Permutation of unPackOp.
 /// This function, like the getPackInverseDestPerm uses the helper function
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
index eb2d825e17e44..12b6da774701c 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
@@ -234,8 +234,9 @@ FailureOr<LowerPackResult> linalg::lowerPack(RewriterBase &rewriter,
   // before any outer or inner permutations have been applied.
   PackingMetadata packingMetadata = computePackingMetadata(
       packedTensorType.getRank(), packOp.getInnerDimsPos());
+  PackingMetadata packMetadata;
   SmallVector<int64_t> packedToStripMinedShapePerm =
-      getPackInverseDestPerm(packOp);
+      getPackInverseDestPerm(packOp, packMetadata);
 
   // 3. Compute the stripMinedShape: this is the packed shape before any outer
   // or inner permutations have been applied.
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index 9d62491214018..e460797a309c4 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -1568,7 +1568,9 @@ vectorizeAsLinalgGeneric(RewriterBase &rewriter, VectorizationState &state,
 /// permutations.
 static SmallVector<int64_t> getTiledPackShape(linalg::PackOp packOp,
                                               ArrayRef<int64_t> destShape) {
-  return applyPermutation(destShape, linalg::getPackInverseDestPerm(packOp));
+  PackingMetadata metadata;
+  return applyPermutation(destShape,
+                          linalg::getPackInverseDestPerm(packOp, metadata));
 }
 
 /// Determines whether a mask for xfer_write is trivially "all true"
@@ -1761,99 +1763,6 @@ createWriteOrMaskedWrite(OpBuilder &builder, Location loc, Value vecToStore,
   return mlir::vector::maskOperation(builder, write, maskForWrite);
 }
 
-/// Vectorize linalg::PackOp with (1) static inner_tiles (2) constant
-/// padding value and (3) input vector sizes into:
-///
-///   masked_transfer_read->shape_cast->transpose->transfer_write_in_bounds
-///
-/// As in the following example:
-/// %pack = tensor.pack %src inner_dims_pos = [2, 1] inner_tiles = [16, 2]
-///     into %dst : tensor<32x8x16xf32> -> tensor<32x4x1x16x2xf32>
-///
-/// This pack would be vectorized to:
-///
-/// %load = vector.mask %mask {
-///     vector.transfer_read %arg0[%c0, %c0, %c0], %cst
-///         {in_bounds = [true, true, true]} :
-///         tensor<32x7x16xf32>, vector<32x8x16xf32>
-/// } : vector<32x8x16xi1> -> vector<32x8x16xf32>
-/// %shape_cast = vector.shape_cast %load : vector<32x8x16xf32>
-///                                         to vector<32x4x2x1x16xf32>
-/// %transpose = vector.transpose %shape_cast, [0, 1, 3, 4, 2]
-///     : vector<32x4x2x1x16xf32> to vector<32x4x1x16x2xf32>
-/// %write = vector.transfer_write %transpose,
-///     %empty[%c0_0, %c0_0, %c0_0, %c0_0, %c0_0]
-///     {in_bounds = [true, true, true, true, true]}
-///     : vector<32x4x1x16x2xf32>, tensor<32x4x1x16x2xf32>
-///
-/// If the (3) input vector sizes are not provided, the vector sizes are
-/// determined by the result tensor shape and the `in_bounds`
-/// attribute is used instead of masking to mark out-of-bounds accesses.
-///
-/// NOTE: The input vector sizes specify the dimensions corresponding to the
-/// outer dimensions of the output tensor. The remaining dimensions are
-/// computed based on, e.g., the static inner tiles.
-/// Supporting dynamic inner tiles will require the user to specify the
-/// missing vector sizes. This is left as a TODO.
-static LogicalResult
-vectorizeAsTensorPackOp(RewriterBase &rewriter, linalg::PackOp packOp,
-                        ArrayRef<int64_t> inputVectorSizes,
-                        SmallVectorImpl<Value> &newResults) {
-  // TODO: Introduce a parent class that will handle the insertion point update.
-  OpBuilder::InsertionGuard g(rewriter);
-  rewriter.setInsertionPoint(packOp);
-
-  Location loc = packOp.getLoc();
-  std::optional<Value> padValue = packOp.getPaddingValue()
-                                      ? std::optional(packOp.getPaddingValue())
-                                      : std::nullopt;
-
-  // If the input vector sizes are not provided, then the vector sizes are
-  // determined by the result tensor shape. In case the vector sizes aren't
-  // provided, we update the inBounds attribute instead of masking.
-  bool useInBoundsInsteadOfMasking = false;
-  if (inputVectorSizes.empty()) {
-    ArrayRef<int64_t> resultTensorShape = packOp.getDestType().getShape();
-    inputVectorSizes = resultTensorShape.take_front(packOp.getSourceRank());
-    useInBoundsInsteadOfMasking = true;
-  }
-
-  // Create masked TransferReadOp.
-  SmallVector<int64_t> inputShape(inputVectorSizes);
-  auto innerTiles = packOp.getStaticInnerTiles();
-  auto innerDimsPos = packOp.getInnerDimsPos();
-  auto outerDimsPerm = packOp.getOuterDimsPerm();
-  if (!outerDimsPerm.empty())
-    applyPermutationToVector(inputShape,
-                             invertPermutationVector(outerDimsPerm));
-  for (auto [idx, size] : enumerate(innerTiles))
-    inputShape[innerDimsPos[idx]] *= size;
-  auto maskedRead = vector::createReadOrMaskedRead(
-      rewriter, loc, packOp.getSource(), inputShape, padValue,
-      useInBoundsInsteadOfMasking,
-      /*inputScalableVecSizes=*/{});
-
-  // Create ShapeCastOp.
-  SmallVector<int64_t> destShape(inputVectorSizes);
-  destShape.append(innerTiles.begin(), innerTiles.end());
-  auto tiledPackType = VectorType::get(getTiledPackShape(packOp, destShape),
-                                       packOp.getDestType().getElementType());
-  auto shapeCastOp =
-      vector::ShapeCastOp::create(rewriter, loc, tiledPackType, maskedRead);
-
-  // Create TransposeOp.
-  auto destPermutation =
-      invertPermutationVector(getPackInverseDestPerm(packOp));
-  auto transposeOp = vector::TransposeOp::create(
-      rewriter, loc, shapeCastOp.getResult(), destPermutation);
-
-  // Create TransferWriteOp.
-  Operation *write = createWriteOrMaskedWrite(
-      rewriter, loc, transposeOp.getResult(), packOp.getDest());
-  newResults.push_back(write->getResult(0));
-  return success();
-}
-
 /// Given the re-associations, "collapses" the input Vector type
 ///
 /// This is similar to CollapseShapeOp::inferCollapsedType with two notable
@@ -1901,12 +1810,119 @@ static VectorType getCollapsedVecType(VectorType type,
   return VectorType::get(newShape, type.getElementType(), newScalableFlags);
 }
 
+/// Vectorize `linalg.pack` as:
+///   * xfer_read -> shape_cast -> transpose -> xfer_write
+///
+/// The input-vector-sizes specify the _write_ vector sizes (i.e. the vector
+/// sizes for the xfer_write operation). This is sufficient to infer the other
+/// vector sizes required here.
+///
+/// If the vector sizes are not provided:
+///  * the vector sizes are determined from the destination tensor static shape.
+///  * the inBounds attribute is used instead of masking.
+///
+/// EXAMPLE (no vector sizes):
+/// ```
+///   %pack = tensor.pack %src
+///     inner_dims_pos = [2, 1]
+///     inner_tiles = [16, 2]
+///     into %dst : tensor<32x8x16xf32> -> tensor<32x4x1x16x2xf32>
+/// ``
+/// is vectorizes as:
+/// ```
+///   %read = vector.transfer_read %src
+///     : tensor<32x7x16xf32>, vector<32x8x16xf32>
+///   %sc = vector.shape_cast %read
+///     : vector<32x8x16xf32> to vector<32x4x2x1x16xf32>
+///   %tr = vector.transpose %sc, [0, 1, 3, 4, 2]
+///     : vector<32x4x2x1x16xf32> to vector<32x4x1x16x2xf32>
+///   %write = vector.transfer_write %tr into %dest
+///     : vector<32x4x1x16x2xf32>, tensor<32x4x1x16x2xf32>
+/// ```
+static LogicalResult
+vectorizeAsTensorPackOp(RewriterBase &rewriter, linalg::PackOp packOp,
+                        ArrayRef<int64_t> inputVectorSizes,
+                        SmallVectorImpl<Value> &newResults) {
+  if (!inputVectorSizes.empty()) {
+    assert(inputVectorSizes.size() == packOp.getDestRank() &&
+           "Invalid number of input vector sizes!");
+  }
+
+  // TODO: Introduce a parent class that will handle the insertion point update.
+  OpBuilder::InsertionGuard g(rewriter);
+  rewriter.setInsertionPoint(packOp);
+
+  Location loc = packOp.getLoc();
+  std::optional<Value> padValue = packOp.getPaddingValue()
+                                      ? std::optional(packOp.getPaddingValue())
+                                      : std::nullopt;
+
+  SmallVector<int64_t> destShape =
+      SmallVector<int64_t>(packOp.getDestType().getShape());
+
+  // This is just a convenience alias to clearly communicate that the input
+  // vector sizes determine the _write_ sizes.
+  ArrayRef<int64_t> &writeVectorSizes = inputVectorSizes;
+
+  // In the absence of input-vector-sizes, use the _static_ input tensor shape.
+  // In addition, use the inBounds attribute instead of masking.
+  bool useInBoundsInsteadOfMasking = false;
+  if (writeVectorSizes.empty()) {
+    if (ShapedType::isDynamicShape(destShape))
+      return rewriter.notifyMatchFailure(packOp,
+                                         "Unable to infer vector sizes!");
+
+    writeVectorSizes = destShape;
+    useInBoundsInsteadOfMasking = true;
+  }
+
+  // Compute vector type for the _read_ opeartion. The required dims are
+  // determined based on the _write_ vector sizes. This is done in two
+  // steps:
+  //  1) Invert the permutation/transposition that's part of the Pack
+  //  operation.
+  //  2) Collapse the tiled sizes/dims to "return" to the unpacked domain.
+  PackingMetadata packMetadata;
+  auto destInvPermutation = getPackInverseDestPerm(packOp, packMetadata);
+
+  SmallVector<int64_t> inputVecSizesPrePerm(writeVectorSizes);
+  applyPermutationToVector(inputVecSizesPrePerm, destInvPermutation);
+
+  VectorType readVecType = getCollapsedVecType(
+      VectorType::get(inputVecSizesPrePerm, packOp.getType().getElementType()),
+      getSymbolLessAffineMaps(convertReassociationIndicesToExprs(
+          rewriter.getContext(), packMetadata.reassociations)));
+
+  // Create masked TransferReadOp.
+  auto maskedRead = vector::createReadOrMaskedRead(
+      rewriter, loc, packOp.getSource(), readVecType.getShape(), padValue,
+      useInBoundsInsteadOfMasking,
+      /*inputScalableVecSizes=*/{});
+
+  // Create ShapeCastOp.
+  auto expandedVecType =
+      VectorType::get(inputVecSizesPrePerm, packOp.getType().getElementType());
+  auto shapeCastOp =
+      vector::ShapeCastOp::create(rewriter, loc, expandedVecType, maskedRead);
+
+  // Create TransposeOp.
+  auto destPermutation = invertPermutationVector(destInvPermutation);
+  auto transposeOp = vector::TransposeOp::create(
+      rewriter, loc, shapeCastOp.getResult(), destPermutation);
+
+  // Create TransferWriteOp.
+  Operation *write = createWriteOrMaskedWrite(
+      rewriter, loc, transposeOp.getResult(), packOp.getDest());
+  newResults.push_back(write->getResult(0));
+  return success();
+}
+
 /// Vectorize `linalg.unpack` as:
 ///   * xfer_read -> vector.transpose -> vector.shape_cast -> xfer_write
 ///
-/// The input-vector-sizes specify the read vector sizes (i.e. the vector sizes
-/// for the xfer_read operation). This is sufficient to infer the other vector
-/// sizes required here.
+/// The input-vector-sizes specify the _read_ vector sizes (i.e. the vector
+/// sizes for the xfer_read operation). This is sufficient to infer the other
+/// vector sizes required here.
 ///
 /// If the vector sizes are not provided:
 ///  * the vector sizes are determined from the input tensor static shape.
@@ -1960,7 +1976,8 @@ vectorizeAsTensorUnpackOp(RewriterBase &rewriter, linalg::UnPackOp unpackOp,
   // In the absence of input-vector-sizes, use the _static_ input tensor shape.
   if (inputVectorSizes.empty()) {
     if (ShapedType::isDynamicShape(sourceShape))
-      return failure();
+      return rewriter.notifyMatchFailure(unpackOp,
+                                         "Unable to infer vector sizes!");
 
     readVectorSizes.assign(sourceShape.begin(), sourceShape.end());
     useInBoundsInsteadOfMasking = true;
@@ -2443,6 +2460,7 @@ vectorizePackOpPrecondition(linalg::PackOp packOp,
                             ArrayRef<int64_t> inputVectorSizes) {
   auto padValue = packOp.getPaddingValue();
   Attribute cstAttr;
+  // TODO: Relax this condiiton
   if (padValue && !matchPattern(padValue, m_Constant(&cstAttr))) {
     LDBG() << "pad value is not constant: " << packOp;
     return failure();
diff --git a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
index 24d3722cf5426..a91397d29f3e3 100644
--- a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
@@ -171,9 +171,9 @@ computePackUnPackPerm(int64_t rank, ArrayRef<int64_t> &innerDimsPos,
 namespace mlir {
 namespace linalg {
 
-SmallVector<int64_t> getPackInverseDestPerm(PackOp packOp) {
+SmallVector<int64_t> getPackInverseDestPerm(PackOp packOp,
+                                            PackingMetadata &pMetadata) {
 
-  PackingMetadata pMetadata;
   int64_t packedRank = packOp.getDestType().getRank();
   ArrayRef<int64_t> innerDimPos = packOp.getInnerDimsPos();
   ArrayRef<int64_t> outerPerm = packOp.getOuterDimsPerm();
@@ -189,11 +189,11 @@ SmallVector<int64_t> getUnPackInverseSrcPerm(UnPackOp unpackOp) {
 
 SmallVector<int64_t> getUnPackInverseSrcPerm(UnPackOp unpackOp,
                                              PackingMetadata &metadata) {
-  int64_t unpackRank = unpackOp.getSourceType().getRank();
+  int64_t packedRank = unpackOp.getSourceType().getRank();
   ArrayRef<int64_t> innerDimPos = unpackOp.getInnerDimsPos();
   ArrayRef<int64_t> outerPerm = unpackOp.getOuterDimsPerm();
   SmallVector<int64_t> unpackInvSrcPerm =
-      computePackUnPackPerm(unpackRank, innerDimPos, outerPerm, metadata);
+      computePackUnPackPerm(packedRank, innerDimPos, outerPerm, metadata);
   return unpackInvSrcPerm;
 }
 
diff --git a/mlir/test/Dialect/Linalg/vectorization/linalg-ops-with-patterns.mlir b/mlir/test/Dialect/Linalg/vectorization/linalg-ops-with-patterns.mlir
index 93a03369be239..cd472802dd307 100644
--- a/mlir/test/Dialect/Linalg/vectorization/linalg-ops-with-patterns.mlir
+++ b/mlir/test/Dialect/Linalg/vectorization/linalg-ops-with-patterns.mlir
@@ -285,6 +285,8 @@ module attributes {transform.with_named_sequence} {
 
 ///----------------------------------------------------------------------------------------
 /// Tests for linalg.pack
+///
+/// TODO: Add similar tests for linalg.unpack
 ///----------------------------------------------------------------------------------------
 
 // Note, see a similar test in:
diff --git a/mlir/test/Dialect/Linalg/vectorization/linalg-ops.mlir b/mlir/test/Dialect/Linalg/vectorization/linalg-ops.mlir
index 1304a90349f71..6d3544ff4f23d 100644
--- a/mlir/test/Dialect/Linalg/vectorization/linalg-ops.mlir
+++ b/mlir/test/Dialect/Linalg/vectorization/linalg-ops.mlir
@@ -1335,7 +1335,7 @@ func.func @pack_no_padding(%src: tensor<32x8x16xf32>, %dest: tensor<4x1x32x16x2x
 module attributes {transform.with_named_sequence} {
   transform.named_sequence @__transform_main(%src: !transform.any_op {transform.readonly}) {
     %0 = transform.structured.match ops{["linalg.pack"]} in %src : (!transform.any_op) -> !transform.any_op
-    transform.structured.vectorize %0 vector_sizes [4, 1, 32] : !transform.any_op
+    transform.structured.vectorize %0 vector_sizes [4, 1, 32, 16, 2] : !transform.any_op
     transform.yield
   }
 }
@@ -1378,7 +1378,7 @@ func.func @pack_with_padding(%src: tensor<32x7x15xf32>, %dest: tensor<32x4x1x16x
 module attributes {transform.with_named_sequence} {
   transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
     %0 = transform.structured.match ops{["linalg.pack"]} in %arg0 : (!transform.any_op) -> !transform.any_op
-    transform.structured.vectorize %0 vector_sizes [32, 4, 1] : !transform.any_op
+    transform.structured.vectorize %0 vector_sizes [32, 4, 1, 16, 2] : !transform.any_op
     transform.yield
   }
 }
@@ -1424,8 +1424,15 @@ module attributes {transform.with_named_sequence} {
 // CHECK-LABEL: func @pack_with_dynamic_dims
 // CHECK-SAME:      %[[SRC:.*]]: tensor<?x?xf32>,
 // CHECK-SAME:      %[[DEST:.*]]: tensor<?x?x16x2xf32>
-func.func @pack_with_dynamic_dims(%src: tensor<?x?xf32>, %dest: tensor<?x?x16x2xf32>) -> tensor<?x?x16x2xf32> {
-  %pack = linalg.pack %src inner_dims_pos = [1, 0] inner_tiles = [16, 2] into %dest : tensor<?x?xf32> -> tensor<?x?x16x2xf32>
+func.func @pack_with_dynamic_dims(
+    %src: tensor<?x?xf32>, 
+    %dest: tensor<?x?x16x2xf32>) -> tensor<?x?x16x2xf32> {
+
+  %pack = linalg.pack %src 
+    inner_dims_pos = [1, 0]
+    inner_tiles = [16, 2]
+    into %dest : tensor<?x?xf32> -> tensor<?x?x16x2xf32>
+
   return %pack : tensor<?x?x16x2xf32>
 }
 
@@ -1433,30 +1440,111 @@ func.func @pack_with_dynamic_dims(%src: tensor<?x?xf32>, %dest: tensor<?x?x16x2x
 //  CHECK-DAG: %[[C0_1:.*]] = arith.constant 0 : index
 //  CHECK-DAG: %[[C0_0:.*]] = arith.constant 0 : index
 //  CHECK-DAG: %[[C1_0:.*]] = arith.constant 1 : index
+
+/// Compute mask for xfer_read
 //  CHECK-DAG: %[[D0_0:.*]] = tensor.dim {{.*}} %[[C0_0]] : tensor<?x?xf32>
 //  CHECK-DAG: %[[D1_0:.*]] = tensor.dim {{.*}} %[[C1_0]] : tensor<?x?xf32>
 //      CHECK: %[[MASK:.*]] = vector.create_mask %[[D0_0]], %[[D1_0]] : vector<8x16xi1>
+
+/// --= read =---
 //      CHECK: %[[READ:.*]] = vector.mask %[[MASK]] {
 // CHECK-SAME:   vector.transfer_read %{{.*}}[%[[C0_1]], %[[C0_1]]], %[[CST]]
 // CHECK-SAME:   {in_bounds = [true, true]} : tensor<?x?xf32>, vector<8x16xf32>
 // CHECK-SAME: } : vector<8x16xi1> -> vector<8x16xf32>
+
+/// --= shape_cast =---
 //      CHECK: %[[SC:.*]] = vector.shape_cast %[[READ]] : vector<8x16xf32> to vector<4x2x1x16xf32>
+
+/// --= transpose =---
 //      CHECK: %[[TR:.*]] = vector.transpose %[[SC]], [0, 2, 3, 1] : vector<4x2x1x16xf32> to vector<4x1x16x2xf32>
+
+/// Compute mask for xfer_write
 //  CHECK-DAG: %[[C0_2:.*]] = arith.constant 0 : index
 //  CHECK-DAG: %[[C16:.*]] = arith.constant 16 : index
 //  CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index
 //  CHECK-DAG: %[[D2:.*]] = tensor.dim %[[DEST]], {{.*}} : tensor<?x?x16x2xf32>
 //  CHECK-DAG: %[[D3:.*]] = tensor.dim %[[DEST]], {{.*}} : tensor<?x?x16x2xf32>
 //      CHECK: %[[MASK_0:.*]] = vector.create_mask %[[D2]], %[[D3]], %[[C16]], %[[C2]] : vector<4x1x16x2xi1>
+
+/// --= write =---
 //      CHECK: %[[WRITE:.*]] = vector.mask %[[MASK_0]] {
 // CHECK-SAME:   vector.transfer_write %[[TR]], %[[DEST]][%[[C0_2]], %[[C0_2]], %[[C0_2]], %[[C0_2]]]
 // CHECK-SAME:   {in_bounds = [true, true, true, true]} : vector<4x1x16x2xf32>, tensor<?x?x16x2xf32>
+
 //      CHECK: return %[[WRITE]] : tensor<?x?x16x2xf32>
 
 module attributes {transform.with_named_sequence} {
   transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
     %0 = transform.structured.match ops{["linalg.pack"]} in %arg0 : (!transform.any_op) -> !transform.any_op
-    transform.structured.vectorize %0 vector_sizes [4, 1] : !transform.any_op
+    transform.structured.vectorize %0 vector_sizes [4, 1, 16, 2] : !transform.any_op
+    transform.yield
+  }
+}
+
+// -----
+
+/// Similar to the test above, but one of the inner tile sizes is dynamic. As a
+/// result, more output dims are dynamic (and, e.g., output mask calcuation is a bit different).
+
+// CHECK-LABEL: func @pack_with_dynamic_dims_and_dynamic_inner_tile
+// CHECK-SAME:      %[[SRC:.*]]: tensor<?x?xf32>,
+// CHECK-SAME:     ...
[truncated]

@banach-space banach-space changed the title users/banach space/linalg/vectorize pack [mlir][linalg] Update vectorizatio of linalg.pack Oct 15, 2025
@rengolin rengolin requested a review from adam-smnk October 15, 2025 10:19
@banach-space banach-space force-pushed the users/banach-space/linalg/vectorize_pack branch from 3f2aacb to e2ec90e Compare October 15, 2025 12:52
@hanhanW hanhanW changed the title [mlir][linalg] Update vectorizatio of linalg.pack [mlir][linalg] Update vectorization of linalg.pack Oct 15, 2025
Base automatically changed from users/banach-space/linalg/update_vec_tests to main October 16, 2025 08:25
@banach-space banach-space force-pushed the users/banach-space/linalg/vectorize_pack branch 2 times, most recently from aa3cea2 to a26b730 Compare October 16, 2025 09:48
@github-actions
Copy link

github-actions bot commented Oct 16, 2025

✅ With the latest revision this PR passed the C/C++ code formatter.

This patch changes `vectorizeAsTensorPackOp` to require users to specify
all write-side vector sizes for `linalg.pack` (not just the outer
dimensions). This makes `linalg.pack` vectorization consistent with
`linalg.unpack` (see #149293 for a similar change).

Conceptually, `linalg.pack` consists of these high-level steps:
  * **Read** from the source tensor using `vector.transfer_read`.
  * **Re-associate** dimensions of the transposed value, as specified by
    the op (via `vector.shape_cast`)
  * **Transpose** the re-associated value according to the permutation
    in the `linalg.pack` op (via `vector.transpose`).
  * **Write** the result into the destination tensor via
    `vector.transfer_write`.

Previously, the vector sizes provided by the user were interpreted as
write-vector-sizes for PackOp _outer_ dims (i.e. the final step above).
These were used to:
  * Infer read-vector-sizes using the `inner_tiles` attribute of PackOp.
  * Deduce vector sizes for the transpose and shape cast operations.
  * Ultimately determine the vector shape for the read.

However, this logic breaks when one or more tile sizes are dynamic (*).
In such cases, `vectorizePackOpPrecondition` would currently fail (see
`@pack_with_dynamic_dims_and_dynamic_inner_tile` added in this PR -
without this change it will crash).

This patch updates the contract: users now directly specify _all_ the
"write-vector-sizes", which inherently encode all inner tile sizes - including
dynamic ones. It becomes the user's responsibility to provide valid sizes.

In practice, since `linalg.pack` is typically constructed, tiled, and
vectorized by the same transformation pipeline, the necessary
"write-vector-sizes" should be recoverable.

Notes for reviewers:
  * See test updates for user-facing impact.
  * Review `vectorizeAsTensorPackOp` as a new implementation rather than
    a diff.
  * Comments and variable names were updated to align with
    `vectorizeAsTensorUnPackOp`.

(*) As a concrete example, "scalable" tile sizes are represent as
dynamic values. Note, support for "scalable" vectorisation will be added
in a separate PR.
@banach-space banach-space force-pushed the users/banach-space/linalg/vectorize_pack branch from a26b730 to bea41f5 Compare October 16, 2025 10:28
// before any outer or inner permutations have been applied.
PackingMetadata packingMetadata = computePackingMetadata(
packedTensorType.getRank(), packOp.getInnerDimsPos());
PackingMetadata packMetadata;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll take a better look at this this week, but just as an initial comment, what exactly does the packMetadata do over here? Is it supposed to replace the packingMetadata?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch, I reverted this change.

Copy link
Contributor

@hanhanW hanhanW left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, I'll prepare a fix for downstream IREE project.

/// that packOp uses destination rank whereas unpack Uses source rank.
SmallVector<int64_t> getPackInverseDestPerm(linalg::PackOp packOp);
SmallVector<int64_t> getPackInverseDestPerm(linalg::PackOp packOp,
PackingMetadata &metadatap);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

pMetadata? It does not match the implementation; I think packingMetadata looks better, as you are exposing it as a function argument. Or it can just be metadata like the other function, i.e. getUnPackInverseSrcPerm. The doc needs to be updated as well.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, that's a typo. Let me update this to match getUnPackInverseSrcPerm. I will also update the docs for both hooks - in fact, I will make them much shorter. Right now, IMHO, they are too long and go into implementation details that should be left for the implementation itself:

/// Shell function to compute the Destination Permutation of PackOp
/// This function uses the helper function `computePackUnPackPerm` to get
/// the permutation vector. Only major difference between UnPack and Pack is
/// that packOp uses destination rank whereas unpack Uses source rank.

I will also remove this helper hook which doesn't seem to be required (at least based on "upstream"):

SmallVector<int64_t> getUnPackInverseSrcPerm(linalg::UnPackOp unpackOp);

Comment on lines 1492 to 1504
func.func @pack_with_dynamic_dims_and_dynamic_inner_tile(
%src: tensor<?x?xf32>,
%dest: tensor<?x?x?x2xf32>) -> tensor<?x?x?x2xf32> {

%c16 = arith.constant 16 : index

%pack = linalg.pack %src
inner_dims_pos = [1, 0]
inner_tiles = [%c16, 2]
into %dest : tensor<?x?xf32> -> tensor<?x?x?x2xf32>

return %pack : tensor<?x?x?x2xf32>
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

optional: I'd drop blank lines to make it fit window better.

if (writeVectorSizes.empty()) {
if (ShapedType::isDynamicShape(destShape))
return rewriter.notifyMatchFailure(packOp,
"Unable to infer vector sizes!");
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We usually start the first sentence with a lowercase letter, and finish the last sentence without a period/exclamation mark .

https://llvm.org/docs/CodingStandards.html#error-and-warning-messages

Copy link
Contributor Author

@banach-space banach-space Oct 25, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for the reminder!

@hanhanW
Copy link
Contributor

hanhanW commented Oct 23, 2025

Testing IREE in iree-org/iree#22400

Note, I deleted `getUnPackInverseSrcPerm(UnPackOp unpackOp)` - I
couldn't find any uses of that hook.
@banach-space banach-space requested a review from hanhanW October 25, 2025 19:46
/// This function uses the helper function `computePackUnPackPerm` to get
/// the permutation vector. Only major difference between UnPack and Pack is
/// that packOp uses destination rank whereas unpack Uses source rank.
/// Compute inverse permutation for the destination tensor (i.e. in the packed
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

picky nit: I'd add a blank line before this comment. It looks easier to me; it is a new chunk of the declaration.

Copy link
Contributor

@egebeysel egebeysel left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks and sorry for the delayed review!

SmallVector<int64_t> writeVecSizesUnpermuted(writeVectorSizes);
applyPermutationToVector(writeVecSizesUnpermuted, destInvPermutation);

VectorType readVecType = getCollapsedVecType(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: this is a little hard to follow at first glance IMO

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants