Skip to content

Conversation

@banach-space
Copy link
Contributor

@banach-space banach-space commented Sep 10, 2025

Adds an end-to-end test for computing matrix-multiplication using
linalg.mmt4d, combined with "scalable" tiling and "scalable"
vectorisation. This is similar to an existing example that does not use
"scalable" sizes:

  • test/Integration/Dialect/Linalg/CPU/pack-unpack-mmt4d.mlir

@llvmbot
Copy link
Member

llvmbot commented Sep 10, 2025

@llvm/pr-subscribers-mlir-sve

@llvm/pr-subscribers-mlir-vector

Author: Andrzej Warzyński (banach-space)

Changes
  • [mlir][vector] Add a new TD op to wrap unit-dim collapsing patterns
  • Add missing LIT excludes
  • Remove TestVectorTransferCollapseInnerMostContiguousDims
  • [mlir][test] Add e2e test for linalg.mmt4d + SVE

Patch is 25.09 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/157815.diff

7 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td (+14)
  • (modified) mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp (+5)
  • (added) mlir/test/Dialect/Vector/lit.local.cfg (+2)
  • (added) mlir/test/Dialect/Vector/td/xfer-drop-unit-dims.mlir (+11)
  • (modified) mlir/test/Dialect/Vector/vector-transfer-collapse-inner-most-dims.mlir (+3-1)
  • (added) mlir/test/Integration/Dialect/Linalg/CPU/ArmSVE/pack-unpack-mmt4d.mlir (+398)
  • (modified) mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp (-32)
diff --git a/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td b/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td
index 07a4117a37b2c..85d0b2a28c65b 100644
--- a/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td
+++ b/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td
@@ -85,6 +85,20 @@ def ApplyDropUnitDimWithShapeCastPatternsOp : Op<Transform_Dialect,
   let assemblyFormat = "attr-dict";
 }
 
+def ApplyDropInnerMostUnitDimsFromXferOpsPatternsOp : Op<Transform_Dialect,
+    "apply_patterns.vector.drop_inner_most_unit_dims_from_xfer_ops",
+    [DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
+  let description = [{
+    Apply vector patterns to drop the inner most unit dims from
+    vector.transfer_read and vector.transfer_write Ops by taking a subview (via
+    memref.subview) of the original source/destination MemRef. Since it
+    requires the input/ouptu to be MemRefs, this Op is only helpful
+    past-bufferization.
+  }];
+
+  let assemblyFormat = "attr-dict";
+}
+
 def ApplyTransferPermutationPatternsOp : Op<Transform_Dialect,
     "apply_patterns.vector.transfer_permutation_patterns",
     [DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
diff --git a/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp b/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
index fe066dc04ad55..1bad9221df915 100644
--- a/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
+++ b/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
@@ -88,6 +88,11 @@ void transform::ApplyDropUnitDimWithShapeCastPatternsOp::populatePatterns(
   vector::populateDropUnitDimWithShapeCastPatterns(patterns);
 }
 
+void transform::ApplyDropInnerMostUnitDimsFromXferOpsPatternsOp::
+    populatePatterns(RewritePatternSet &patterns) {
+  vector::populateDropInnerMostUnitDimsXferOpPatterns(patterns);
+}
+
 void transform::ApplyLowerBitCastPatternsOp::populatePatterns(
     RewritePatternSet &patterns) {
   vector::populateVectorBitCastLoweringPatterns(patterns);
diff --git a/mlir/test/Dialect/Vector/lit.local.cfg b/mlir/test/Dialect/Vector/lit.local.cfg
new file mode 100644
index 0000000000000..62743008a3e3a
--- /dev/null
+++ b/mlir/test/Dialect/Vector/lit.local.cfg
@@ -0,0 +1,2 @@
+# Skip the directory with input TD sequences
+config.excludes = ["td"]
diff --git a/mlir/test/Dialect/Vector/td/xfer-drop-unit-dims.mlir b/mlir/test/Dialect/Vector/td/xfer-drop-unit-dims.mlir
new file mode 100644
index 0000000000000..5bffa20842b0c
--- /dev/null
+++ b/mlir/test/Dialect/Vector/td/xfer-drop-unit-dims.mlir
@@ -0,0 +1,11 @@
+module @transforms attributes { transform.with_named_sequence } {
+  transform.named_sequence @drop_unit_dims(%module: !transform.any_op {transform.readonly}) {
+
+    %func_op = transform.structured.match ops{["func.func"]} in %module : (!transform.any_op) -> !transform.op<"func.func">
+    transform.apply_patterns to %func_op {
+      transform.apply_patterns.vector.drop_inner_most_unit_dims_from_xfer_ops
+    } : !transform.op<"func.func">
+
+    transform.yield
+   }
+}
diff --git a/mlir/test/Dialect/Vector/vector-transfer-collapse-inner-most-dims.mlir b/mlir/test/Dialect/Vector/vector-transfer-collapse-inner-most-dims.mlir
index cd56c1bf9695b..18c28799a62e5 100644
--- a/mlir/test/Dialect/Vector/vector-transfer-collapse-inner-most-dims.mlir
+++ b/mlir/test/Dialect/Vector/vector-transfer-collapse-inner-most-dims.mlir
@@ -1,4 +1,6 @@
-// RUN: mlir-opt %s -test-vector-transfer-collapse-inner-most-dims -split-input-file | FileCheck %s
+// RUN: mlir-opt -split-input-file \
+// RUN: -transform-preload-library='transform-library-paths=%p/td/xfer-drop-unit-dims.mlir' \
+// RUN: -transform-interpreter=entry-point=drop_unit_dims %s | FileCheck %s
 
 //-----------------------------------------------------------------------------
 // 1. vector.transfer_read
diff --git a/mlir/test/Integration/Dialect/Linalg/CPU/ArmSVE/pack-unpack-mmt4d.mlir b/mlir/test/Integration/Dialect/Linalg/CPU/ArmSVE/pack-unpack-mmt4d.mlir
new file mode 100644
index 0000000000000..d001353ef1d7e
--- /dev/null
+++ b/mlir/test/Integration/Dialect/Linalg/CPU/ArmSVE/pack-unpack-mmt4d.mlir
@@ -0,0 +1,398 @@
+// DEFINE: %{compile} =  mlir-opt %s \
+// DEFINE:    -transform-interpreter -test-transform-dialect-erase-schedule \
+// DEFINE:    -cse -canonicalize -test-lower-to-llvm
+// DEFINE: %{entry_point} = main
+// DEFINE: %{run} = mlir-runner -e %{entry_point} -entry-point-result=void \
+// DEFINE:    -shared-libs=%mlir_runner_utils,%mlir_c_runner_utils
+
+// RUN: %{compile} | %{run} | FileCheck %s
+
+//===----------------------------------------------------------------------===//
+/// HIGH-LEVEL OVERVIEW
+///
+/// End-to-end test for computing matrix-multiplication using linalg.mmt4d. In
+/// particular, demonstrates how the following MLIR sequence (implemented in
+/// @matmul_via_mmt4d):
+///
+///   A_pack = linalg.pack A
+///   B_pack = linalg.pack B
+///   C_pack = linalg.pack C
+///   out_pack = linalg.mmt4d(A_pack, B_pack, C_pack)
+///
+/// is equivalent to:
+///
+///  linalg.matmul(A, B, C)
+///
+/// (implemented in @matmul_via_matmul).
+///
+/// NOTES ON IMPLEMENTATION
+/// 1. The MMT4D example uses _scalable_ tile sizes for data tiling.
+///   * The matrix-multiplication dimension that's scalable: N.
+///
+/// 2. The lowering of linalg.mmt4d leverages scalable vectorisation.
+///   * The matrix-multiplication dimension that's scalable: N (to match data
+///     tiling configuration).
+///
+/// 3. Neither `linalg.pack` nor `linalg.unpack` are vectorised ATM.
+///
+/// 4. The MMT4D and Pack/Unpack Ops are kept in seperate functions to isolate
+///    the corresponding lowering and lowering configs.
+///   * TODO: Ideally, we should consider fusion opportunities by moving these
+///     Ops into one function.
+//===----------------------------------------------------------------------===//
+
+//===----------------------------------------------------------------------===//
+// @main
+//
+// The main entry point that computes matrix multiplication via linalg.mmt4d
+// and linalg.matmul. Note, the output should be independent of the underlying
+// Linalg Op used, as well as SVE vector length.
+//===----------------------------------------------------------------------===//
+func.func @main() {
+  // Allocate and initialise the inputs
+  %A_empty = tensor.empty() : tensor<7x16xi32>
+  %B_empty = tensor.empty() : tensor<16x13xi32>
+
+  %c3 = arith.constant 3 : i32
+  %c4 = arith.constant 4 : i32
+  %A = linalg.fill ins(%c3 : i32) outs(%A_empty : tensor<7x16xi32>) -> tensor<7x16xi32>
+  %B = linalg.fill ins(%c4 : i32) outs(%B_empty : tensor<16x13xi32>) -> tensor<16x13xi32>
+  %C = arith.constant dense<[
+    [ 1,  8, 15, 22, 29, 36, 43, 50, 57, 64, 71, 78, 85],
+    [ 2,  9, 16, 23, 30, 37, 44, 51, 58, 65, 72, 79, 86],
+    [ 3, 10, 17, 24, 31, 38, 45, 52, 59, 66, 73, 80, 87],
+    [ 4, 11, 18, 25, 32, 39, 46, 53, 60, 67, 74, 81, 88],
+    [ 5, 12, 19, 26, 33, 40, 47, 54, 61, 68, 75, 82, 89],
+    [ 6, 13, 20, 27, 34, 41, 48, 55, 62, 69, 76, 83, 90],
+    [ 7, 14, 21, 28, 35, 42, 49, 56, 63, 70, 77, 84, 91]
+  ]> : tensor<7x13xi32>
+
+  // VARIANT: Matrix multiplication via linalg.mmt4d
+  // CHECK: Unranked Memref
+  // CHECK:  [193,   200,   207,   214,   221,   228,   235,   242,   249,   256,   263,   270,   277]
+  // CHECK:  [194,   201,   208,   215,   222,   229,   236,   243,   250,   257,   264,   271,   278]
+  // CHECK:  [195,   202,   209,   216,   223,   230,   237,   244,   251,   258,   265,   272,   279]
+  // CHECK:  [196,   203,   210,   217,   224,   231,   238,   245,   252,   259,   266,   273,   280]
+  // CHECK:  [197,   204,   211,   218,   225,   232,   239,   246,   253,   260,   267,   274,   281]
+  // CHECK:  [198,   205,   212,   219,   226,   233,   240,   247,   254,   261,   268,   275,   282]
+  // CHECK:  [199,   206,   213,   220,   227,   234,   241,   248,   255,   262,   269,   276,   283]
+  %C_mmt4d = func.call @matmul_via_mmt4d(%A, %B, %C) : (tensor<7x16xi32>, tensor<16x13xi32>, tensor<7x13xi32>) -> tensor<7x13xi32>
+  %C_mmt4d_cast = tensor.cast %C_mmt4d : tensor<7x13xi32> to tensor<*xi32>
+  vector.print str "--------------------------\n"
+  vector.print str "RESULT FROM linalg.mmt4d:\n"
+  vector.print str "--------------------------\n"
+  call @printMemrefI32(%C_mmt4d_cast) : (tensor<*xi32>) -> ()
+
+  // VARIANT: Matrix multiplication via linalg.matmul
+  // CHECK: Unranked Memref
+  // CHECK:  [193,   200,   207,   214,   221,   228,   235,   242,   249,   256,   263,   270,   277]
+  // CHECK:  [194,   201,   208,   215,   222,   229,   236,   243,   250,   257,   264,   271,   278]
+  // CHECK:  [195,   202,   209,   216,   223,   230,   237,   244,   251,   258,   265,   272,   279]
+  // CHECK:  [196,   203,   210,   217,   224,   231,   238,   245,   252,   259,   266,   273,   280]
+  // CHECK:  [197,   204,   211,   218,   225,   232,   239,   246,   253,   260,   267,   274,   281]
+  // CHECK:  [198,   205,   212,   219,   226,   233,   240,   247,   254,   261,   268,   275,   282]
+  // CHECK:  [199,   206,   213,   220,   227,   234,   241,   248,   255,   262,   269,   276,   283]
+  %C_matmul = func.call @matmul(%A, %B, %C) : (tensor<7x16xi32>, tensor<16x13xi32>, tensor<7x13xi32>) -> tensor<7x13xi32>
+  %C_matmul_cast = tensor.cast %C_matmul : tensor<7x13xi32> to tensor<*xi32>
+  vector.print str "\n--------------------------\n"
+  vector.print str "RESULT FROM linalg.matmul:\n"
+  vector.print str "--------------------------\n"
+  call @printMemrefI32(%C_matmul_cast) : (tensor<*xi32>) -> ()
+
+  return
+}
+
+//===----------------------------------------------------------------------===//
+// @matmul_via_matmul
+//
+// Implements matrix-multiplication via linalg.matmul
+//===----------------------------------------------------------------------===//
+func.func private @matmul(%A: tensor<7x16xi32>, %B: tensor<16x13xi32>, %C: tensor<7x13xi32>) -> tensor<7x13xi32> {
+  %C_matmul = linalg.matmul ins(%A, %B: tensor<7x16xi32>, tensor<16x13xi32>)
+                            outs(%C: tensor<7x13xi32>) -> tensor<7x13xi32>
+
+  return %C_matmul : tensor<7x13xi32>
+}
+
+//===----------------------------------------------------------------------===//
+// @matmul_via_mmt4d
+//
+// Implements matrix-multiplication via linalg.mmt4d
+//===----------------------------------------------------------------------===//
+func.func private @pack_lhs(%A: tensor<7x16xi32>) -> tensor<1x16x8x1xi32> {
+  %pad = arith.constant 0 : i32
+
+  %A_pack_empty = tensor.empty() : tensor<1x16x8x1xi32>
+  %A_pack = linalg.pack %A
+    padding_value(%pad : i32)
+    inner_dims_pos = [0, 1]
+    inner_tiles = [8, 1]
+    into %A_pack_empty : tensor<7x16xi32> -> tensor<1x16x8x1xi32>
+
+  return %A_pack : tensor<1x16x8x1xi32>
+}
+
+//===----------------------------------------------------------------------===//
+// @pack_rhs
+//
+// Implements packing for the B matrix (RHS) in matrix multiplication. The
+// inner tile size is "scalable": 8 * vscale.
+//===----------------------------------------------------------------------===//
+func.func private @pack_rhs(%B: tensor<16x13xi32>) ->  tensor<?x16x?x1xi32> {
+  %pad = arith.constant 0 : i32
+
+  // Compute the outer tile size.
+  %vs = vector.vscale
+  %c8 = arith.constant 8 : index
+  %vs_c8 = arith.muli %vs, %c8 : index
+  %c13 = arith.constant 13 : index
+  %outer_tile_size = arith.ceildivui %c13, %vs_c8 : index
+
+  %B_pack_empty = tensor.empty(%outer_tile_size, %vs_c8) : tensor<?x16x?x1xi32>
+  %B_pack = linalg.pack %B
+     padding_value(%pad : i32)
+     outer_dims_perm = [1, 0]
+     inner_dims_pos = [1, 0]
+     inner_tiles = [%vs_c8, 1]
+     into %B_pack_empty : tensor<16x13xi32> -> tensor<?x16x?x1xi32>
+
+  return %B_pack : tensor<?x16x?x1xi32>
+}
+
+//===----------------------------------------------------------------------===//
+// @pack_acc
+//
+// Implements packing for the C matrix (accumulator) in matrix multiplication.
+// The inner tile size is "scalable": 8 * vscale
+//===----------------------------------------------------------------------===//
+func.func private @pack_acc(%C: tensor<7x13xi32>) -> tensor<1x?x8x?xi32> {
+  %pad = arith.constant 0 : i32
+
+  // Compute the outer tile size.
+  %c13 = arith.constant 13 : index
+  %vs = vector.vscale
+  %c8 = arith.constant 8 : index
+  %vs_c8 = arith.muli %vs, %c8 : index
+  %outer_tile_size = arith.ceildivui %c13, %vs_c8 : index
+
+  %C_pack_empty = tensor.empty(%outer_tile_size, %vs_c8) : tensor<1x?x8x?xi32>
+  %C_pack = linalg.pack %C
+    padding_value(%pad : i32)
+    outer_dims_perm = [0, 1]
+    inner_dims_pos = [0, 1]
+    inner_tiles = [8, %vs_c8] into %C_pack_empty : tensor<7x13xi32> -> tensor<1x?x8x?xi32>
+
+  return %C_pack : tensor<1x?x8x?xi32>
+}
+
+//===----------------------------------------------------------------------===//
+// @unpack_acc
+//
+// Implements unpacking for the C matrix (accumulator) in matrix
+// multiplication. The inner tile size is "scalable": 8 * vscale
+//===----------------------------------------------------------------------===//
+func.func private @unpack_acc(%C_packed: tensor<1x?x8x?xi32>) -> tensor<7x13xi32> {
+  %vs = vector.vscale
+  %c8 = arith.constant 8 : index
+  %vs_c8 = arith.muli %vs, %c8 : index
+
+  %C_out_empty = tensor.empty() : tensor<7x13xi32>
+  %C_out_unpack = linalg.unpack %C_packed
+    outer_dims_perm = [0, 1]
+    inner_dims_pos = [0, 1]
+    inner_tiles = [8, %vs_c8]
+    into %C_out_empty : tensor<1x?x8x?xi32> -> tensor<7x13xi32>
+
+  return %C_out_unpack: tensor<7x13xi32>
+}
+
+//===----------------------------------------------------------------------===//
+//  Helper methods for printing
+//===----------------------------------------------------------------------===//
+func.func private @print_pack_A(%A_pack : tensor<1x16x8x1xi32>) -> () {
+  %A_pack_cast = tensor.cast %A_pack : tensor<1x16x8x1xi32> to tensor<*xi32>
+  call @printMemrefI32(%A_pack_cast) : (tensor<*xi32>) -> ()
+
+  return
+}
+
+func.func private @print_pack_B(%B_pack : tensor<?x16x?x1xi32>) -> () {
+  %B_pack_cast = tensor.cast %B_pack : tensor<?x16x?x1xi32> to tensor<*xi32>
+  call @printMemrefI32(%B_pack_cast) : (tensor<*xi32>) -> ()
+
+  return
+}
+
+func.func private @print_pack_C(%C_pack : tensor<1x?x8x?xi32>) -> () {
+  %C_pack_cast = tensor.cast %C_pack : tensor<1x?x8x?xi32> to tensor<*xi32>
+  call @printMemrefI32(%C_pack_cast) : (tensor<*xi32>) -> ()
+
+  return
+}
+
+//===----------------------------------------------------------------------===//
+// @matmul_via_mmt4d
+//
+// Implements matrix-multiplication via linalg.mmt4d
+//===----------------------------------------------------------------------===//
+func.func private @matmul_via_mmt4d(%A: tensor<7x16xi32>, %B: tensor<16x13xi32>, %C: tensor<7x13xi32>) -> tensor<7x13xi32> {
+  // Pack input matrices
+  %A_pack = func.call @pack_lhs(%A): (tensor<7x16xi32>) -> tensor<1x16x8x1xi32>
+  %B_pack = func.call @pack_rhs(%B): (tensor<16x13xi32>) -> tensor<?x16x?x1xi32>
+  %C_pack = func.call @pack_acc(%C): (tensor<7x13xi32>) -> tensor<1x?x8x?xi32>
+
+  // Print the packed matrices (this is the only _visible_ part that changes
+  // when adjusting the SVE vector size).
+  func.call @print_pack_A(%A_pack) : (tensor<1x16x8x1xi32>) -> ()
+  func.call @print_pack_B(%B_pack) : (tensor<?x16x?x1xi32>) -> ()
+  func.call @print_pack_C(%C_pack) : (tensor<1x?x8x?xi32>) -> ()
+
+  // MMT4D
+  %mmt4d = linalg.mmt4d ins(%A_pack, %B_pack : tensor<1x16x8x1xi32>, tensor<?x16x?x1xi32>) outs(%C_pack : tensor<1x?x8x?xi32>) -> tensor<1x?x8x?xi32>
+
+  // Unpack the output
+  %C_out_unpack = func.call @unpack_acc(%mmt4d) : (tensor<1x?x8x?xi32>) -> tensor<7x13xi32>
+
+  return %C_out_unpack : tensor<7x13xi32>
+}
+
+//===----------------------------------------------------------------------===//
+// TD Sequence
+//===----------------------------------------------------------------------===//
+module @transforms attributes { transform.with_named_sequence } {
+  transform.named_sequence @__transform_main(%module: !transform.any_op {transform.consumed}) {
+    //==========================================================================
+    // HANDLE MMT4D
+    //==========================================================================
+    %mmt4d = transform.collect_matching @match_mmt4d in %module : (!transform.any_op) -> (!transform.any_op)
+    %mmt4d_func = transform.get_parent_op %mmt4d {isolated_from_above} : (!transform.any_op) -> !transform.op<"func.func">
+
+    // Step 1: Tile
+    // Tile parallel dims (note, the N dim is scalable!)
+    %tiled_mmt4d_parallel, %_:4 = transform.structured.tile_using_for %mmt4d tile_sizes [1, 1, 0, 8, [8], 0]
+      : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op)
+    // Tile reduction dims
+    %tiled_mmt4d, %_1:2 = transform.structured.tile_using_for %tiled_mmt4d_parallel tile_sizes [0, 0, 1, 0, 0, 1]
+      : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op)
+
+    // Step 2: Vectorize linalg.mmt4d (note, the N dim is scalable!)
+    transform.structured.vectorize %tiled_mmt4d
+      vector_sizes  [1, 1, 1, 8, [8], 1]  {assume_dynamic_dims_match_vec_sizes} : !transform.any_op
+
+    // Step 3: Simplify
+    // vector.multi_reduction --> vector.contract
+    // Generates a 6-dim vector.contract with the dim matching the original MMT4D Op
+    // and with the following split into parallel and reduction dims:
+    //    * parallel, parallel, reduction, parallel, parallel, reduction
+    transform.apply_patterns to %mmt4d_func {
+      transform.apply_patterns.vector.reduction_to_contract
+      // Reduce the rank of xfer ops. This transforms vector.contract to be
+      // more matmul-like and to enable the lowering to outer product Ops.
+      transform.apply_patterns.vector.transfer_permutation_patterns
+    } : !transform.op<"func.func">
+
+   // Hoisting and LICM - not strictly required
+   %mmt4d_func_h = transform.structured.hoist_redundant_vector_transfers %mmt4d_func
+     : (!transform.op<"func.func">) -> !transform.op<"func.func">
+   %all_loops = transform.structured.match interface{LoopLikeInterface} in %mmt4d_func_h
+     : (!transform.op<"func.func">) -> !transform.any_op
+   transform.apply_licm to %all_loops : !transform.any_op
+   transform.loop.hoist_loop_invariant_subsets %all_loops : !transform.any_op
+
+   // Simplification
+   transform.apply_patterns to %mmt4d_func_h {
+     transform.apply_patterns.vector.reduction_to_contract
+     transform.apply_patterns.vector.cast_away_vector_leading_one_dim
+     transform.apply_patterns.canonicalization
+   } : !transform.op<"func.func">
+
+   //==========================================================================
+   // HANDLE PACK + UNPACK
+   //==========================================================================
+    %pack = transform.structured.match ops{["linalg.pack"]} in %module : (!transform.any_op) -> !transform.any_op
+    %unpack = transform.structured.match ops{["linalg.unpack"]} in %module : (!transform.any_op) -> !transform.any_op
+
+    // 1.1 Tile the linalg.pack Op so that we can decompose it into e.g. tensor.pad
+    //    and other lower-level Ops (see step 2.1)
+    %tiled_pack_op_p, %loops_pack:2 = transform.structured.tile_using_for %pack tile_sizes [1, 1]
+       : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op)
+
+    // 1.2 Tile the linalg.unpack Op so that we can decompose it into e.g. tensor.pad
+    //    and other lower-level Ops (see step 2)
+    %tiled_unpack_op_p, %loops_unpack:2 = transform.structured.tile_using_for %unpack tile_sizes [8, 1]
+       : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op)
+
+    // 2.1. Decompose tiled PackOp into lower-level Ops + simplify
+    %func_op_pack = transform.get_parent_op %tiled_pack_op_p {isolated_from_above} : (!transform.any_op) -> !transform.op<"func.func">
+    transform.apply_patterns to %func_op_pack {
+      transform.apply_patterns.linalg.decompose_pack_unpack
+      transform.apply_patterns.linalg.decompose_pad
+    } : !transform.op<"func.func">
+
+    transform.apply_patterns to %func_op_pack {
+      transform.apply_patterns.tensor.fold_tensor_subset_ops
+  ...
[truncated]

@llvmbot
Copy link
Member

llvmbot commented Sep 10, 2025

@llvm/pr-subscribers-mlir-linalg

Author: Andrzej Warzyński (banach-space)

Changes
  • [mlir][vector] Add a new TD op to wrap unit-dim collapsing patterns
  • Add missing LIT excludes
  • Remove TestVectorTransferCollapseInnerMostContiguousDims
  • [mlir][test] Add e2e test for linalg.mmt4d + SVE

Patch is 25.09 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/157815.diff

7 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td (+14)
  • (modified) mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp (+5)
  • (added) mlir/test/Dialect/Vector/lit.local.cfg (+2)
  • (added) mlir/test/Dialect/Vector/td/xfer-drop-unit-dims.mlir (+11)
  • (modified) mlir/test/Dialect/Vector/vector-transfer-collapse-inner-most-dims.mlir (+3-1)
  • (added) mlir/test/Integration/Dialect/Linalg/CPU/ArmSVE/pack-unpack-mmt4d.mlir (+398)
  • (modified) mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp (-32)
diff --git a/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td b/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td
index 07a4117a37b2c..85d0b2a28c65b 100644
--- a/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td
+++ b/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td
@@ -85,6 +85,20 @@ def ApplyDropUnitDimWithShapeCastPatternsOp : Op<Transform_Dialect,
   let assemblyFormat = "attr-dict";
 }
 
