Skip to content

Conversation

anemet
Copy link
Contributor

@anemet anemet commented Oct 17, 2025

In the inner loop of matmul, instead of continuously halving the HW vector register width, I just use the remainder vector directly if it's legal.

We don't have in-tree targets that have this so I opted for adding a hidden flag to simulate this for testing purposes: -matrix-split-matmul-remainder=0

The tests are the vectorization-friendly 3x3x1 matrix-vector and 1x3x3 vector-matrix multiplies for CM, RM respectively.

In the inner loop of matmul, instead of continuously halving the HW
vector register width, I just use the remainder vector directly if
it's legal.

We don't have in-tree targets that have this so I opted for adding a
hidden flag to simulate this for testing purposes:
-matrix-split-matmul-remainder=0

The tests are the vectorization-friendly 3x3x1 matrix-vector and
1x3x3 vector-matrix multiplies for CM, RM respectively.
@llvmbot
Copy link
Member

llvmbot commented Oct 17, 2025

@llvm/pr-subscribers-llvm-transforms

Author: Adam Nemet (anemet)

Changes

In the inner loop of matmul, instead of continuously halving the HW vector register width, I just use the remainder vector directly if it's legal.

We don't have in-tree targets that have this so I opted for adding a hidden flag to simulate this for testing purposes: -matrix-split-matmul-remainder=0

The tests are the vectorization-friendly 3x3x1 matrix-vector and 1x3x3 vector-matrix multiplies for CM, RM respectively.


Patch is 21.37 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/163987.diff

3 Files Affected:

  • (modified) llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp (+29-7)
  • (added) llvm/test/Transforms/LowerMatrixIntrinsics/multiply-remainder-rm.ll (+95)
  • (added) llvm/test/Transforms/LowerMatrixIntrinsics/multiply-remainder.ll (+95)
diff --git a/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp b/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp
index 7cae94ebb4ba1..242cbee8b3c9d 100644
--- a/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp
+++ b/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp
@@ -97,6 +97,11 @@ static cl::opt<MatrixLayoutTy> MatrixLayout(
 static cl::opt<bool> PrintAfterTransposeOpt("matrix-print-after-transpose-opt",
                                             cl::init(false));
 
+static cl::opt<bool> SplitMatmulRemainder(
+    "matrix-split-matmul-remainder", cl::Hidden,
+    cl::desc("Split remainder vector in the inner loop of matmul"),
+    cl::init(true));
+
 /// Helper function to either return Scope, if it is a subprogram or the
 /// attached subprogram for a local scope.
 static DISubprogram *getSubprogram(DIScope *Scope) {
@@ -1719,6 +1724,26 @@ class LowerMatrixIntrinsics {
     ToRemove.push_back(MatMul);
   }
 
+  /// Given \p Remainder iterations of the the matmul inner loop,
+  /// potentially lower \p Blocksize that is used for the underlying
+  /// vector.
+  unsigned capBlockSize(unsigned BlockSize, unsigned Remainder, Type *EltType) {
+    if (BlockSize <= Remainder)
+      return BlockSize;
+
+    // If the remainder is also a legal type just use it.
+    if (TTI.isTypeLegal(FixedVectorType::get(EltType, Remainder)) ||
+        !SplitMatmulRemainder)
+      return Remainder;
+
+    // Gradually lower the vectorization factor to cover the
+    // remainder.
+    do {
+      BlockSize /= 2;
+    } while (BlockSize > Remainder);
+    return BlockSize;
+  }
+
   /// Compute \p Result += \p A * \p B for input matrices with left-associating
   /// addition.
   ///
@@ -1756,10 +1781,8 @@ class LowerMatrixIntrinsics {
         bool isSumZero = isa<ConstantAggregateZero>(Result.getColumn(J));
 
         for (unsigned I = 0; I < R; I += BlockSize) {
-          // Gradually lower the vectorization factor to cover the remainder.
-          while (I + BlockSize > R)
-            BlockSize /= 2;
-
+          // Lower block size to make sure we stay within bounds.
+          BlockSize = capBlockSize(BlockSize, R - I, Result.getElementType());
           Value *Sum = IsTiled ? Result.extractVector(I, J, BlockSize, Builder)
                                : nullptr;
           for (unsigned K = 0; K < M; ++K) {
@@ -1784,9 +1807,8 @@ class LowerMatrixIntrinsics {
         unsigned BlockSize = VF;
         bool isSumZero = isa<ConstantAggregateZero>(Result.getRow(I));
         for (unsigned J = 0; J < C; J += BlockSize) {
-          // Gradually lower the vectorization factor to cover the remainder.
-          while (J + BlockSize > C)
-            BlockSize /= 2;
+          // Lower the vectorization factor to cover the remainder.
+          BlockSize = capBlockSize(BlockSize, C - J, Result.getElementType());
 
           Value *Sum = nullptr;
           for (unsigned K = 0; K < M; ++K) {
diff --git a/llvm/test/Transforms/LowerMatrixIntrinsics/multiply-remainder-rm.ll b/llvm/test/Transforms/LowerMatrixIntrinsics/multiply-remainder-rm.ll
new file mode 100644
index 0000000000000..f71aa5305f679
--- /dev/null
+++ b/llvm/test/Transforms/LowerMatrixIntrinsics/multiply-remainder-rm.ll
@@ -0,0 +1,95 @@
+; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 5
+; RUN: opt -passes='lower-matrix-intrinsics'                                  -matrix-default-layout=row-major -S < %s | FileCheck --check-prefix=RM_SPLIT_REMAINDER %s
+; RUN: opt -passes='lower-matrix-intrinsics' -matrix-split-matmul-remainder=0 -matrix-default-layout=row-major -S < %s | FileCheck --check-prefix=RM_NO_SPLIT_REMAINDER %s
+
+; REQUIRES: aarch64-registered-target
+
+target datalayout = "e-m:o-i64:64-f80:128-n8:8:32:64-S128"
+target triple = "aarch64-apple-ios"
+
+define void @matmul(ptr %a, ptr %b, ptr %c) {
+; RM_SPLIT_REMAINDER-LABEL: define void @matmul(
+; RM_SPLIT_REMAINDER-SAME: ptr [[A:%.*]], ptr [[B:%.*]], ptr [[C:%.*]]) {
+; RM_SPLIT_REMAINDER-NEXT:    [[COL_LOAD:%.*]] = load <3 x float>, ptr [[A]], align 4
+; RM_SPLIT_REMAINDER-NEXT:    [[COL_LOAD1:%.*]] = load <3 x float>, ptr [[B]], align 4
+; RM_SPLIT_REMAINDER-NEXT:    [[VEC_GEP:%.*]] = getelementptr float, ptr [[B]], i64 3
+; RM_SPLIT_REMAINDER-NEXT:    [[COL_LOAD2:%.*]] = load <3 x float>, ptr [[VEC_GEP]], align 4
+; RM_SPLIT_REMAINDER-NEXT:    [[VEC_GEP3:%.*]] = getelementptr float, ptr [[B]], i64 6
+; RM_SPLIT_REMAINDER-NEXT:    [[COL_LOAD4:%.*]] = load <3 x float>, ptr [[VEC_GEP3]], align 4
+; RM_SPLIT_REMAINDER-NEXT:    [[BLOCK:%.*]] = shufflevector <3 x float> [[COL_LOAD1]], <3 x float> poison, <2 x i32> <i32 0, i32 1>
+; RM_SPLIT_REMAINDER-NEXT:    [[TMP1:%.*]] = extractelement <3 x float> [[COL_LOAD]], i64 0
+; RM_SPLIT_REMAINDER-NEXT:    [[SPLAT_SPLATINSERT:%.*]] = insertelement <2 x float> poison, float [[TMP1]], i64 0
+; RM_SPLIT_REMAINDER-NEXT:    [[SPLAT_SPLAT:%.*]] = shufflevector <2 x float> [[SPLAT_SPLATINSERT]], <2 x float> poison, <2 x i32> zeroinitializer
+; RM_SPLIT_REMAINDER-NEXT:    [[TMP2:%.*]] = fmul <2 x float> [[SPLAT_SPLAT]], [[BLOCK]]
+; RM_SPLIT_REMAINDER-NEXT:    [[BLOCK5:%.*]] = shufflevector <3 x float> [[COL_LOAD2]], <3 x float> poison, <2 x i32> <i32 0, i32 1>
+; RM_SPLIT_REMAINDER-NEXT:    [[TMP3:%.*]] = extractelement <3 x float> [[COL_LOAD]], i64 1
+; RM_SPLIT_REMAINDER-NEXT:    [[SPLAT_SPLATINSERT6:%.*]] = insertelement <2 x float> poison, float [[TMP3]], i64 0
+; RM_SPLIT_REMAINDER-NEXT:    [[SPLAT_SPLAT7:%.*]] = shufflevector <2 x float> [[SPLAT_SPLATINSERT6]], <2 x float> poison, <2 x i32> zeroinitializer
+; RM_SPLIT_REMAINDER-NEXT:    [[TMP4:%.*]] = fmul <2 x float> [[SPLAT_SPLAT7]], [[BLOCK5]]
+; RM_SPLIT_REMAINDER-NEXT:    [[TMP5:%.*]] = fadd <2 x float> [[TMP2]], [[TMP4]]
+; RM_SPLIT_REMAINDER-NEXT:    [[BLOCK8:%.*]] = shufflevector <3 x float> [[COL_LOAD4]], <3 x float> poison, <2 x i32> <i32 0, i32 1>
+; RM_SPLIT_REMAINDER-NEXT:    [[TMP6:%.*]] = extractelement <3 x float> [[COL_LOAD]], i64 2
+; RM_SPLIT_REMAINDER-NEXT:    [[SPLAT_SPLATINSERT9:%.*]] = insertelement <2 x float> poison, float [[TMP6]], i64 0
+; RM_SPLIT_REMAINDER-NEXT:    [[SPLAT_SPLAT10:%.*]] = shufflevector <2 x float> [[SPLAT_SPLATINSERT9]], <2 x float> poison, <2 x i32> zeroinitializer
+; RM_SPLIT_REMAINDER-NEXT:    [[TMP7:%.*]] = fmul <2 x float> [[SPLAT_SPLAT10]], [[BLOCK8]]
+; RM_SPLIT_REMAINDER-NEXT:    [[TMP8:%.*]] = fadd <2 x float> [[TMP5]], [[TMP7]]
+; RM_SPLIT_REMAINDER-NEXT:    [[TMP9:%.*]] = shufflevector <2 x float> [[TMP8]], <2 x float> poison, <3 x i32> <i32 0, i32 1, i32 poison>
+; RM_SPLIT_REMAINDER-NEXT:    [[TMP10:%.*]] = shufflevector <3 x float> poison, <3 x float> [[TMP9]], <3 x i32> <i32 3, i32 4, i32 2>
+; RM_SPLIT_REMAINDER-NEXT:    [[BLOCK11:%.*]] = shufflevector <3 x float> [[COL_LOAD1]], <3 x float> poison, <1 x i32> <i32 2>
+; RM_SPLIT_REMAINDER-NEXT:    [[TMP11:%.*]] = extractelement <3 x float> [[COL_LOAD]], i64 0
+; RM_SPLIT_REMAINDER-NEXT:    [[SPLAT_SPLATINSERT12:%.*]] = insertelement <1 x float> poison, float [[TMP11]], i64 0
+; RM_SPLIT_REMAINDER-NEXT:    [[SPLAT_SPLAT13:%.*]] = shufflevector <1 x float> [[SPLAT_SPLATINSERT12]], <1 x float> poison, <1 x i32> zeroinitializer
+; RM_SPLIT_REMAINDER-NEXT:    [[TMP12:%.*]] = fmul <1 x float> [[SPLAT_SPLAT13]], [[BLOCK11]]
+; RM_SPLIT_REMAINDER-NEXT:    [[BLOCK14:%.*]] = shufflevector <3 x float> [[COL_LOAD2]], <3 x float> poison, <1 x i32> <i32 2>
+; RM_SPLIT_REMAINDER-NEXT:    [[TMP13:%.*]] = extractelement <3 x float> [[COL_LOAD]], i64 1
+; RM_SPLIT_REMAINDER-NEXT:    [[SPLAT_SPLATINSERT15:%.*]] = insertelement <1 x float> poison, float [[TMP13]], i64 0
+; RM_SPLIT_REMAINDER-NEXT:    [[SPLAT_SPLAT16:%.*]] = shufflevector <1 x float> [[SPLAT_SPLATINSERT15]], <1 x float> poison, <1 x i32> zeroinitializer
+; RM_SPLIT_REMAINDER-NEXT:    [[TMP14:%.*]] = fmul <1 x float> [[SPLAT_SPLAT16]], [[BLOCK14]]
+; RM_SPLIT_REMAINDER-NEXT:    [[TMP15:%.*]] = fadd <1 x float> [[TMP12]], [[TMP14]]
+; RM_SPLIT_REMAINDER-NEXT:    [[BLOCK17:%.*]] = shufflevector <3 x float> [[COL_LOAD4]], <3 x float> poison, <1 x i32> <i32 2>
+; RM_SPLIT_REMAINDER-NEXT:    [[TMP16:%.*]] = extractelement <3 x float> [[COL_LOAD]], i64 2
+; RM_SPLIT_REMAINDER-NEXT:    [[SPLAT_SPLATINSERT18:%.*]] = insertelement <1 x float> poison, float [[TMP16]], i64 0
+; RM_SPLIT_REMAINDER-NEXT:    [[SPLAT_SPLAT19:%.*]] = shufflevector <1 x float> [[SPLAT_SPLATINSERT18]], <1 x float> poison, <1 x i32> zeroinitializer
+; RM_SPLIT_REMAINDER-NEXT:    [[TMP17:%.*]] = fmul <1 x float> [[SPLAT_SPLAT19]], [[BLOCK17]]
+; RM_SPLIT_REMAINDER-NEXT:    [[TMP18:%.*]] = fadd <1 x float> [[TMP15]], [[TMP17]]
+; RM_SPLIT_REMAINDER-NEXT:    [[TMP19:%.*]] = shufflevector <1 x float> [[TMP18]], <1 x float> poison, <3 x i32> <i32 0, i32 poison, i32 poison>
+; RM_SPLIT_REMAINDER-NEXT:    [[TMP20:%.*]] = shufflevector <3 x float> [[TMP10]], <3 x float> [[TMP19]], <3 x i32> <i32 0, i32 1, i32 3>
+; RM_SPLIT_REMAINDER-NEXT:    store <3 x float> [[TMP20]], ptr [[C]], align 4
+; RM_SPLIT_REMAINDER-NEXT:    ret void
+;
+; RM_NO_SPLIT_REMAINDER-LABEL: define void @matmul(
+; RM_NO_SPLIT_REMAINDER-SAME: ptr [[A:%.*]], ptr [[B:%.*]], ptr [[C:%.*]]) {
+; RM_NO_SPLIT_REMAINDER-NEXT:    [[COL_LOAD:%.*]] = load <3 x float>, ptr [[A]], align 4
+; RM_NO_SPLIT_REMAINDER-NEXT:    [[COL_LOAD1:%.*]] = load <3 x float>, ptr [[B]], align 4
+; RM_NO_SPLIT_REMAINDER-NEXT:    [[VEC_GEP:%.*]] = getelementptr float, ptr [[B]], i64 3
+; RM_NO_SPLIT_REMAINDER-NEXT:    [[COL_LOAD2:%.*]] = load <3 x float>, ptr [[VEC_GEP]], align 4
+; RM_NO_SPLIT_REMAINDER-NEXT:    [[VEC_GEP3:%.*]] = getelementptr float, ptr [[B]], i64 6
+; RM_NO_SPLIT_REMAINDER-NEXT:    [[COL_LOAD4:%.*]] = load <3 x float>, ptr [[VEC_GEP3]], align 4
+; RM_NO_SPLIT_REMAINDER-NEXT:    [[BLOCK:%.*]] = shufflevector <3 x float> [[COL_LOAD1]], <3 x float> poison, <3 x i32> <i32 0, i32 1, i32 2>
+; RM_NO_SPLIT_REMAINDER-NEXT:    [[TMP1:%.*]] = extractelement <3 x float> [[COL_LOAD]], i64 0
+; RM_NO_SPLIT_REMAINDER-NEXT:    [[SPLAT_SPLATINSERT:%.*]] = insertelement <3 x float> poison, float [[TMP1]], i64 0
+; RM_NO_SPLIT_REMAINDER-NEXT:    [[SPLAT_SPLAT:%.*]] = shufflevector <3 x float> [[SPLAT_SPLATINSERT]], <3 x float> poison, <3 x i32> zeroinitializer
+; RM_NO_SPLIT_REMAINDER-NEXT:    [[TMP2:%.*]] = fmul <3 x float> [[SPLAT_SPLAT]], [[BLOCK]]
+; RM_NO_SPLIT_REMAINDER-NEXT:    [[BLOCK5:%.*]] = shufflevector <3 x float> [[COL_LOAD2]], <3 x float> poison, <3 x i32> <i32 0, i32 1, i32 2>
+; RM_NO_SPLIT_REMAINDER-NEXT:    [[TMP3:%.*]] = extractelement <3 x float> [[COL_LOAD]], i64 1
+; RM_NO_SPLIT_REMAINDER-NEXT:    [[SPLAT_SPLATINSERT6:%.*]] = insertelement <3 x float> poison, float [[TMP3]], i64 0
+; RM_NO_SPLIT_REMAINDER-NEXT:    [[SPLAT_SPLAT7:%.*]] = shufflevector <3 x float> [[SPLAT_SPLATINSERT6]], <3 x float> poison, <3 x i32> zeroinitializer
+; RM_NO_SPLIT_REMAINDER-NEXT:    [[TMP4:%.*]] = fmul <3 x float> [[SPLAT_SPLAT7]], [[BLOCK5]]
+; RM_NO_SPLIT_REMAINDER-NEXT:    [[TMP5:%.*]] = fadd <3 x float> [[TMP2]], [[TMP4]]
+; RM_NO_SPLIT_REMAINDER-NEXT:    [[BLOCK8:%.*]] = shufflevector <3 x float> [[COL_LOAD4]], <3 x float> poison, <3 x i32> <i32 0, i32 1, i32 2>
+; RM_NO_SPLIT_REMAINDER-NEXT:    [[TMP6:%.*]] = extractelement <3 x float> [[COL_LOAD]], i64 2
+; RM_NO_SPLIT_REMAINDER-NEXT:    [[SPLAT_SPLATINSERT9:%.*]] = insertelement <3 x float> poison, float [[TMP6]], i64 0
+; RM_NO_SPLIT_REMAINDER-NEXT:    [[SPLAT_SPLAT10:%.*]] = shufflevector <3 x float> [[SPLAT_SPLATINSERT9]], <3 x float> poison, <3 x i32> zeroinitializer
+; RM_NO_SPLIT_REMAINDER-NEXT:    [[TMP7:%.*]] = fmul <3 x float> [[SPLAT_SPLAT10]], [[BLOCK8]]
+; RM_NO_SPLIT_REMAINDER-NEXT:    [[TMP8:%.*]] = fadd <3 x float> [[TMP5]], [[TMP7]]
+; RM_NO_SPLIT_REMAINDER-NEXT:    [[TMP9:%.*]] = shufflevector <3 x float> [[TMP8]], <3 x float> poison, <3 x i32> <i32 0, i32 1, i32 2>
+; RM_NO_SPLIT_REMAINDER-NEXT:    [[TMP10:%.*]] = shufflevector <3 x float> poison, <3 x float> [[TMP9]], <3 x i32> <i32 3, i32 4, i32 5>
+; RM_NO_SPLIT_REMAINDER-NEXT:    store <3 x float> [[TMP10]], ptr [[C]], align 4
+; RM_NO_SPLIT_REMAINDER-NEXT:    ret void
+;
+  %a_load = load <3 x float>, ptr %a, align 4
+  %b_load = load <9 x float>, ptr %b, align 4
+  %matmul = tail call <3 x float> @llvm.matrix.multiply.v3f32.v9f32.v3f32(<3 x float> %a_load, <9 x float> %b_load, i32 1, i32 3, i32 3)
+  store <3 x float> %matmul, ptr %c, align 4
+  ret void
+}
diff --git a/llvm/test/Transforms/LowerMatrixIntrinsics/multiply-remainder.ll b/llvm/test/Transforms/LowerMatrixIntrinsics/multiply-remainder.ll
new file mode 100644
index 0000000000000..b60c3858c31c6
--- /dev/null
+++ b/llvm/test/Transforms/LowerMatrixIntrinsics/multiply-remainder.ll
@@ -0,0 +1,95 @@
+; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 5
+; RUN: opt -passes='lower-matrix-intrinsics'                                  -S < %s | FileCheck --check-prefix=CM_SPLIT_REMAINDER %s
+; RUN: opt -passes='lower-matrix-intrinsics' -matrix-split-matmul-remainder=0 -S < %s | FileCheck --check-prefix=CM_NO_SPLIT_REMAINDER %s
+
+; REQUIRES: aarch64-registered-target
+
+target datalayout = "e-m:o-i64:64-f80:128-n8:8:32:64-S128"
+target triple = "aarch64-apple-ios"
+
+define void @matmul(ptr %a, ptr %b, ptr %c) {
+; CM_SPLIT_REMAINDER-LABEL: define void @matmul(
+; CM_SPLIT_REMAINDER-SAME: ptr [[A:%.*]], ptr [[B:%.*]], ptr [[C:%.*]]) {
+; CM_SPLIT_REMAINDER-NEXT:    [[COL_LOAD:%.*]] = load <3 x float>, ptr [[A]], align 4
+; CM_SPLIT_REMAINDER-NEXT:    [[VEC_GEP:%.*]] = getelementptr float, ptr [[A]], i64 3
+; CM_SPLIT_REMAINDER-NEXT:    [[COL_LOAD1:%.*]] = load <3 x float>, ptr [[VEC_GEP]], align 4
+; CM_SPLIT_REMAINDER-NEXT:    [[VEC_GEP2:%.*]] = getelementptr float, ptr [[A]], i64 6
+; CM_SPLIT_REMAINDER-NEXT:    [[COL_LOAD3:%.*]] = load <3 x float>, ptr [[VEC_GEP2]], align 4
+; CM_SPLIT_REMAINDER-NEXT:    [[COL_LOAD4:%.*]] = load <3 x float>, ptr [[B]], align 4
+; CM_SPLIT_REMAINDER-NEXT:    [[BLOCK:%.*]] = shufflevector <3 x float> [[COL_LOAD]], <3 x float> poison, <2 x i32> <i32 0, i32 1>
+; CM_SPLIT_REMAINDER-NEXT:    [[TMP1:%.*]] = extractelement <3 x float> [[COL_LOAD4]], i64 0
+; CM_SPLIT_REMAINDER-NEXT:    [[SPLAT_SPLATINSERT:%.*]] = insertelement <2 x float> poison, float [[TMP1]], i64 0
+; CM_SPLIT_REMAINDER-NEXT:    [[SPLAT_SPLAT:%.*]] = shufflevector <2 x float> [[SPLAT_SPLATINSERT]], <2 x float> poison, <2 x i32> zeroinitializer
+; CM_SPLIT_REMAINDER-NEXT:    [[TMP2:%.*]] = fmul <2 x float> [[BLOCK]], [[SPLAT_SPLAT]]
+; CM_SPLIT_REMAINDER-NEXT:    [[BLOCK5:%.*]] = shufflevector <3 x float> [[COL_LOAD1]], <3 x float> poison, <2 x i32> <i32 0, i32 1>
+; CM_SPLIT_REMAINDER-NEXT:    [[TMP3:%.*]] = extractelement <3 x float> [[COL_LOAD4]], i64 1
+; CM_SPLIT_REMAINDER-NEXT:    [[SPLAT_SPLATINSERT6:%.*]] = insertelement <2 x float> poison, float [[TMP3]], i64 0
+; CM_SPLIT_REMAINDER-NEXT:    [[SPLAT_SPLAT7:%.*]] = shufflevector <2 x float> [[SPLAT_SPLATINSERT6]], <2 x float> poison, <2 x i32> zeroinitializer
+; CM_SPLIT_REMAINDER-NEXT:    [[TMP4:%.*]] = fmul <2 x float> [[BLOCK5]], [[SPLAT_SPLAT7]]
+; CM_SPLIT_REMAINDER-NEXT:    [[TMP5:%.*]] = fadd <2 x float> [[TMP2]], [[TMP4]]
+; CM_SPLIT_REMAINDER-NEXT:    [[BLOCK8:%.*]] = shufflevector <3 x float> [[COL_LOAD3]], <3 x float> poison, <2 x i32> <i32 0, i32 1>
+; CM_SPLIT_REMAINDER-NEXT:    [[TMP6:%.*]] = extractelement <3 x float> [[COL_LOAD4]], i64 2
+; CM_SPLIT_REMAINDER-NEXT:    [[SPLAT_SPLATINSERT9:%.*]] = insertelement <2 x float> poison, float [[TMP6]], i64 0
+; CM_SPLIT_REMAINDER-NEXT:    [[SPLAT_SPLAT10:%.*]] = shufflevector <2 x float> [[SPLAT_SPLATINSERT9]], <2 x float> poison, <2 x i32> zeroinitializer
+; CM_SPLIT_REMAINDER-NEXT:    [[TMP7:%.*]] = fmul <2 x float> [[BLOCK8]], [[SPLAT_SPLAT10]]
+; CM_SPLIT_REMAINDER-NEXT:    [[TMP8:%.*]] = fadd <2 x float> [[TMP5]], [[TMP7]]
+; CM_SPLIT_REMAINDER-NEXT:    [[TMP9:%.*]] = shufflevector <2 x float> [[TMP8]], <2 x float> poison, <3 x i32> <i32 0, i32 1, i32 poison>
+; CM_SPLIT_REMAINDER-NEXT:    [[TMP10:%.*]] = shufflevector <3 x float> poison, <3 x float> [[TMP9]], <3 x i32> <i32 3, i32 4, i32 2>
+; CM_SPLIT_REMAINDER-NEXT:    [[BLOCK11:%.*]] = shufflevector <3 x float> [[COL_LOAD]], <3 x float> poison, <1 x i32> <i32 2>
+; CM_SPLIT_REMAINDER-NEXT:    [[TMP11:%.*]] = extractelement <3 x float> [[COL_LOAD4]], i64 0
+; CM_SPLIT_REMAINDER-NEXT:    [[SPLAT_SPLATINSERT12:%.*]] = insertelement <1 x float> poison, float [[TMP11]], i64 0
+; CM_SPLIT_REMAINDER-NEXT:    [[SPLAT_SPLAT13:%.*]] = shufflevector <1 x float> [[SPLAT_SPLATINSERT12]], <1 x float> poison, <1 x i32> zeroinitializer
+; CM_SPLIT_REMAINDER-NEXT:    [[TMP12:%.*]] = fmul <1 x float> [[BLOCK11]], [[SPLAT_SPLAT13]]
+; CM_SPLIT_REMAINDER-NEXT:    [[BLOCK14:%.*]] = shufflevector <3 x float> [[COL_LOAD1]], <3 x float> poison, <1 x i32> <i32 2>
+; CM_SPLIT_REMAINDER-NEXT:    [[TMP13:%.*]] = extractelement <3 x float> [[COL_LOAD4]], i64 1
+; CM_SPLIT_REMAINDER-NEXT:    [[SPLAT_SPLATINSERT15:%.*]] = insertelement <1 x float> poison, float [[TMP13]], i64 0
+; CM_SPLIT_REMAINDER-NEXT:    [[SPLAT_SPLAT16:%.*]] = shufflevector <1 x float> [[SPLAT_SPLATINSERT15]], <1 x float> poison, <1 x i32> zeroinitializer
+; CM_SPLIT_REMAINDER-NEXT:    [[TMP14:%.*]] = fmul <1 x float> [[BLOCK14]], [[SPLAT_SPLAT16]]
+; CM_SPLIT_REMAINDER-NEXT:    [[TMP15:%.*]] = fadd <1 x float> [[TMP12]], [[TMP14]]
+; CM_SPLIT_REMAINDER-NEXT:    [[BLOCK17:%.*]] = shufflevector <3 x float> [[COL_LOAD3]], <3 x float> poison, <1 x i32> <i32 2>
+; CM_SPLIT_REMAINDER-NEXT:    [[TMP16:%.*]] = extractelement <3 x float> [[COL_LOAD4]], i64 2
+; CM_SPLIT_REMAINDER-NEXT:    [[SPLAT_SPLATINSERT18:%.*]] = insertelement <1 x float> poison, float [[TMP16]], i64 0
+; CM_SPLIT_REMAINDER-NEXT:    [[SPLAT_SPLAT19:%.*]] = shufflevector <1 x float> [[SPLAT_SPLATINSERT18]], <1 x float> poison, <1 x i32> zeroinitializer
+; CM_SPLIT_REMAINDER-NEXT:    [[TMP17:%.*]] = fmul <1 x float> [[BLOCK17]], [[SPLAT_SPLAT19]]
+; CM_SPLIT_REMAINDER-NEXT:    [[TMP18:%.*]] = fadd <1 x float> [[TMP15]], [[TMP17]]
+; CM_SPLIT_REMAINDER-NEXT:    [[TMP19:%.*]] = shufflevector <1 x float> [[TMP18]], <1 x float> poison, <3 x i32> <i32 0, i32 poison, i32 poison>
+; CM_SPLIT_REMAINDER-NEXT:    [[TMP20:%.*]] = shufflevector <3 x float> [[TMP10]], <3 x float> [[TMP19]], <3 x i32> <i32 0, i32 1, i32 3>
+; CM_SPLIT_REMAINDER-NEXT:    store <3 x float> [[TMP20]], ptr [[C]], align 4
+; CM_SPLIT_REMAINDER-NEXT:    ret void
+;
+; CM_NO_SPLIT_REMAINDER-LABEL: define void @matmul(
+; CM_NO_SPLIT_REMAINDER-SAME: ptr [[A:%.*]], ptr [[B:%.*]], ptr [[C:%.*]]) {
+; CM_NO_SPLIT_REMAINDER-NEXT:    [[COL_LOAD:%.*]] = load <3 x float>, ptr [[A]], align 4
+; CM_NO_SPLIT_REMAINDER-NEXT:    [[VEC_GEP:%.*]] = getelementptr float, ptr [[A]], i64 3
+; CM_NO_SPLIT_REMAINDER-NEXT:    [[COL_LOAD1:%.*]] = load <3 x float>, ptr [[VEC_GEP]], align 4
+; CM_NO_SPLIT_REMAINDER-NEXT:    [[VEC_GEP2:%.*]] = getelementptr float, ptr [[A]], i64 6
+; CM_NO_SPLIT_REMAINDER-NEXT:    [[COL_LOAD3:%.*]] = load <3 x float>, ptr [[VEC_GEP2]], align 4
+; CM_NO_SPLIT_REMAINDER-NEXT:    [[COL_LOAD4:%.*]] = load <3 x float>, ptr [[B]], align 4
+; CM_NO_SPLIT_REMAINDER-NEXT:    [[BLOCK:%.*]] = shufflevector <3 x float> [[COL_LOAD]], <3 x float> poison, <3 x i32> <i32 0, i32 1, i32 2>
+; CM_NO_SPLIT_REMAINDER-NEXT:    [[TMP1:%.*]] = extractelement <3 x float> [[COL_LOAD4]], i64 0
+; CM_NO_SPLIT_REMAINDER-NEXT:    [[SPLAT_SPLATINSERT:%.*]] = insertelement <3 x float> poison, float [[TMP1]], i64 0
+; CM_NO_SPLIT_REMAINDER-NEXT:    [[SPLAT_SPLAT:%.*]] = shufflevector <3 x float> [[SPLAT_SPLATINSERT]], <3 x float> poison, <3 x i32> zeroinitializer
+; CM_NO_SPLIT_REMAINDER-NEXT:    [[TMP2:%.*]] = fmul <3 x float> [[BLOCK]], [[SPLAT_SPLAT]]
+; CM_NO_SPLIT_REMAINDER-NEXT:    [[BLOCK5:%.*]] = shufflevector <3 x float> [[COL_LOAD1]], <3 x float> poison, <3 x i32> <i32 0, i32 1, i32 2>
+; CM_NO_SPLIT_REMAINDER-NEXT:    [[TMP3:%.*]] = extractelement <3 x float> [[COL_LOAD4]], i64 1
+; CM_NO_SPLIT_REMAINDER-NEXT:    [[SPLAT_SPLATINSERT6:%.*]] = insertelement <3 x float> poison, float [[TMP3]], i64 0
+; CM_NO_SPLIT_REMAINDER-NEXT:    [[SPLAT_SPLAT7:%.*]] = shufflevector <3 x float> [[SPLAT_SPLATINSERT6]], <3 x float> poison, <3 x i32> zeroinitializer
+; CM_NO_SPLIT_REMAINDER-NEXT:    [[TMP4:%.*]] = fmul <3 x float> [[BLOCK5]], [[SPLAT_SPLAT7]]
+; CM_NO_SPLIT_REMAINDER-NEXT:    [[TMP5:%.*]] = fadd <3 x float> [[TMP2]]...
[truncated]

@jroelofs
Copy link
Contributor

cc: @cofibrant (since I can't tag you as a reviewer yet)

@@ -97,6 +97,11 @@ static cl::opt<MatrixLayoutTy> MatrixLayout(
static cl::opt<bool> PrintAfterTransposeOpt("matrix-print-after-transpose-opt",
cl::init(false));

static cl::opt<bool> SplitMatmulRemainder(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you make this a threshold instead of a boolean, I think it would be a little better for both the test and "what if"-type experimentation.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good idea! See the additional commit.

Copy link
Contributor

@jroelofs jroelofs left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One more nit, but otherwise LGTM.

Co-authored-by: Jon Roelofs <[email protected]>
@anemet anemet enabled auto-merge (squash) October 17, 2025 18:05
@anemet anemet merged commit 783b050 into llvm:main Oct 17, 2025
7 of 9 checks passed
@anemet anemet deleted the dev/anemet/lmi-non-power-of-two-blocking branch October 17, 2025 18:51
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants