-
Notifications
You must be signed in to change notification settings - Fork 14.9k
[LMI] Support non-power-of-2 types for the matmul remainder #163987
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[LMI] Support non-power-of-2 types for the matmul remainder #163987
Conversation
In the inner loop of matmul, instead of continuously halving the HW vector register width, I just use the remainder vector directly if it's legal. We don't have in-tree targets that have this so I opted for adding a hidden flag to simulate this for testing purposes: -matrix-split-matmul-remainder=0 The tests are the vectorization-friendly 3x3x1 matrix-vector and 1x3x3 vector-matrix multiplies for CM, RM respectively.
@llvm/pr-subscribers-llvm-transforms Author: Adam Nemet (anemet) ChangesIn the inner loop of matmul, instead of continuously halving the HW vector register width, I just use the remainder vector directly if it's legal. We don't have in-tree targets that have this so I opted for adding a hidden flag to simulate this for testing purposes: -matrix-split-matmul-remainder=0 The tests are the vectorization-friendly 3x3x1 matrix-vector and 1x3x3 vector-matrix multiplies for CM, RM respectively. Patch is 21.37 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/163987.diff 3 Files Affected:
diff --git a/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp b/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp
index 7cae94ebb4ba1..242cbee8b3c9d 100644
--- a/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp
+++ b/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp
@@ -97,6 +97,11 @@ static cl::opt<MatrixLayoutTy> MatrixLayout(
static cl::opt<bool> PrintAfterTransposeOpt("matrix-print-after-transpose-opt",
cl::init(false));
+static cl::opt<bool> SplitMatmulRemainder(
+ "matrix-split-matmul-remainder", cl::Hidden,
+ cl::desc("Split remainder vector in the inner loop of matmul"),
+ cl::init(true));
+
/// Helper function to either return Scope, if it is a subprogram or the
/// attached subprogram for a local scope.
static DISubprogram *getSubprogram(DIScope *Scope) {
@@ -1719,6 +1724,26 @@ class LowerMatrixIntrinsics {
ToRemove.push_back(MatMul);
}
+ /// Given \p Remainder iterations of the the matmul inner loop,
+ /// potentially lower \p Blocksize that is used for the underlying
+ /// vector.
+ unsigned capBlockSize(unsigned BlockSize, unsigned Remainder, Type *EltType) {
+ if (BlockSize <= Remainder)
+ return BlockSize;
+
+ // If the remainder is also a legal type just use it.
+ if (TTI.isTypeLegal(FixedVectorType::get(EltType, Remainder)) ||
+ !SplitMatmulRemainder)
+ return Remainder;
+
+ // Gradually lower the vectorization factor to cover the
+ // remainder.
+ do {
+ BlockSize /= 2;
+ } while (BlockSize > Remainder);
+ return BlockSize;
+ }
+
/// Compute \p Result += \p A * \p B for input matrices with left-associating
/// addition.
///
@@ -1756,10 +1781,8 @@ class LowerMatrixIntrinsics {
bool isSumZero = isa<ConstantAggregateZero>(Result.getColumn(J));
for (unsigned I = 0; I < R; I += BlockSize) {
- // Gradually lower the vectorization factor to cover the remainder.
- while (I + BlockSize > R)
- BlockSize /= 2;
-
+ // Lower block size to make sure we stay within bounds.
+ BlockSize = capBlockSize(BlockSize, R - I, Result.getElementType());
Value *Sum = IsTiled ? Result.extractVector(I, J, BlockSize, Builder)
: nullptr;
for (unsigned K = 0; K < M; ++K) {
@@ -1784,9 +1807,8 @@ class LowerMatrixIntrinsics {
unsigned BlockSize = VF;
bool isSumZero = isa<ConstantAggregateZero>(Result.getRow(I));
for (unsigned J = 0; J < C; J += BlockSize) {
- // Gradually lower the vectorization factor to cover the remainder.
- while (J + BlockSize > C)
- BlockSize /= 2;
+ // Lower the vectorization factor to cover the remainder.
+ BlockSize = capBlockSize(BlockSize, C - J, Result.getElementType());
Value *Sum = nullptr;
for (unsigned K = 0; K < M; ++K) {
diff --git a/llvm/test/Transforms/LowerMatrixIntrinsics/multiply-remainder-rm.ll b/llvm/test/Transforms/LowerMatrixIntrinsics/multiply-remainder-rm.ll
new file mode 100644
index 0000000000000..f71aa5305f679
--- /dev/null
+++ b/llvm/test/Transforms/LowerMatrixIntrinsics/multiply-remainder-rm.ll
@@ -0,0 +1,95 @@
+; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 5
+; RUN: opt -passes='lower-matrix-intrinsics' -matrix-default-layout=row-major -S < %s | FileCheck --check-prefix=RM_SPLIT_REMAINDER %s
+; RUN: opt -passes='lower-matrix-intrinsics' -matrix-split-matmul-remainder=0 -matrix-default-layout=row-major -S < %s | FileCheck --check-prefix=RM_NO_SPLIT_REMAINDER %s
+
+; REQUIRES: aarch64-registered-target
+
+target datalayout = "e-m:o-i64:64-f80:128-n8:8:32:64-S128"
+target triple = "aarch64-apple-ios"
+
+define void @matmul(ptr %a, ptr %b, ptr %c) {
+; RM_SPLIT_REMAINDER-LABEL: define void @matmul(
+; RM_SPLIT_REMAINDER-SAME: ptr [[A:%.*]], ptr [[B:%.*]], ptr [[C:%.*]]) {
+; RM_SPLIT_REMAINDER-NEXT: [[COL_LOAD:%.*]] = load <3 x float>, ptr [[A]], align 4
+; RM_SPLIT_REMAINDER-NEXT: [[COL_LOAD1:%.*]] = load <3 x float>, ptr [[B]], align 4
+; RM_SPLIT_REMAINDER-NEXT: [[VEC_GEP:%.*]] = getelementptr float, ptr [[B]], i64 3
+; RM_SPLIT_REMAINDER-NEXT: [[COL_LOAD2:%.*]] = load <3 x float>, ptr [[VEC_GEP]], align 4
+; RM_SPLIT_REMAINDER-NEXT: [[VEC_GEP3:%.*]] = getelementptr float, ptr [[B]], i64 6
+; RM_SPLIT_REMAINDER-NEXT: [[COL_LOAD4:%.*]] = load <3 x float>, ptr [[VEC_GEP3]], align 4
+; RM_SPLIT_REMAINDER-NEXT: [[BLOCK:%.*]] = shufflevector <3 x float> [[COL_LOAD1]], <3 x float> poison, <2 x i32> <i32 0, i32 1>
+; RM_SPLIT_REMAINDER-NEXT: [[TMP1:%.*]] = extractelement <3 x float> [[COL_LOAD]], i64 0
+; RM_SPLIT_REMAINDER-NEXT: [[SPLAT_SPLATINSERT:%.*]] = insertelement <2 x float> poison, float [[TMP1]], i64 0
+; RM_SPLIT_REMAINDER-NEXT: [[SPLAT_SPLAT:%.*]] = shufflevector <2 x float> [[SPLAT_SPLATINSERT]], <2 x float> poison, <2 x i32> zeroinitializer
+; RM_SPLIT_REMAINDER-NEXT: [[TMP2:%.*]] = fmul <2 x float> [[SPLAT_SPLAT]], [[BLOCK]]
+; RM_SPLIT_REMAINDER-NEXT: [[BLOCK5:%.*]] = shufflevector <3 x float> [[COL_LOAD2]], <3 x float> poison, <2 x i32> <i32 0, i32 1>
+; RM_SPLIT_REMAINDER-NEXT: [[TMP3:%.*]] = extractelement <3 x float> [[COL_LOAD]], i64 1
+; RM_SPLIT_REMAINDER-NEXT: [[SPLAT_SPLATINSERT6:%.*]] = insertelement <2 x float> poison, float [[TMP3]], i64 0
+; RM_SPLIT_REMAINDER-NEXT: [[SPLAT_SPLAT7:%.*]] = shufflevector <2 x float> [[SPLAT_SPLATINSERT6]], <2 x float> poison, <2 x i32> zeroinitializer
+; RM_SPLIT_REMAINDER-NEXT: [[TMP4:%.*]] = fmul <2 x float> [[SPLAT_SPLAT7]], [[BLOCK5]]
+; RM_SPLIT_REMAINDER-NEXT: [[TMP5:%.*]] = fadd <2 x float> [[TMP2]], [[TMP4]]
+; RM_SPLIT_REMAINDER-NEXT: [[BLOCK8:%.*]] = shufflevector <3 x float> [[COL_LOAD4]], <3 x float> poison, <2 x i32> <i32 0, i32 1>
+; RM_SPLIT_REMAINDER-NEXT: [[TMP6:%.*]] = extractelement <3 x float> [[COL_LOAD]], i64 2
+; RM_SPLIT_REMAINDER-NEXT: [[SPLAT_SPLATINSERT9:%.*]] = insertelement <2 x float> poison, float [[TMP6]], i64 0
+; RM_SPLIT_REMAINDER-NEXT: [[SPLAT_SPLAT10:%.*]] = shufflevector <2 x float> [[SPLAT_SPLATINSERT9]], <2 x float> poison, <2 x i32> zeroinitializer
+; RM_SPLIT_REMAINDER-NEXT: [[TMP7:%.*]] = fmul <2 x float> [[SPLAT_SPLAT10]], [[BLOCK8]]
+; RM_SPLIT_REMAINDER-NEXT: [[TMP8:%.*]] = fadd <2 x float> [[TMP5]], [[TMP7]]
+; RM_SPLIT_REMAINDER-NEXT: [[TMP9:%.*]] = shufflevector <2 x float> [[TMP8]], <2 x float> poison, <3 x i32> <i32 0, i32 1, i32 poison>
+; RM_SPLIT_REMAINDER-NEXT: [[TMP10:%.*]] = shufflevector <3 x float> poison, <3 x float> [[TMP9]], <3 x i32> <i32 3, i32 4, i32 2>
+; RM_SPLIT_REMAINDER-NEXT: [[BLOCK11:%.*]] = shufflevector <3 x float> [[COL_LOAD1]], <3 x float> poison, <1 x i32> <i32 2>
+; RM_SPLIT_REMAINDER-NEXT: [[TMP11:%.*]] = extractelement <3 x float> [[COL_LOAD]], i64 0
+; RM_SPLIT_REMAINDER-NEXT: [[SPLAT_SPLATINSERT12:%.*]] = insertelement <1 x float> poison, float [[TMP11]], i64 0
+; RM_SPLIT_REMAINDER-NEXT: [[SPLAT_SPLAT13:%.*]] = shufflevector <1 x float> [[SPLAT_SPLATINSERT12]], <1 x float> poison, <1 x i32> zeroinitializer
+; RM_SPLIT_REMAINDER-NEXT: [[TMP12:%.*]] = fmul <1 x float> [[SPLAT_SPLAT13]], [[BLOCK11]]
+; RM_SPLIT_REMAINDER-NEXT: [[BLOCK14:%.*]] = shufflevector <3 x float> [[COL_LOAD2]], <3 x float> poison, <1 x i32> <i32 2>
+; RM_SPLIT_REMAINDER-NEXT: [[TMP13:%.*]] = extractelement <3 x float> [[COL_LOAD]], i64 1
+; RM_SPLIT_REMAINDER-NEXT: [[SPLAT_SPLATINSERT15:%.*]] = insertelement <1 x float> poison, float [[TMP13]], i64 0
+; RM_SPLIT_REMAINDER-NEXT: [[SPLAT_SPLAT16:%.*]] = shufflevector <1 x float> [[SPLAT_SPLATINSERT15]], <1 x float> poison, <1 x i32> zeroinitializer
+; RM_SPLIT_REMAINDER-NEXT: [[TMP14:%.*]] = fmul <1 x float> [[SPLAT_SPLAT16]], [[BLOCK14]]
+; RM_SPLIT_REMAINDER-NEXT: [[TMP15:%.*]] = fadd <1 x float> [[TMP12]], [[TMP14]]
+; RM_SPLIT_REMAINDER-NEXT: [[BLOCK17:%.*]] = shufflevector <3 x float> [[COL_LOAD4]], <3 x float> poison, <1 x i32> <i32 2>
+; RM_SPLIT_REMAINDER-NEXT: [[TMP16:%.*]] = extractelement <3 x float> [[COL_LOAD]], i64 2
+; RM_SPLIT_REMAINDER-NEXT: [[SPLAT_SPLATINSERT18:%.*]] = insertelement <1 x float> poison, float [[TMP16]], i64 0
+; RM_SPLIT_REMAINDER-NEXT: [[SPLAT_SPLAT19:%.*]] = shufflevector <1 x float> [[SPLAT_SPLATINSERT18]], <1 x float> poison, <1 x i32> zeroinitializer
+; RM_SPLIT_REMAINDER-NEXT: [[TMP17:%.*]] = fmul <1 x float> [[SPLAT_SPLAT19]], [[BLOCK17]]
+; RM_SPLIT_REMAINDER-NEXT: [[TMP18:%.*]] = fadd <1 x float> [[TMP15]], [[TMP17]]
+; RM_SPLIT_REMAINDER-NEXT: [[TMP19:%.*]] = shufflevector <1 x float> [[TMP18]], <1 x float> poison, <3 x i32> <i32 0, i32 poison, i32 poison>
+; RM_SPLIT_REMAINDER-NEXT: [[TMP20:%.*]] = shufflevector <3 x float> [[TMP10]], <3 x float> [[TMP19]], <3 x i32> <i32 0, i32 1, i32 3>
+; RM_SPLIT_REMAINDER-NEXT: store <3 x float> [[TMP20]], ptr [[C]], align 4
+; RM_SPLIT_REMAINDER-NEXT: ret void
+;
+; RM_NO_SPLIT_REMAINDER-LABEL: define void @matmul(
+; RM_NO_SPLIT_REMAINDER-SAME: ptr [[A:%.*]], ptr [[B:%.*]], ptr [[C:%.*]]) {
+; RM_NO_SPLIT_REMAINDER-NEXT: [[COL_LOAD:%.*]] = load <3 x float>, ptr [[A]], align 4
+; RM_NO_SPLIT_REMAINDER-NEXT: [[COL_LOAD1:%.*]] = load <3 x float>, ptr [[B]], align 4
+; RM_NO_SPLIT_REMAINDER-NEXT: [[VEC_GEP:%.*]] = getelementptr float, ptr [[B]], i64 3
+; RM_NO_SPLIT_REMAINDER-NEXT: [[COL_LOAD2:%.*]] = load <3 x float>, ptr [[VEC_GEP]], align 4
+; RM_NO_SPLIT_REMAINDER-NEXT: [[VEC_GEP3:%.*]] = getelementptr float, ptr [[B]], i64 6
+; RM_NO_SPLIT_REMAINDER-NEXT: [[COL_LOAD4:%.*]] = load <3 x float>, ptr [[VEC_GEP3]], align 4
+; RM_NO_SPLIT_REMAINDER-NEXT: [[BLOCK:%.*]] = shufflevector <3 x float> [[COL_LOAD1]], <3 x float> poison, <3 x i32> <i32 0, i32 1, i32 2>
+; RM_NO_SPLIT_REMAINDER-NEXT: [[TMP1:%.*]] = extractelement <3 x float> [[COL_LOAD]], i64 0
+; RM_NO_SPLIT_REMAINDER-NEXT: [[SPLAT_SPLATINSERT:%.*]] = insertelement <3 x float> poison, float [[TMP1]], i64 0
+; RM_NO_SPLIT_REMAINDER-NEXT: [[SPLAT_SPLAT:%.*]] = shufflevector <3 x float> [[SPLAT_SPLATINSERT]], <3 x float> poison, <3 x i32> zeroinitializer
+; RM_NO_SPLIT_REMAINDER-NEXT: [[TMP2:%.*]] = fmul <3 x float> [[SPLAT_SPLAT]], [[BLOCK]]
+; RM_NO_SPLIT_REMAINDER-NEXT: [[BLOCK5:%.*]] = shufflevector <3 x float> [[COL_LOAD2]], <3 x float> poison, <3 x i32> <i32 0, i32 1, i32 2>
+; RM_NO_SPLIT_REMAINDER-NEXT: [[TMP3:%.*]] = extractelement <3 x float> [[COL_LOAD]], i64 1
+; RM_NO_SPLIT_REMAINDER-NEXT: [[SPLAT_SPLATINSERT6:%.*]] = insertelement <3 x float> poison, float [[TMP3]], i64 0
+; RM_NO_SPLIT_REMAINDER-NEXT: [[SPLAT_SPLAT7:%.*]] = shufflevector <3 x float> [[SPLAT_SPLATINSERT6]], <3 x float> poison, <3 x i32> zeroinitializer
+; RM_NO_SPLIT_REMAINDER-NEXT: [[TMP4:%.*]] = fmul <3 x float> [[SPLAT_SPLAT7]], [[BLOCK5]]
+; RM_NO_SPLIT_REMAINDER-NEXT: [[TMP5:%.*]] = fadd <3 x float> [[TMP2]], [[TMP4]]
+; RM_NO_SPLIT_REMAINDER-NEXT: [[BLOCK8:%.*]] = shufflevector <3 x float> [[COL_LOAD4]], <3 x float> poison, <3 x i32> <i32 0, i32 1, i32 2>
+; RM_NO_SPLIT_REMAINDER-NEXT: [[TMP6:%.*]] = extractelement <3 x float> [[COL_LOAD]], i64 2
+; RM_NO_SPLIT_REMAINDER-NEXT: [[SPLAT_SPLATINSERT9:%.*]] = insertelement <3 x float> poison, float [[TMP6]], i64 0
+; RM_NO_SPLIT_REMAINDER-NEXT: [[SPLAT_SPLAT10:%.*]] = shufflevector <3 x float> [[SPLAT_SPLATINSERT9]], <3 x float> poison, <3 x i32> zeroinitializer
+; RM_NO_SPLIT_REMAINDER-NEXT: [[TMP7:%.*]] = fmul <3 x float> [[SPLAT_SPLAT10]], [[BLOCK8]]
+; RM_NO_SPLIT_REMAINDER-NEXT: [[TMP8:%.*]] = fadd <3 x float> [[TMP5]], [[TMP7]]
+; RM_NO_SPLIT_REMAINDER-NEXT: [[TMP9:%.*]] = shufflevector <3 x float> [[TMP8]], <3 x float> poison, <3 x i32> <i32 0, i32 1, i32 2>
+; RM_NO_SPLIT_REMAINDER-NEXT: [[TMP10:%.*]] = shufflevector <3 x float> poison, <3 x float> [[TMP9]], <3 x i32> <i32 3, i32 4, i32 5>
+; RM_NO_SPLIT_REMAINDER-NEXT: store <3 x float> [[TMP10]], ptr [[C]], align 4
+; RM_NO_SPLIT_REMAINDER-NEXT: ret void
+;
+ %a_load = load <3 x float>, ptr %a, align 4
+ %b_load = load <9 x float>, ptr %b, align 4
+ %matmul = tail call <3 x float> @llvm.matrix.multiply.v3f32.v9f32.v3f32(<3 x float> %a_load, <9 x float> %b_load, i32 1, i32 3, i32 3)
+ store <3 x float> %matmul, ptr %c, align 4
+ ret void
+}
diff --git a/llvm/test/Transforms/LowerMatrixIntrinsics/multiply-remainder.ll b/llvm/test/Transforms/LowerMatrixIntrinsics/multiply-remainder.ll
new file mode 100644
index 0000000000000..b60c3858c31c6
--- /dev/null
+++ b/llvm/test/Transforms/LowerMatrixIntrinsics/multiply-remainder.ll
@@ -0,0 +1,95 @@
+; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 5
+; RUN: opt -passes='lower-matrix-intrinsics' -S < %s | FileCheck --check-prefix=CM_SPLIT_REMAINDER %s
+; RUN: opt -passes='lower-matrix-intrinsics' -matrix-split-matmul-remainder=0 -S < %s | FileCheck --check-prefix=CM_NO_SPLIT_REMAINDER %s
+
+; REQUIRES: aarch64-registered-target
+
+target datalayout = "e-m:o-i64:64-f80:128-n8:8:32:64-S128"
+target triple = "aarch64-apple-ios"
+
+define void @matmul(ptr %a, ptr %b, ptr %c) {
+; CM_SPLIT_REMAINDER-LABEL: define void @matmul(
+; CM_SPLIT_REMAINDER-SAME: ptr [[A:%.*]], ptr [[B:%.*]], ptr [[C:%.*]]) {
+; CM_SPLIT_REMAINDER-NEXT: [[COL_LOAD:%.*]] = load <3 x float>, ptr [[A]], align 4
+; CM_SPLIT_REMAINDER-NEXT: [[VEC_GEP:%.*]] = getelementptr float, ptr [[A]], i64 3
+; CM_SPLIT_REMAINDER-NEXT: [[COL_LOAD1:%.*]] = load <3 x float>, ptr [[VEC_GEP]], align 4
+; CM_SPLIT_REMAINDER-NEXT: [[VEC_GEP2:%.*]] = getelementptr float, ptr [[A]], i64 6
+; CM_SPLIT_REMAINDER-NEXT: [[COL_LOAD3:%.*]] = load <3 x float>, ptr [[VEC_GEP2]], align 4
+; CM_SPLIT_REMAINDER-NEXT: [[COL_LOAD4:%.*]] = load <3 x float>, ptr [[B]], align 4
+; CM_SPLIT_REMAINDER-NEXT: [[BLOCK:%.*]] = shufflevector <3 x float> [[COL_LOAD]], <3 x float> poison, <2 x i32> <i32 0, i32 1>
+; CM_SPLIT_REMAINDER-NEXT: [[TMP1:%.*]] = extractelement <3 x float> [[COL_LOAD4]], i64 0
+; CM_SPLIT_REMAINDER-NEXT: [[SPLAT_SPLATINSERT:%.*]] = insertelement <2 x float> poison, float [[TMP1]], i64 0
+; CM_SPLIT_REMAINDER-NEXT: [[SPLAT_SPLAT:%.*]] = shufflevector <2 x float> [[SPLAT_SPLATINSERT]], <2 x float> poison, <2 x i32> zeroinitializer
+; CM_SPLIT_REMAINDER-NEXT: [[TMP2:%.*]] = fmul <2 x float> [[BLOCK]], [[SPLAT_SPLAT]]
+; CM_SPLIT_REMAINDER-NEXT: [[BLOCK5:%.*]] = shufflevector <3 x float> [[COL_LOAD1]], <3 x float> poison, <2 x i32> <i32 0, i32 1>
+; CM_SPLIT_REMAINDER-NEXT: [[TMP3:%.*]] = extractelement <3 x float> [[COL_LOAD4]], i64 1
+; CM_SPLIT_REMAINDER-NEXT: [[SPLAT_SPLATINSERT6:%.*]] = insertelement <2 x float> poison, float [[TMP3]], i64 0
+; CM_SPLIT_REMAINDER-NEXT: [[SPLAT_SPLAT7:%.*]] = shufflevector <2 x float> [[SPLAT_SPLATINSERT6]], <2 x float> poison, <2 x i32> zeroinitializer
+; CM_SPLIT_REMAINDER-NEXT: [[TMP4:%.*]] = fmul <2 x float> [[BLOCK5]], [[SPLAT_SPLAT7]]
+; CM_SPLIT_REMAINDER-NEXT: [[TMP5:%.*]] = fadd <2 x float> [[TMP2]], [[TMP4]]
+; CM_SPLIT_REMAINDER-NEXT: [[BLOCK8:%.*]] = shufflevector <3 x float> [[COL_LOAD3]], <3 x float> poison, <2 x i32> <i32 0, i32 1>
+; CM_SPLIT_REMAINDER-NEXT: [[TMP6:%.*]] = extractelement <3 x float> [[COL_LOAD4]], i64 2
+; CM_SPLIT_REMAINDER-NEXT: [[SPLAT_SPLATINSERT9:%.*]] = insertelement <2 x float> poison, float [[TMP6]], i64 0
+; CM_SPLIT_REMAINDER-NEXT: [[SPLAT_SPLAT10:%.*]] = shufflevector <2 x float> [[SPLAT_SPLATINSERT9]], <2 x float> poison, <2 x i32> zeroinitializer
+; CM_SPLIT_REMAINDER-NEXT: [[TMP7:%.*]] = fmul <2 x float> [[BLOCK8]], [[SPLAT_SPLAT10]]
+; CM_SPLIT_REMAINDER-NEXT: [[TMP8:%.*]] = fadd <2 x float> [[TMP5]], [[TMP7]]
+; CM_SPLIT_REMAINDER-NEXT: [[TMP9:%.*]] = shufflevector <2 x float> [[TMP8]], <2 x float> poison, <3 x i32> <i32 0, i32 1, i32 poison>
+; CM_SPLIT_REMAINDER-NEXT: [[TMP10:%.*]] = shufflevector <3 x float> poison, <3 x float> [[TMP9]], <3 x i32> <i32 3, i32 4, i32 2>
+; CM_SPLIT_REMAINDER-NEXT: [[BLOCK11:%.*]] = shufflevector <3 x float> [[COL_LOAD]], <3 x float> poison, <1 x i32> <i32 2>
+; CM_SPLIT_REMAINDER-NEXT: [[TMP11:%.*]] = extractelement <3 x float> [[COL_LOAD4]], i64 0
+; CM_SPLIT_REMAINDER-NEXT: [[SPLAT_SPLATINSERT12:%.*]] = insertelement <1 x float> poison, float [[TMP11]], i64 0
+; CM_SPLIT_REMAINDER-NEXT: [[SPLAT_SPLAT13:%.*]] = shufflevector <1 x float> [[SPLAT_SPLATINSERT12]], <1 x float> poison, <1 x i32> zeroinitializer
+; CM_SPLIT_REMAINDER-NEXT: [[TMP12:%.*]] = fmul <1 x float> [[BLOCK11]], [[SPLAT_SPLAT13]]
+; CM_SPLIT_REMAINDER-NEXT: [[BLOCK14:%.*]] = shufflevector <3 x float> [[COL_LOAD1]], <3 x float> poison, <1 x i32> <i32 2>
+; CM_SPLIT_REMAINDER-NEXT: [[TMP13:%.*]] = extractelement <3 x float> [[COL_LOAD4]], i64 1
+; CM_SPLIT_REMAINDER-NEXT: [[SPLAT_SPLATINSERT15:%.*]] = insertelement <1 x float> poison, float [[TMP13]], i64 0
+; CM_SPLIT_REMAINDER-NEXT: [[SPLAT_SPLAT16:%.*]] = shufflevector <1 x float> [[SPLAT_SPLATINSERT15]], <1 x float> poison, <1 x i32> zeroinitializer
+; CM_SPLIT_REMAINDER-NEXT: [[TMP14:%.*]] = fmul <1 x float> [[BLOCK14]], [[SPLAT_SPLAT16]]
+; CM_SPLIT_REMAINDER-NEXT: [[TMP15:%.*]] = fadd <1 x float> [[TMP12]], [[TMP14]]
+; CM_SPLIT_REMAINDER-NEXT: [[BLOCK17:%.*]] = shufflevector <3 x float> [[COL_LOAD3]], <3 x float> poison, <1 x i32> <i32 2>
+; CM_SPLIT_REMAINDER-NEXT: [[TMP16:%.*]] = extractelement <3 x float> [[COL_LOAD4]], i64 2
+; CM_SPLIT_REMAINDER-NEXT: [[SPLAT_SPLATINSERT18:%.*]] = insertelement <1 x float> poison, float [[TMP16]], i64 0
+; CM_SPLIT_REMAINDER-NEXT: [[SPLAT_SPLAT19:%.*]] = shufflevector <1 x float> [[SPLAT_SPLATINSERT18]], <1 x float> poison, <1 x i32> zeroinitializer
+; CM_SPLIT_REMAINDER-NEXT: [[TMP17:%.*]] = fmul <1 x float> [[BLOCK17]], [[SPLAT_SPLAT19]]
+; CM_SPLIT_REMAINDER-NEXT: [[TMP18:%.*]] = fadd <1 x float> [[TMP15]], [[TMP17]]
+; CM_SPLIT_REMAINDER-NEXT: [[TMP19:%.*]] = shufflevector <1 x float> [[TMP18]], <1 x float> poison, <3 x i32> <i32 0, i32 poison, i32 poison>
+; CM_SPLIT_REMAINDER-NEXT: [[TMP20:%.*]] = shufflevector <3 x float> [[TMP10]], <3 x float> [[TMP19]], <3 x i32> <i32 0, i32 1, i32 3>
+; CM_SPLIT_REMAINDER-NEXT: store <3 x float> [[TMP20]], ptr [[C]], align 4
+; CM_SPLIT_REMAINDER-NEXT: ret void
+;
+; CM_NO_SPLIT_REMAINDER-LABEL: define void @matmul(
+; CM_NO_SPLIT_REMAINDER-SAME: ptr [[A:%.*]], ptr [[B:%.*]], ptr [[C:%.*]]) {
+; CM_NO_SPLIT_REMAINDER-NEXT: [[COL_LOAD:%.*]] = load <3 x float>, ptr [[A]], align 4
+; CM_NO_SPLIT_REMAINDER-NEXT: [[VEC_GEP:%.*]] = getelementptr float, ptr [[A]], i64 3
+; CM_NO_SPLIT_REMAINDER-NEXT: [[COL_LOAD1:%.*]] = load <3 x float>, ptr [[VEC_GEP]], align 4
+; CM_NO_SPLIT_REMAINDER-NEXT: [[VEC_GEP2:%.*]] = getelementptr float, ptr [[A]], i64 6
+; CM_NO_SPLIT_REMAINDER-NEXT: [[COL_LOAD3:%.*]] = load <3 x float>, ptr [[VEC_GEP2]], align 4
+; CM_NO_SPLIT_REMAINDER-NEXT: [[COL_LOAD4:%.*]] = load <3 x float>, ptr [[B]], align 4
+; CM_NO_SPLIT_REMAINDER-NEXT: [[BLOCK:%.*]] = shufflevector <3 x float> [[COL_LOAD]], <3 x float> poison, <3 x i32> <i32 0, i32 1, i32 2>
+; CM_NO_SPLIT_REMAINDER-NEXT: [[TMP1:%.*]] = extractelement <3 x float> [[COL_LOAD4]], i64 0
+; CM_NO_SPLIT_REMAINDER-NEXT: [[SPLAT_SPLATINSERT:%.*]] = insertelement <3 x float> poison, float [[TMP1]], i64 0
+; CM_NO_SPLIT_REMAINDER-NEXT: [[SPLAT_SPLAT:%.*]] = shufflevector <3 x float> [[SPLAT_SPLATINSERT]], <3 x float> poison, <3 x i32> zeroinitializer
+; CM_NO_SPLIT_REMAINDER-NEXT: [[TMP2:%.*]] = fmul <3 x float> [[BLOCK]], [[SPLAT_SPLAT]]
+; CM_NO_SPLIT_REMAINDER-NEXT: [[BLOCK5:%.*]] = shufflevector <3 x float> [[COL_LOAD1]], <3 x float> poison, <3 x i32> <i32 0, i32 1, i32 2>
+; CM_NO_SPLIT_REMAINDER-NEXT: [[TMP3:%.*]] = extractelement <3 x float> [[COL_LOAD4]], i64 1
+; CM_NO_SPLIT_REMAINDER-NEXT: [[SPLAT_SPLATINSERT6:%.*]] = insertelement <3 x float> poison, float [[TMP3]], i64 0
+; CM_NO_SPLIT_REMAINDER-NEXT: [[SPLAT_SPLAT7:%.*]] = shufflevector <3 x float> [[SPLAT_SPLATINSERT6]], <3 x float> poison, <3 x i32> zeroinitializer
+; CM_NO_SPLIT_REMAINDER-NEXT: [[TMP4:%.*]] = fmul <3 x float> [[BLOCK5]], [[SPLAT_SPLAT7]]
+; CM_NO_SPLIT_REMAINDER-NEXT: [[TMP5:%.*]] = fadd <3 x float> [[TMP2]]...
[truncated]
|
cc: @cofibrant (since I can't tag you as a reviewer yet) |
@@ -97,6 +97,11 @@ static cl::opt<MatrixLayoutTy> MatrixLayout( | |||
static cl::opt<bool> PrintAfterTransposeOpt("matrix-print-after-transpose-opt", | |||
cl::init(false)); | |||
|
|||
static cl::opt<bool> SplitMatmulRemainder( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If you make this a threshold instead of a boolean, I think it would be a little better for both the test and "what if"-type experimentation.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good idea! See the additional commit.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
One more nit, but otherwise LGTM.
Co-authored-by: Jon Roelofs <[email protected]>
In the inner loop of matmul, instead of continuously halving the HW vector register width, I just use the remainder vector directly if it's legal.
We don't have in-tree targets that have this so I opted for adding a hidden flag to simulate this for testing purposes: -matrix-split-matmul-remainder=0
The tests are the vectorization-friendly 3x3x1 matrix-vector and 1x3x3 vector-matrix multiplies for CM, RM respectively.