+def ApplyDropInnerMostUnitDimsFromXferOpsPatternsOp : Op<Transform_Dialect,
+    "apply_patterns.vector.drop_inner_most_unit_dims_from_xfer_ops",
+    [DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
+  let description = [{
+    Apply vector patterns to drop the inner most unit dims from
+    vector.transfer_read and vector.transfer_write Ops by taking a subview (via
+    memref.subview) of the original source/destination MemRef. Since it
+    requires the input/ouptu to be MemRefs, this Op is only helpful
+    past-bufferization.
+  }];
+
+  let assemblyFormat = "attr-dict";
+}
+
 def ApplyTransferPermutationPatternsOp : Op<Transform_Dialect,
     "apply_patterns.vector.transfer_permutation_patterns",
     [DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
diff --git a/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp b/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
index fe066dc04ad55..1bad9221df915 100644
--- a/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
+++ b/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
@@ -88,6 +88,11 @@ void transform::ApplyDropUnitDimWithShapeCastPatternsOp::populatePatterns(
   vector::populateDropUnitDimWithShapeCastPatterns(patterns);
 }
 
+void transform::ApplyDropInnerMostUnitDimsFromXferOpsPatternsOp::
+    populatePatterns(RewritePatternSet &patterns) {
+  vector::populateDropInnerMostUnitDimsXferOpPatterns(patterns);
+}
+
 void transform::ApplyLowerBitCastPatternsOp::populatePatterns(
     RewritePatternSet &patterns) {
   vector::populateVectorBitCastLoweringPatterns(patterns);
diff --git a/mlir/test/Dialect/Vector/lit.local.cfg b/mlir/test/Dialect/Vector/lit.local.cfg
new file mode 100644
index 0000000000000..62743008a3e3a
--- /dev/null
+++ b/mlir/test/Dialect/Vector/lit.local.cfg
@@ -0,0 +1,2 @@
+# Skip the directory with input TD sequences
+config.excludes = ["td"]
diff --git a/mlir/test/Dialect/Vector/td/xfer-drop-unit-dims.mlir b/mlir/test/Dialect/Vector/td/xfer-drop-unit-dims.mlir
new file mode 100644
index 0000000000000..5bffa20842b0c
--- /dev/null
+++ b/mlir/test/Dialect/Vector/td/xfer-drop-unit-dims.mlir
@@ -0,0 +1,11 @@
+module @transforms attributes { transform.with_named_sequence } {
+  transform.named_sequence @drop_unit_dims(%module: !transform.any_op {transform.readonly}) {
+
+    %func_op = transform.structured.match ops{["func.func"]} in %module : (!transform.any_op) -> !transform.op<"func.func">
+    transform.apply_patterns to %func_op {
+      transform.apply_patterns.vector.drop_inner_most_unit_dims_from_xfer_ops
+    } : !transform.op<"func.func">
+
+    transform.yield
+   }
+}
diff --git a/mlir/test/Dialect/Vector/vector-transfer-collapse-inner-most-dims.mlir b/mlir/test/Dialect/Vector/vector-transfer-collapse-inner-most-dims.mlir
index cd56c1bf9695b..18c28799a62e5 100644
--- a/mlir/test/Dialect/Vector/vector-transfer-collapse-inner-most-dims.mlir
+++ b/mlir/test/Dialect/Vector/vector-transfer-collapse-inner-most-dims.mlir
@@ -1,4 +1,6 @@
-// RUN: mlir-opt %s -test-vector-transfer-collapse-inner-most-dims -split-input-file | FileCheck %s
+// RUN: mlir-opt -split-input-file \
+// RUN: -transform-preload-library='transform-library-paths=%p/td/xfer-drop-unit-dims.mlir' \
+// RUN: -transform-interpreter=entry-point=drop_unit_dims %s | FileCheck %s
 
 //-----------------------------------------------------------------------------
 // 1. vector.transfer_read
diff --git a/mlir/test/Integration/Dialect/Linalg/CPU/ArmSVE/pack-unpack-mmt4d.mlir b/mlir/test/Integration/Dialect/Linalg/CPU/ArmSVE/pack-unpack-mmt4d.mlir
new file mode 100644
index 0000000000000..d001353ef1d7e
--- /dev/null
+++ b/mlir/test/Integration/Dialect/Linalg/CPU/ArmSVE/pack-unpack-mmt4d.mlir
@@ -0,0 +1,398 @@
+// DEFINE: %{compile} =  mlir-opt %s \
+// DEFINE:    -transform-interpreter -test-transform-dialect-erase-schedule \
+// DEFINE:    -cse -canonicalize -test-lower-to-llvm
+// DEFINE: %{entry_point} = main
+// DEFINE: %{run} = mlir-runner -e %{entry_point} -entry-point-result=void \
+// DEFINE:    -shared-libs=%mlir_runner_utils,%mlir_c_runner_utils
+
+// RUN: %{compile} | %{run} | FileCheck %s
+
+//===----------------------------------------------------------------------===//
+/// HIGH-LEVEL OVERVIEW
+///
+/// End-to-end test for computing matrix-multiplication using linalg.mmt4d. In
+/// particular, demonstrates how the following MLIR sequence (implemented in
+/// @matmul_via_mmt4d):
+///
+///   A_pack = linalg.pack A
+///   B_pack = linalg.pack B
+///   C_pack = linalg.pack C
+///   out_pack = linalg.mmt4d(A_pack, B_pack, C_pack)
+///
+/// is equivalent to:
+///
+///  linalg.matmul(A, B, C)
+///
+/// (implemented in @matmul_via_matmul).
+///
+/// NOTES ON IMPLEMENTATION
+/// 1. The MMT4D example uses _scalable_ tile sizes for data tiling.
+///   * The matrix-multiplication dimension that's scalable: N.
+///
+/// 2. The lowering of linalg.mmt4d leverages scalable vectorisation.
+///   * The matrix-multiplication dimension that's scalable: N (to match data
+///     tiling configuration).
+///
+/// 3. Neither `linalg.pack` nor `linalg.unpack` are vectorised ATM.
+///
+/// 4. The MMT4D and Pack/Unpack Ops are kept in seperate functions to isolate
+///    the corresponding lowering and lowering configs.
+///   * TODO: Ideally, we should consider fusion opportunities by moving these
+///     Ops into one function.
+//===----------------------------------------------------------------------===//
+
+//===----------------------------------------------------------------------===//
+// @main
+//
+// The main entry point that computes matrix multiplication via linalg.mmt4d
+// and linalg.matmul. Note, the output should be independent of the underlying
+// Linalg Op used, as well as SVE vector length.
+//===----------------------------------------------------------------------===//
+func.func @main() {
+  // Allocate and initialise the inputs
+  %A_empty = tensor.empty() : tensor<7x16xi32>
+  %B_empty = tensor.empty() : tensor<16x13xi32>
+
+  %c3 = arith.constant 3 : i32
+  %c4 = arith.constant 4 : i32
+  %A = linalg.fill ins(%c3 : i32) outs(%A_empty : tensor<7x16xi32>) -> tensor<7x16xi32>
+  %B = linalg.fill ins(%c4 : i32) outs(%B_empty : tensor<16x13xi32>) -> tensor<16x13xi32>
+  %C = arith.constant dense<[
+    [ 1,  8, 15, 22, 29, 36, 43, 50, 57, 64, 71, 78, 85],
+    [ 2,  9, 16, 23, 30, 37, 44, 51, 58, 65, 72, 79, 86],
+    [ 3, 10, 17, 24, 31, 38, 45, 52, 59, 66, 73, 80, 87],
+    [ 4, 11, 18, 25, 32, 39, 46, 53, 60, 67, 74, 81, 88],
+    [ 5, 12, 19, 26, 33, 40, 47, 54, 61, 68, 75, 82, 89],
+    [ 6, 13, 20, 27, 34, 41, 48, 55, 62, 69, 76, 83, 90],
+    [ 7, 14, 21, 28, 35, 42, 49, 56, 63, 70, 77, 84, 91]
+  ]> : tensor<7x13xi32>
+
+  // VARIANT: Matrix multiplication via linalg.mmt4d
+  // CHECK: Unranked Memref
+  // CHECK:  [193,   200,   207,   214,   221,   228,   235,   242,   249,   256,   263,   270,   277]
+  // CHECK:  [194,   201,   208,   215,   222,   229,   236,   243,   250,   257,   264,   271,   278]
+  // CHECK:  [195,   202,   209,   216,   223,   230,   237,   244,   251,   258,   265,   272,   279]
+  // CHECK:  [196,   203,   210,   217,   224,   231,   238,   245,   252,   259,   266,   273,   280]
+  // CHECK:  [197,   204,   211,   218,   225,   232,   239,   246,   253,   260,   267,   274,   281]
+  // CHECK:  [198,   205,   212,   219,   226,   233,   240,   247,   254,   261,   268,   275,   282]
+  // CHECK:  [199,   206,   213,   220,   227,   234,   241,   248,   255,   262,   269,   276,   283]
+  %C_mmt4d = func.call @matmul_via_mmt4d(%A, %B, %C) : (tensor<7x16xi32>, tensor<16x13xi32>, tensor<7x13xi32>) -> tensor<7x13xi32>
+  %C_mmt4d_cast = tensor.cast %C_mmt4d : tensor<7x13xi32> to tensor<*xi32>
+  vector.print str "--------------------------\n"
+  vector.print str "RESULT FROM linalg.mmt4d:\n"
+  vector.print str "--------------------------\n"
+  call @printMemrefI32(%C_mmt4d_cast) : (tensor<*xi32>) -> ()
+
+  // VARIANT: Matrix multiplication via linalg.matmul
+  // CHECK: Unranked Memref
+  // CHECK:  [193,   200,   207,   214,   221,   228,   235,   242,   249,   256,   263,   270,   277]
+  // CHECK:  [194,   201,   208,   215,   222,   229,   236,   243,   250,   257,   264,   271,   278]
+  // CHECK:  [195,   202,   209,   216,   223,   230,   237,   244,   251,   258,   265,   272,   279]
+  // CHECK:  [196,   203,   210,   217,   224,   231,   238,   245,   252,   259,   266,   273,   280]
+  // CHECK:  [197,   204,   211,   218,   225,   232,   239,   246,   253,   260,   267,   274,   281]
+  // CHECK:  [198,   205,   212,   219,   226,   233,   240,   247,   254,   261,   268,   275,   282]
+  // CHECK:  [199,   206,   213,   220,   227,   234,   241,   248,   255,   262,   269,   276,   283]
+  %C_matmul = func.call @matmul(%A, %B, %C) : (tensor<7x16xi32>, tensor<16x13xi32>, tensor<7x13xi32>) -> tensor<7x13xi32>
+  %C_matmul_cast = tensor.cast %C_matmul : tensor<7x13xi32> to tensor<*xi32>
+  vector.print str "\n--------------------------\n"
+  vector.print str "RESULT FROM linalg.matmul:\n"
+  vector.print str "--------------------------\n"
+  call @printMemrefI32(%C_matmul_cast) : (tensor<*xi32>) -> ()
+
+  return
+}
+
+//===----------------------------------------------------------------------===//
+// @matmul_via_matmul
+//
+// Implements matrix-multiplication via linalg.matmul
+//===----------------------------------------------------------------------===//
+func.func private @matmul(%A: tensor<7x16xi32>, %B: tensor<16x13xi32>, %C: tensor<7x13xi32>) -> tensor<7x13xi32> {
+  %C_matmul = linalg.matmul ins(%A, %B: tensor<7x16xi32>, tensor<16x13xi32>)
+                            outs(%C: tensor<7x13xi32>) -> tensor<7x13xi32>
+
+  return %C_matmul : tensor<7x13xi32>
+}
+
+//===----------------------------------------------------------------------===//
+// @matmul_via_mmt4d
+//
+// Implements matrix-multiplication via linalg.mmt4d
+//===----------------------------------------------------------------------===//
+func.func private @pack_lhs(%A: tensor<7x16xi32>) -> tensor<1x16x8x1xi32> {
+  %pad = arith.constant 0 : i32
+
+  %A_pack_empty = tensor.empty() : tensor<1x16x8x1xi32>
+  %A_pack = linalg.pack %A
+    padding_value(%pad : i32)
+    inner_dims_pos = [0, 1]
+    inner_tiles = [8, 1]
+    into %A_pack_empty : tensor<7x16xi32> -> tensor<1x16x8x1xi32>
+
+  return %A_pack : tensor<1x16x8x1xi32>
+}
+
+//===----------------------------------------------------------------------===//
+// @pack_rhs
+//
+// Implements packing for the B matrix (RHS) in matrix multiplication. The
+// inner tile size is "scalable": 8 * vscale.
+//===----------------------------------------------------------------------===//
+func.func private @pack_rhs(%B: tensor<16x13xi32>) ->  tensor<?x16x?x1xi32> {
+  %pad = arith.constant 0 : i32
+
+  // Compute the outer tile size.
+  %vs = vector.vscale
+  %c8 = arith.constant 8 : index
+  %vs_c8 = arith.muli %vs, %c8 : index
+  %c13 = arith.constant 13 : index
+  %outer_tile_size = arith.ceildivui %c13, %vs_c8 : index
+
+  %B_pack_empty = tensor.empty(%outer_tile_size, %vs_c8) : tensor<?x16x?x1xi32>
+  %B_pack = linalg.pack %B
+     padding_value(%pad : i32)
+     outer_dims_perm = [1, 0]
+     inner_dims_pos = [1, 0]
+     inner_tiles = [%vs_c8, 1]
+     into %B_pack_empty : tensor<16x13xi32> -> tensor<?x16x?x1xi32>
+
+  return %B_pack : tensor<?x16x?x1xi32>
+}
+
+//===----------------------------------------------------------------------===//
+// @pack_acc
+//
+// Implements packing for the C matrix (accumulator) in matrix multiplication.
+// The inner tile size is "scalable": 8 * vscale
+//===----------------------------------------------------------------------===//
+func.func private @pack_acc(%C: tensor<7x13xi32>) -> tensor<1x?x8x?xi32> {
+  %pad = arith.constant 0 : i32
+
+  // Compute the outer tile size.
+  %c13 = arith.constant 13 : index
+  %vs = vector.vscale
+  %c8 = arith.constant 8 : index
+  %vs_c8 = arith.muli %vs, %c8 : index
+  %outer_tile_size = arith.ceildivui %c13, %vs_c8 : index
+
+  %C_pack_empty = tensor.empty(%outer_tile_size, %vs_c8) : tensor<1x?x8x?xi32>
+  %C_pack = linalg.pack %C
+    padding_value(%pad : i32)
+    outer_dims_perm = [0, 1]
+    inner_dims_pos = [0, 1]
+    inner_tiles = [8, %vs_c8] into %C_pack_empty : tensor<7x13xi32> -> tensor<1x?x8x?xi32>
+
+  return %C_pack : tensor<1x?x8x?xi32>
+}
+
+//===----------------------------------------------------------------------===//
+// @unpack_acc
+//
+// Implements unpacking for the C matrix (accumulator) in matrix
+// multiplication. The inner tile size is "scalable": 8 * vscale
+//===----------------------------------------------------------------------===//
+func.func private @unpack_acc(%C_packed: tensor<1x?x8x?xi32>) -> tensor<7x13xi32> {
+  %vs = vector.vscale
+  %c8 = arith.constant 8 : index
+  %vs_c8 = arith.muli %vs, %c8 : index
+
+  %C_out_empty = tensor.empty() : tensor<7x13xi32>
+  %C_out_unpack = linalg.unpack %C_packed
+    outer_dims_perm = [0, 1]
+    inner_dims_pos = [0, 1]
+    inner_tiles = [8, %vs_c8]
+    into %C_out_empty : tensor<1x?x8x?xi32> -> tensor<7x13xi32>
+
+  return %C_out_unpack: tensor<7x13xi32>
+}
+
+//===----------------------------------------------------------------------===//
+//  Helper methods for printing
+//===----------------------------------------------------------------------===//
+func.func private @print_pack_A(%A_pack : tensor<1x16x8x1xi32>) -> () {
+  %A_pack_cast = tensor.cast %A_pack : tensor<1x16x8x1xi32> to tensor<*xi32>
+  call @printMemrefI32(%A_pack_cast) : (tensor<*xi32>) -> ()
+
+  return
+}
+
+func.func private @print_pack_B(%B_pack : tensor<?x16x?x1xi32>) -> () {
+  %B_pack_cast = tensor.cast %B_pack : tensor<?x16x?x1xi32> to tensor<*xi32>
+  call @printMemrefI32(%B_pack_cast) : (tensor<*xi32>) -> ()
+
+  return
+}
+
+func.func private @print_pack_C(%C_pack : tensor<1x?x8x?xi32>) -> () {
+  %C_pack_cast = tensor.cast %C_pack : tensor<1x?x8x?xi32> to tensor<*xi32>
+  call @printMemrefI32(%C_pack_cast) : (tensor<*xi32>) -> ()
+
+  return
+}
+
+//===----------------------------------------------------------------------===//
+// @matmul_via_mmt4d
+//
+// Implements matrix-multiplication via linalg.mmt4d
+//===----------------------------------------------------------------------===//
+func.func private @matmul_via_mmt4d(%A: tensor<7x16xi32>, %B: tensor<16x13xi32>, %C: tensor<7x13xi32>) -> tensor<7x13xi32> {
+  // Pack input matrices
+  %A_pack = func.call @pack_lhs(%A): (tensor<7x16xi32>) -> tensor<1x16x8x1xi32>
+  %B_pack = func.call @pack_rhs(%B): (tensor<16x13xi32>) -> tensor<?x16x?x1xi32>
+  %C_pack = func.call @pack_acc(%C): (tensor<7x13xi32>) -> tensor<1x?x8x?xi32>
+
+  // Print the packed matrices (this is the only _visible_ part that changes
+  // when adjusting the SVE vector size).
+  func.call @print_pack_A(%A_pack) : (tensor<1x16x8x1xi32>) -> ()
+  func.call @print_pack_B(%B_pack) : (tensor<?x16x?x1xi32>) -> ()
+  func.call @print_pack_C(%C_pack) : (tensor<1x?x8x?xi32>) -> ()
+
+  // MMT4D
+  %mmt4d = linalg.mmt4d ins(%A_pack, %B_pack : tensor<1x16x8x1xi32>, tensor<?x16x?x1xi32>) outs(%C_pack : tensor<1x?x8x?xi32>) -> tensor<1x?x8x?xi32>
+
+  // Unpack the output
+  %C_out_unpack = func.call @unpack_acc(%mmt4d) : (tensor<1x?x8x?xi32>) -> tensor<7x13xi32>
+
+  return %C_out_unpack : tensor<7x13xi32>
+}
+
+//===----------------------------------------------------------------------===//
+// TD Sequence
+//===----------------------------------------------------------------------===//
+module @transforms attributes { transform.with_named_sequence } {
+  transform.named_sequence @__transform_main(%module: !transform.any_op {transform.consumed}) {
+    //==========================================================================
+    // HANDLE MMT4D
+    //==========================================================================
+    %mmt4d = transform.collect_matching @match_mmt4d in %module : (!transform.any_op) -> (!transform.any_op)
+    %mmt4d_func = transform.get_parent_op %mmt4d {isolated_from_above} : (!transform.any_op) -> !transform.op<"func.func">
+
+    // Step 1: Tile
+    // Tile parallel dims (note, the N dim is scalable!)
+    %tiled_mmt4d_parallel, %_:4 = transform.structured.tile_using_for %mmt4d tile_sizes [1, 1, 0, 8, [8], 0]
+      : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op)
+    // Tile reduction dims
+    %tiled_mmt4d, %_1:2 = transform.structured.tile_using_for %tiled_mmt4d_parallel tile_sizes [0, 0, 1, 0, 0, 1]
+      : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op)
+
+    // Step 2: Vectorize linalg.mmt4d (note, the N dim is scalable!)
+    transform.structured.vectorize %tiled_mmt4d
+      vector_sizes  [1, 1, 1, 8, [8], 1]  {assume_dynamic_dims_match_vec_sizes} : !transform.any_op
+
+    // Step 3: Simplify
+    // vector.multi_reduction --> vector.contract
+    // Generates a 6-dim vector.contract with the dim matching the original MMT4D Op
+    // and with the following split into parallel and reduction dims:
+    //    * parallel, parallel, reduction, parallel, parallel, reduction
+    transform.apply_patterns to %mmt4d_func {
+      transform.apply_patterns.vector.reduction_to_contract
+      // Reduce the rank of xfer ops. This transforms vector.contract to be
+      // more matmul-like and to enable the lowering to outer product Ops.
+      transform.apply_patterns.vector.transfer_permutation_patterns
+    } : !transform.op<"func.func">
+
+   // Hoisting and LICM - not strictly required
+   %mmt4d_func_h = transform.structured.hoist_redundant_vector_transfers %mmt4d_func
+     : (!transform.op<"func.func">) -> !transform.op<"func.func">
+   %all_loops = transform.structured.match interface{LoopLikeInterface} in %mmt4d_func_h
+     : (!transform.op<"func.func">) -> !transform.any_op
+   transform.apply_licm to %all_loops : !transform.any_op
+   transform.loop.hoist_loop_invariant_subsets %all_loops : !transform.any_op
+
+   // Simplification
+   transform.apply_patterns to %mmt4d_func_h {
+     transform.apply_patterns.vector.reduction_to_contract
+     transform.apply_patterns.vector.cast_away_vector_leading_one_dim
+     transform.apply_patterns.canonicalization
+   } : !transform.op<"func.func">
+
+   //==========================================================================
+   // HANDLE PACK + UNPACK
+   //==========================================================================
+    %pack = transform.structured.match ops{["linalg.pack"]} in %module : (!transform.any_op) -> !transform.any_op
+    %unpack = transform.structured.match ops{["linalg.unpack"]} in %module : (!transform.any_op) -> !transform.any_op
+
+    // 1.1 Tile the linalg.pack Op so that we can decompose it into e.g. tensor.pad
+    //    and other lower-level Ops (see step 2.1)
+    %tiled_pack_op_p, %loops_pack:2 = transform.structured.tile_using_for %pack tile_sizes [1, 1]
+       : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op)
+
+    // 1.2 Tile the linalg.unpack Op so that we can decompose it into e.g. tensor.pad
+    //    and other lower-level Ops (see step 2)
+    %tiled_unpack_op_p, %loops_unpack:2 = transform.structured.tile_using_for %unpack tile_sizes [8, 1]
+       : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op)
+
+    // 2.1. Decompose tiled PackOp into lower-level Ops + simplify
+    %func_op_pack = transform.get_parent_op %tiled_pack_op_p {isolated_from_above} : (!transform.any_op) -> !transform.op<"func.func">
+    transform.apply_patterns to %func_op_pack {
+      transform.apply_patterns.linalg.decompose_pack_unpack
+      transform.apply_patterns.linalg.decompose_pad
+    } : !transform.op<"func.func">
+
+    transform.apply_patterns to %func_op_pack {
+      transform.apply_patterns.tensor.fold_tensor_subset_ops
+  ...
[truncated]

@llvmbot
Copy link
Member

llvmbot commented Sep 10, 2025

@llvm/pr-subscribers-mlir

Author: Andrzej Warzyński (banach-space)

Changes
  • [mlir][vector] Add a new TD op to wrap unit-dim collapsing patterns
  • Add missing LIT excludes
  • Remove TestVectorTransferCollapseInnerMostContiguousDims
  • [mlir][test] Add e2e test for linalg.mmt4d + SVE

Patch is 25.09 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/157815.diff

7 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td (+14)
  • (modified) mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp (+5)
  • (added) mlir/test/Dialect/Vector/lit.local.cfg (+2)
  • (added) mlir/test/Dialect/Vector/td/xfer-drop-unit-dims.mlir (+11)
  • (modified) mlir/test/Dialect/Vector/vector-transfer-collapse-inner-most-dims.mlir (+3-1)
  • (added) mlir/test/Integration/Dialect/Linalg/CPU/ArmSVE/pack-unpack-mmt4d.mlir (+398)
  • (modified) mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp (-32)
diff --git a/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td b/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td
index 07a4117a37b2c..85d0b2a28c65b 100644
--- a/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td
+++ b/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td
@@ -85,6 +85,20 @@ def ApplyDropUnitDimWithShapeCastPatternsOp : Op<Transform_Dialect,
   let assemblyFormat = "attr-dict";
 }
 
+def ApplyDropInnerMostUnitDimsFromXferOpsPatternsOp : Op<Transform_Dialect,
+    "apply_patterns.vector.drop_inner_most_unit_dims_from_xfer_ops",
+    [DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
+  let description = [{
+    Apply vector patterns to drop the inner most unit dims from
+    vector.transfer_read and vector.transfer_write Ops by taking a subview (via
+    memref.subview) of the original source/destination MemRef. Since it
+    requires the input/ouptu to be MemRefs, this Op is only helpful
+    past-bufferization.
+  }];
+
+  let assemblyFormat = "attr-dict";
+}
+
 def ApplyTransferPermutationPatternsOp : Op<Transform_Dialect,
     "apply_patterns.vector.transfer_permutation_patterns",
     [DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
diff --git a/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp b/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
index fe066dc04ad55..1bad9221df915 100644
--- a/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
+++ b/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
@@ -88,6 +88,11 @@ void transform::ApplyDropUnitDimWithShapeCastPatternsOp::populatePatterns(
   vector::populateDropUnitDimWithShapeCastPatterns(patterns);
 }
 
+void transform::ApplyDropInnerMostUnitDimsFromXferOpsPatternsOp::
+    populatePatterns(RewritePatternSet &patterns) {
+  vector::populateDropInnerMostUnitDimsXferOpPatterns(patterns);
+}
+
 void transform::ApplyLowerBitCastPatternsOp::populatePatterns(
     RewritePatternSet &patterns) {
   vector::populateVectorBitCastLoweringPatterns(patterns);
diff --git a/mlir/test/Dialect/Vector/lit.local.cfg b/mlir/test/Dialect/Vector/lit.local.cfg
new file mode 100644
index 0000000000000..62743008a3e3a
--- /dev/null
+++ b/mlir/test/Dialect/Vector/lit.local.cfg
@@ -0,0 +1,2 @@
+# Skip the directory with input TD sequences
+config.excludes = ["td"]
diff --git a/mlir/test/Dialect/Vector/td/xfer-drop-unit-dims.mlir b/mlir/test/Dialect/Vector/td/xfer-drop-unit-dims.mlir
new file mode 100644
index 0000000000000..5bffa20842b0c
--- /dev/null
+++ b/mlir/test/Dialect/Vector/td/xfer-drop-unit-dims.mlir
@@ -0,0 +1,11 @@
+module @transforms attributes { transform.with_named_sequence } {
+  transform.named_sequence @drop_unit_dims(%module: !transform.any_op {transform.readonly}) {
+
+    %func_op = transform.structured.match ops{["func.func"]} in %module : (!transform.any_op) -> !transform.op<"func.func">
+    transform.apply_patterns to %func_op {
+      transform.apply_patterns.vector.drop_inner_most_unit_dims_from_xfer_ops
+    } : !transform.op<"func.func">
+
+    transform.yield
+   }
+}
diff --git a/mlir/test/Dialect/Vector/vector-transfer-collapse-inner-most-dims.mlir b/mlir/test/Dialect/Vector/vector-transfer-collapse-inner-most-dims.mlir
index cd56c1bf9695b..18c28799a62e5 100644
--- a/mlir/test/Dialect/Vector/vector-transfer-collapse-inner-most-dims.mlir
+++ b/mlir/test/Dialect/Vector/vector-transfer-collapse-inner-most-dims.mlir
@@ -1,4 +1,6 @@
-// RUN: mlir-opt %s -test-vector-transfer-collapse-inner-most-dims -split-input-file | FileCheck %s
+// RUN: mlir-opt -split-input-file \
+// RUN: -transform-preload-library='transform-library-paths=%p/td/xfer-drop-unit-dims.mlir' \
+// RUN: -transform-interpreter=entry-point=drop_unit_dims %s | FileCheck %s
 
 //-----------------------------------------------------------------------------
 // 1. vector.transfer_read
diff --git a/mlir/test/Integration/Dialect/Linalg/CPU/ArmSVE/pack-unpack-mmt4d.mlir b/mlir/test/Integration/Dialect/Linalg/CPU/ArmSVE/pack-unpack-mmt4d.mlir
new file mode 100644
index 0000000000000..d001353ef1d7e
--- /dev/null
+++ b/mlir/test/Integration/Dialect/Linalg/CPU/ArmSVE/pack-unpack-mmt4d.mlir
@@ -0,0 +1,398 @@
+// DEFINE: %{compile} =  mlir-opt %s \
+// DEFINE:    -transform-interpreter -test-transform-dialect-erase-schedule \
+// DEFINE:    -cse -canonicalize -test-lower-to-llvm
+// DEFINE: %{entry_point} = main
+// DEFINE: %{run} = mlir-runner -e %{entry_point} -entry-point-result=void \
+// DEFINE:    -shared-libs=%mlir_runner_utils,%mlir_c_runner_utils
+
+// RUN: %{compile} | %{run} | FileCheck %s
+
+//===----------------------------------------------------------------------===//
+/// HIGH-LEVEL OVERVIEW
+///
+/// End-to-end test for computing matrix-multiplication using linalg.mmt4d. In
+/// particular, demonstrates how the following MLIR sequence (implemented in
+/// @matmul_via_mmt4d):
+///
+///   A_pack = linalg.pack A
+///   B_pack = linalg.pack B
+///   C_pack = linalg.pack C
+///   out_pack = linalg.mmt4d(A_pack, B_pack, C_pack)
+///
+/// is equivalent to:
+///
+///  linalg.matmul(A, B, C)
+///
+/// (implemented in @matmul_via_matmul).
+///
+/// NOTES ON IMPLEMENTATION
+/// 1. The MMT4D example uses _scalable_ tile sizes for data tiling.
+///   * The matrix-multiplication dimension that's scalable: N.
+///
+/// 2. The lowering of linalg.mmt4d leverages scalable vectorisation.
+///   * The matrix-multiplication dimension that's scalable: N (to match data
+///     tiling configuration).
+///
+/// 3. Neither `linalg.pack` nor `linalg.unpack` are vectorised ATM.
+///
+/// 4. The MMT4D and Pack/Unpack Ops are kept in seperate functions to isolate
+///    the corresponding lowering and lowering configs.
+///   * TODO: Ideally, we should consider fusion opportunities by moving these
+///     Ops into one function.
+//===----------------------------------------------------------------------===//
+
+//===----------------------------------------------------------------------===//
+// @main
+//
+// The main entry point that computes matrix multiplication via linalg.mmt4d
+// and linalg.matmul. Note, the output should be independent of the underlying
+// Linalg Op used, as well as SVE vector length.
+//===----------------------------------------------------------------------===//
+func.func @main() {
+  // Allocate and initialise the inputs
+  %A_empty = tensor.empty() : tensor<7x16xi32>
+  %B_empty = tensor.empty() : tensor<16x13xi32>
+
+  %c3 = arith.constant 3 : i32
+  %c4 = arith.constant 4 : i32
+  %A = linalg.fill ins(%c3 : i32) outs(%A_empty : tensor<7x16xi32>) -> tensor<7x16xi32>
+  %B = linalg.fill ins(%c4 : i32) outs(%B_empty : tensor<16x13xi32>) -> tensor<16x13xi32>
+  %C = arith.constant dense<[
+    [ 1,  8, 15, 22, 29, 36, 43, 50, 57, 64, 71, 78, 85],
+    [ 2,  9, 16, 23, 30, 37, 44, 51, 58, 65, 72, 79, 86],
+    [ 3, 10, 17, 24, 31, 38, 45, 52, 59, 66, 73, 80, 87],
+    [ 4, 11, 18, 25, 32, 39, 46, 53, 60, 67, 74, 81, 88],
+    [ 5, 12, 19, 26, 33, 40, 47, 54, 61, 68, 75, 82, 89],
+    [ 6, 13, 20, 27, 34, 41, 48, 55, 62, 69, 76, 83, 90],
+    [ 7, 14, 21, 28, 35, 42, 49, 56, 63, 70, 77, 84, 91]
+  ]> : tensor<7x13xi32>
+
+  // VARIANT: Matrix multiplication via linalg.mmt4d
+  // CHECK: Unranked Memref
+  // CHECK:  [193,   200,   207,   214,   221,   228,   235,   242,   249,   256,   263,   270,   277]
+  // CHECK:  [194,   201,   208,   215,   222,   229,   236,   243,   250,   257,   264,   271,   278]
+  // CHECK:  [195,   202,   209,   216,   223,   230,   237,   244,   251,   258,   265,   272,   279]
+  // CHECK:  [196,   203,   210,   217,   224,   231,   238,   245,   252,   259,   266,   273,   280]
+  // CHECK:  [197,   204,   211,   218,   225,   232,   239,   246,   253,   260,   267,   274,   281]
+  // CHECK:  [198,   205,   212,   219,   226,   233,   240,   247,   254,   261,   268,   275,   282]
+  // CHECK:  [199,   206,   213,   220,   227,   234,   241,   248,   255,   262,   269,   276,   283]
+  %C_mmt4d = func.call @matmul_via_mmt4d(%A, %B, %C) : (tensor<7x16xi32>, tensor<16x13xi32>, tensor<7x13xi32>) -> tensor<7x13xi32>
+  %C_mmt4d_cast = tensor.cast %C_mmt4d : tensor<7x13xi32> to tensor<*xi32>
+  vector.print str "--------------------------\n"
+  vector.print str "RESULT FROM linalg.mmt4d:\n"
+  vector.print str "--------------------------\n"
+  call @printMemrefI32(%C_mmt4d_cast) : (tensor<*xi32>) -> ()
+
+  // VARIANT: Matrix multiplication via linalg.matmul
+  // CHECK: Unranked Memref
+  // CHECK:  [193,   200,   207,   214,   221,   228,   235,   242,   249,   256,   263,   270,   277]
+  // CHECK:  [194,   201,   208,   215,   222,   229,   236,   243,   250,   257,   264,   271,   278]
+  // CHECK:  [195,   202,   209,   216,   223,   230,   237,   244,   251,   258,   265,   272,   279]
+  // CHECK:  [196,   203,   210,   217,   224,   231,   238,   245,   252,   259,   266,   273,   280]
+  // CHECK:  [197,   204,   211,   218,   225,   232,   239,   246,   253,   260,   267,   274,   281]
+  // CHECK:  [198,   205,   212,   219,   226,   233,   240,   247,   254,   261,   268,   275,   282]
+  // CHECK:  [199,   206,   213,   220,   227,   234,   241,   248,   255,   262,   269,   276,   283]
+  %C_matmul = func.call @matmul(%A, %B, %C) : (tensor<7x16xi32>, tensor<16x13xi32>, tensor<7x13xi32>) -> tensor<7x13xi32>
+  %C_matmul_cast = tensor.cast %C_matmul : tensor<7x13xi32> to tensor<*xi32>
+  vector.print str "\n--------------------------\n"
+  vector.print str "RESULT FROM linalg.matmul:\n"
+  vector.print str "--------------------------\n"
+  call @printMemrefI32(%C_matmul_cast) : (tensor<*xi32>) -> ()
+
+  return
+}
+
+//===----------------------------------------------------------------------===//
+// @matmul_via_matmul
+//
+// Implements matrix-multiplication via linalg.matmul
+//===----------------------------------------------------------------------===//
+func.func private @matmul(%A: tensor<7x16xi32>, %B: tensor<16x13xi32>, %C: tensor<7x13xi32>) -> tensor<7x13xi32> {
+  %C_matmul = linalg.matmul ins(%A, %B: tensor<7x16xi32>, tensor<16x13xi32>)
+                            outs(%C: tensor<7x13xi32>) -> tensor<7x13xi32>
+
+  return %C_matmul : tensor<7x13xi32>
+}
+
+//===----------------------------------------------------------------------===//
+// @matmul_via_mmt4d
+//
+// Implements matrix-multiplication via linalg.mmt4d
+//===----------------------------------------------------------------------===//
+func.func private @pack_lhs(%A: tensor<7x16xi32>) -> tensor<1x16x8x1xi32> {
+  %pad = arith.constant 0 : i32
+
+  %A_pack_empty = tensor.empty() : tensor<1x16x8x1xi32>
+  %A_pack = linalg.pack %A
+    padding_value(%pad : i32)
+    inner_dims_pos = [0, 1]
+    inner_tiles = [8, 1]
+    into %A_pack_empty : tensor<7x16xi32> -> tensor<1x16x8x1xi32>
+
+  return %A_pack : tensor<1x16x8x1xi32>
+}
+
+//===----------------------------------------------------------------------===//
+// @pack_rhs
+//
+// Implements packing for the B matrix (RHS) in matrix multiplication. The
+// inner tile size is "scalable": 8 * vscale.
+//===----------------------------------------------------------------------===//
+func.func private @pack_rhs(%B: tensor<16x13xi32>) ->  tensor<?x16x?x1xi32> {
+  %pad = arith.constant 0 : i32
+
+  // Compute the outer tile size.
+  %vs = vector.vscale
+  %c8 = arith.constant 8 : index
+  %vs_c8 = arith.muli %vs, %c8 : index
+  %c13 = arith.constant 13 : index
+  %outer_tile_size = arith.ceildivui %c13, %vs_c8 : index
+
+  %B_pack_empty = tensor.empty(%outer_tile_size, %vs_c8) : tensor<?x16x?x1xi32>
+  %B_pack = linalg.pack %B
+     padding_value(%pad : i32)
+     outer_dims_perm = [1, 0]
+     inner_dims_pos = [1, 0]
+     inner_tiles = [%vs_c8, 1]
+     into %B_pack_empty : tensor<16x13xi32> -> tensor<?x16x?x1xi32>
+
+  return %B_pack : tensor<?x16x?x1xi32>
+}
+
+//===----------------------------------------------------------------------===//
+// @pack_acc
+//
+// Implements packing for the C matrix (accumulator) in matrix multiplication.
+// The inner tile size is "scalable": 8 * vscale
+//===----------------------------------------------------------------------===//
+func.func private @pack_acc(%C: tensor<7x13xi32>) -> tensor<1x?x8x?xi32> {
+  %pad = arith.constant 0 : i32
+
+  // Compute the outer tile size.
+  %c13 = arith.constant 13 : index
+  %vs = vector.vscale
+  %c8 = arith.constant 8 : index
+  %vs_c8 = arith.muli %vs, %c8 : index
+  %outer_tile_size = arith.ceildivui %c13, %vs_c8 : index
+
+  %C_pack_empty = tensor.empty(%outer_tile_size, %vs_c8) : tensor<1x?x8x?xi32>
+  %C_pack = linalg.pack %C
+    padding_value(%pad : i32)
+    outer_dims_perm = [0, 1]
+    inner_dims_pos = [0, 1]
+    inner_tiles = [8, %vs_c8] into %C_pack_empty : tensor<7x13xi32> -> tensor<1x?x8x?xi32>
+
+  return %C_pack : tensor<1x?x8x?xi32>
+}
+
+//===----------------------------------------------------------------------===//
+// @unpack_acc
+//
+// Implements unpacking for the C matrix (accumulator) in matrix
+// multiplication. The inner tile size is "scalable": 8 * vscale
+//===----------------------------------------------------------------------===//
+func.func private @unpack_acc(%C_packed: tensor<1x?x8x?xi32>) -> tensor<7x13xi32> {
+  %vs = vector.vscale
+  %c8 = arith.constant 8 : index
+  %vs_c8 = arith.muli %vs, %c8 : index
+
+  %C_out_empty = tensor.empty() : tensor<7x13xi32>
+  %C_out_unpack = linalg.unpack %C_packed
+    outer_dims_perm = [0, 1]
+    inner_dims_pos = [0, 1]
+    inner_tiles = [8, %vs_c8]
+    into %C_out_empty : tensor<1x?x8x?xi32> -> tensor<7x13xi32>
+
+  return %C_out_unpack: tensor<7x13xi32>
+}
+
+//===----------------------------------------------------------------------===//
+//  Helper methods for printing
+//===----------------------------------------------------------------------===//
+func.func private @print_pack_A(%A_pack : tensor<1x16x8x1xi32>) -> () {
+  %A_pack_cast = tensor.cast %A_pack : tensor<1x16x8x1xi32> to tensor<*xi32>
+  call @printMemrefI32(%A_pack_cast) : (tensor<*xi32>) -> ()
+
+  return
+}
+
+func.func private @print_pack_B(%B_pack : tensor<?x16x?x1xi32>) -> () {
+  %B_pack_cast = tensor.cast %B_pack : tensor<?x16x?x1xi32> to tensor<*xi32>
+  call @printMemrefI32(%B_pack_cast) : (tensor<*xi32>) -> ()
+
+  return
+}
+
+func.func private @print_pack_C(%C_pack : tensor<1x?x8x?xi32>) -> () {
+  %C_pack_cast = tensor.cast %C_pack : tensor<1x?x8x?xi32> to tensor<*xi32>
+  call @printMemrefI32(%C_pack_cast) : (tensor<*xi32>) -> ()
+
+  return
+}
+
+//===----------------------------------------------------------------------===//
+// @matmul_via_mmt4d
+//
+// Implements matrix-multiplication via linalg.mmt4d
+//===----------------------------------------------------------------------===//
+func.func private @matmul_via_mmt4d(%A: tensor<7x16xi32>, %B: tensor<16x13xi32>, %C: tensor<7x13xi32>) -> tensor<7x13xi32> {
+  // Pack input matrices
+  %A_pack = func.call @pack_lhs(%A): (tensor<7x16xi32>) -> tensor<1x16x8x1xi32>
+  %B_pack = func.call @pack_rhs(%B): (tensor<16x13xi32>) -> tensor<?x16x?x1xi32>
+  %C_pack = func.call @pack_acc(%C): (tensor<7x13xi32>) -> tensor<1x?x8x?xi32>
+
+  // Print the packed matrices (this is the only _visible_ part that changes
+  // when adjusting the SVE vector size).
+  func.call @print_pack_A(%A_pack) : (tensor<1x16x8x1xi32>) -> ()
+  func.call @print_pack_B(%B_pack) : (tensor<?x16x?x1xi32>) -> ()
+  func.call @print_pack_C(%C_pack) : (tensor<1x?x8x?xi32>) -> ()
+
+  // MMT4D
+  %mmt4d = linalg.mmt4d ins(%A_pack, %B_pack : tensor<1x16x8x1xi32>, tensor<?x16x?x1xi32>) outs(%C_pack : tensor<1x?x8x?xi32>) -> tensor<1x?x8x?xi32>
+
+  // Unpack the output
+  %C_out_unpack = func.call @unpack_acc(%mmt4d) : (tensor<1x?x8x?xi32>) -> tensor<7x13xi32>
+
+  return %C_out_unpack : tensor<7x13xi32>
+}
+
+//===----------------------------------------------------------------------===//
+// TD Sequence
+//===----------------------------------------------------------------------===//
+module @transforms attributes { transform.with_named_sequence } {
+  transform.named_sequence @__transform_main(%module: !transform.any_op {transform.consumed}) {
+    //==========================================================================
+    // HANDLE MMT4D
+    //==========================================================================
+    %mmt4d = transform.collect_matching @match_mmt4d in %module : (!transform.any_op) -> (!transform.any_op)
+    %mmt4d_func = transform.get_parent_op %mmt4d {isolated_from_above} : (!transform.any_op) -> !transform.op<"func.func">
+
+    // Step 1: Tile
+    // Tile parallel dims (note, the N dim is scalable!)
+    %tiled_mmt4d_parallel, %_:4 = transform.structured.tile_using_for %mmt4d tile_sizes [1, 1, 0, 8, [8], 0]
+      : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op)
+    // Tile reduction dims
+    %tiled_mmt4d, %_1:2 = transform.structured.tile_using_for %tiled_mmt4d_parallel tile_sizes [0, 0, 1, 0, 0, 1]
+      : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op)
+
+    // Step 2: Vectorize linalg.mmt4d (note, the N dim is scalable!)
+    transform.structured.vectorize %tiled_mmt4d
+      vector_sizes  [1, 1, 1, 8, [8], 1]  {assume_dynamic_dims_match_vec_sizes} : !transform.any_op
+
+    // Step 3: Simplify
+    // vector.multi_reduction --> vector.contract
+    // Generates a 6-dim vector.contract with the dim matching the original MMT4D Op
+    // and with the following split into parallel and reduction dims:
+    //    * parallel, parallel, reduction, parallel, parallel, reduction
+    transform.apply_patterns to %mmt4d_func {
+      transform.apply_patterns.vector.reduction_to_contract
+      // Reduce the rank of xfer ops. This transforms vector.contract to be
+      // more matmul-like and to enable the lowering to outer product Ops.
+      transform.apply_patterns.vector.transfer_permutation_patterns
+    } : !transform.op<"func.func">
+
+   // Hoisting and LICM - not strictly required
+   %mmt4d_func_h = transform.structured.hoist_redundant_vector_transfers %mmt4d_func
+     : (!transform.op<"func.func">) -> !transform.op<"func.func">
+   %all_loops = transform.structured.match interface{LoopLikeInterface} in %mmt4d_func_h
+     : (!transform.op<"func.func">) -> !transform.any_op
+   transform.apply_licm to %all_loops : !transform.any_op
+   transform.loop.hoist_loop_invariant_subsets %all_loops : !transform.any_op
+
+   // Simplification
+   transform.apply_patterns to %mmt4d_func_h {
+     transform.apply_patterns.vector.reduction_to_contract
+     transform.apply_patterns.vector.cast_away_vector_leading_one_dim
+     transform.apply_patterns.canonicalization
+   } : !transform.op<"func.func">
+
+   //==========================================================================
+   // HANDLE PACK + UNPACK
+   //==========================================================================
+    %pack = transform.structured.match ops{["linalg.pack"]} in %module : (!transform.any_op) -> !transform.any_op
+    %unpack = transform.structured.match ops{["linalg.unpack"]} in %module : (!transform.any_op) -> !transform.any_op
+
+    // 1.1 Tile the linalg.pack Op so that we can decompose it into e.g. tensor.pad
+    //    and other lower-level Ops (see step 2.1)
+    %tiled_pack_op_p, %loops_pack:2 = transform.structured.tile_using_for %pack tile_sizes [1, 1]
+       : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op)
+
+    // 1.2 Tile the linalg.unpack Op so that we can decompose it into e.g. tensor.pad
+    //    and other lower-level Ops (see step 2)
+    %tiled_unpack_op_p, %loops_unpack:2 = transform.structured.tile_using_for %unpack tile_sizes [8, 1]
+       : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op)
+
+    // 2.1. Decompose tiled PackOp into lower-level Ops + simplify
+    %func_op_pack = transform.get_parent_op %tiled_pack_op_p {isolated_from_above} : (!transform.any_op) -> !transform.op<"func.func">
+    transform.apply_patterns to %func_op_pack {
+      transform.apply_patterns.linalg.decompose_pack_unpack
+      transform.apply_patterns.linalg.decompose_pad
+    } : !transform.op<"func.func">
+
+    transform.apply_patterns to %func_op_pack {
+      transform.apply_patterns.tensor.fold_tensor_subset_ops
+  ...
[truncated]

@banach-space banach-space changed the title andrzej/vector/add mmt4d with sve e2e [mlir][test] Add e2e test for linalg.mmt4d + SVE Sep 10, 2025
Adds an end-to-end test for computing matrix-multiplication using
linalg.mmt4d, combined with "scalable" tiling and "scalable"
vectorisation. This is similar to an existing example that does not use
"scalable" sizes:
* test/Integration/Dialect/Linalg/CPU/pack-unpack-mmt4d.mlir
@banach-space banach-space force-pushed the andrzej/vector/add_mmt4d_with_sve_e2e branch from eb702db to 7c26b4c Compare September 12, 2025 10:15
Copy link
Contributor

@egebeysel egebeysel left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall, LGTM!

That's probably something for a separate PR, but can we also add i8mm and bf16 versions of this? I think we should be able to support that at the moment.

///
/// 4. The MMT4D and Pack/Unpack Ops are kept in seperate functions to isolate
/// the corresponding lowering and lowering configs.
/// * TODO: Ideally, we should consider fusion opportunities by moving these
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: can we also add the issue numbers on this to improve on the tiling interface for unpack + packs?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you mean some pre-existing issue?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@banach-space banach-space merged commit cca769a into llvm:main Sep 22, 2025
9 checks passed
@banach-space banach-space deleted the andrzej/vector/add_mmt4d_with_sve_e2e branch September 22, 2025 09:26
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants