Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 29 additions & 7 deletions llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,11 @@ static cl::opt<MatrixLayoutTy> MatrixLayout(
static cl::opt<bool> PrintAfterTransposeOpt("matrix-print-after-transpose-opt",
cl::init(false));

static cl::opt<bool> SplitMatmulRemainder(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you make this a threshold instead of a boolean, I think it would be a little better for both the test and "what if"-type experimentation.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good idea! See the additional commit.

"matrix-split-matmul-remainder", cl::Hidden,
cl::desc("Split remainder vector in the inner loop of matmul"),
cl::init(true));

/// Helper function to either return Scope, if it is a subprogram or the
/// attached subprogram for a local scope.
static DISubprogram *getSubprogram(DIScope *Scope) {
Expand Down Expand Up @@ -1719,6 +1724,26 @@ class LowerMatrixIntrinsics {
ToRemove.push_back(MatMul);
}

/// Given \p Remainder iterations of the the matmul inner loop,
/// potentially lower \p Blocksize that is used for the underlying
/// vector.
unsigned capBlockSize(unsigned BlockSize, unsigned Remainder, Type *EltType) {
if (BlockSize <= Remainder)
return BlockSize;

// If the remainder is also a legal type just use it.
if (TTI.isTypeLegal(FixedVectorType::get(EltType, Remainder)) ||
!SplitMatmulRemainder)
return Remainder;

// Gradually lower the vectorization factor to cover the
// remainder.
do {
BlockSize /= 2;
} while (BlockSize > Remainder);
return BlockSize;
}

/// Compute \p Result += \p A * \p B for input matrices with left-associating
/// addition.
///
Expand Down Expand Up @@ -1756,10 +1781,8 @@ class LowerMatrixIntrinsics {
bool isSumZero = isa<ConstantAggregateZero>(Result.getColumn(J));

for (unsigned I = 0; I < R; I += BlockSize) {
// Gradually lower the vectorization factor to cover the remainder.
while (I + BlockSize > R)
BlockSize /= 2;

// Lower block size to make sure we stay within bounds.
BlockSize = capBlockSize(BlockSize, R - I, Result.getElementType());
Value *Sum = IsTiled ? Result.extractVector(I, J, BlockSize, Builder)
: nullptr;
for (unsigned K = 0; K < M; ++K) {
Expand All @@ -1784,9 +1807,8 @@ class LowerMatrixIntrinsics {
unsigned BlockSize = VF;
bool isSumZero = isa<ConstantAggregateZero>(Result.getRow(I));
for (unsigned J = 0; J < C; J += BlockSize) {
// Gradually lower the vectorization factor to cover the remainder.
while (J + BlockSize > C)
BlockSize /= 2;
// Lower the vectorization factor to cover the remainder.
BlockSize = capBlockSize(BlockSize, C - J, Result.getElementType());

Value *Sum = nullptr;
for (unsigned K = 0; K < M; ++K) {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 5
; RUN: opt -passes='lower-matrix-intrinsics' -matrix-default-layout=row-major -S < %s | FileCheck --check-prefix=RM_SPLIT_REMAINDER %s
; RUN: opt -passes='lower-matrix-intrinsics' -matrix-split-matmul-remainder=0 -matrix-default-layout=row-major -S < %s | FileCheck --check-prefix=RM_NO_SPLIT_REMAINDER %s

; REQUIRES: aarch64-registered-target

target datalayout = "e-m:o-i64:64-f80:128-n8:8:32:64-S128"
target triple = "aarch64-apple-ios"

define void @matmul(ptr %a, ptr %b, ptr %c) {
; RM_SPLIT_REMAINDER-LABEL: define void @matmul(
; RM_SPLIT_REMAINDER-SAME: ptr [[A:%.*]], ptr [[B:%.*]], ptr [[C:%.*]]) {
; RM_SPLIT_REMAINDER-NEXT: [[COL_LOAD:%.*]] = load <3 x float>, ptr [[A]], align 4
; RM_SPLIT_REMAINDER-NEXT: [[COL_LOAD1:%.*]] = load <3 x float>, ptr [[B]], align 4
; RM_SPLIT_REMAINDER-NEXT: [[VEC_GEP:%.*]] = getelementptr float, ptr [[B]], i64 3
; RM_SPLIT_REMAINDER-NEXT: [[COL_LOAD2:%.*]] = load <3 x float>, ptr [[VEC_GEP]], align 4
; RM_SPLIT_REMAINDER-NEXT: [[VEC_GEP3:%.*]] = getelementptr float, ptr [[B]], i64 6
; RM_SPLIT_REMAINDER-NEXT: [[COL_LOAD4:%.*]] = load <3 x float>, ptr [[VEC_GEP3]], align 4
; RM_SPLIT_REMAINDER-NEXT: [[BLOCK:%.*]] = shufflevector <3 x float> [[COL_LOAD1]], <3 x float> poison, <2 x i32> <i32 0, i32 1>
; RM_SPLIT_REMAINDER-NEXT: [[TMP1:%.*]] = extractelement <3 x float> [[COL_LOAD]], i64 0
; RM_SPLIT_REMAINDER-NEXT: [[SPLAT_SPLATINSERT:%.*]] = insertelement <2 x float> poison, float [[TMP1]], i64 0
; RM_SPLIT_REMAINDER-NEXT: [[SPLAT_SPLAT:%.*]] = shufflevector <2 x float> [[SPLAT_SPLATINSERT]], <2 x float> poison, <2 x i32> zeroinitializer
; RM_SPLIT_REMAINDER-NEXT: [[TMP2:%.*]] = fmul <2 x float> [[SPLAT_SPLAT]], [[BLOCK]]
; RM_SPLIT_REMAINDER-NEXT: [[BLOCK5:%.*]] = shufflevector <3 x float> [[COL_LOAD2]], <3 x float> poison, <2 x i32> <i32 0, i32 1>
; RM_SPLIT_REMAINDER-NEXT: [[TMP3:%.*]] = extractelement <3 x float> [[COL_LOAD]], i64 1
; RM_SPLIT_REMAINDER-NEXT: [[SPLAT_SPLATINSERT6:%.*]] = insertelement <2 x float> poison, float [[TMP3]], i64 0
; RM_SPLIT_REMAINDER-NEXT: [[SPLAT_SPLAT7:%.*]] = shufflevector <2 x float> [[SPLAT_SPLATINSERT6]], <2 x float> poison, <2 x i32> zeroinitializer
; RM_SPLIT_REMAINDER-NEXT: [[TMP4:%.*]] = fmul <2 x float> [[SPLAT_SPLAT7]], [[BLOCK5]]
; RM_SPLIT_REMAINDER-NEXT: [[TMP5:%.*]] = fadd <2 x float> [[TMP2]], [[TMP4]]
; RM_SPLIT_REMAINDER-NEXT: [[BLOCK8:%.*]] = shufflevector <3 x float> [[COL_LOAD4]], <3 x float> poison, <2 x i32> <i32 0, i32 1>
; RM_SPLIT_REMAINDER-NEXT: [[TMP6:%.*]] = extractelement <3 x float> [[COL_LOAD]], i64 2
; RM_SPLIT_REMAINDER-NEXT: [[SPLAT_SPLATINSERT9:%.*]] = insertelement <2 x float> poison, float [[TMP6]], i64 0
; RM_SPLIT_REMAINDER-NEXT: [[SPLAT_SPLAT10:%.*]] = shufflevector <2 x float> [[SPLAT_SPLATINSERT9]], <2 x float> poison, <2 x i32> zeroinitializer
; RM_SPLIT_REMAINDER-NEXT: [[TMP7:%.*]] = fmul <2 x float> [[SPLAT_SPLAT10]], [[BLOCK8]]
; RM_SPLIT_REMAINDER-NEXT: [[TMP8:%.*]] = fadd <2 x float> [[TMP5]], [[TMP7]]
; RM_SPLIT_REMAINDER-NEXT: [[TMP9:%.*]] = shufflevector <2 x float> [[TMP8]], <2 x float> poison, <3 x i32> <i32 0, i32 1, i32 poison>
; RM_SPLIT_REMAINDER-NEXT: [[TMP10:%.*]] = shufflevector <3 x float> poison, <3 x float> [[TMP9]], <3 x i32> <i32 3, i32 4, i32 2>
; RM_SPLIT_REMAINDER-NEXT: [[BLOCK11:%.*]] = shufflevector <3 x float> [[COL_LOAD1]], <3 x float> poison, <1 x i32> <i32 2>
; RM_SPLIT_REMAINDER-NEXT: [[TMP11:%.*]] = extractelement <3 x float> [[COL_LOAD]], i64 0
; RM_SPLIT_REMAINDER-NEXT: [[SPLAT_SPLATINSERT12:%.*]] = insertelement <1 x float> poison, float [[TMP11]], i64 0
; RM_SPLIT_REMAINDER-NEXT: [[SPLAT_SPLAT13:%.*]] = shufflevector <1 x float> [[SPLAT_SPLATINSERT12]], <1 x float> poison, <1 x i32> zeroinitializer
; RM_SPLIT_REMAINDER-NEXT: [[TMP12:%.*]] = fmul <1 x float> [[SPLAT_SPLAT13]], [[BLOCK11]]
; RM_SPLIT_REMAINDER-NEXT: [[BLOCK14:%.*]] = shufflevector <3 x float> [[COL_LOAD2]], <3 x float> poison, <1 x i32> <i32 2>
; RM_SPLIT_REMAINDER-NEXT: [[TMP13:%.*]] = extractelement <3 x float> [[COL_LOAD]], i64 1
; RM_SPLIT_REMAINDER-NEXT: [[SPLAT_SPLATINSERT15:%.*]] = insertelement <1 x float> poison, float [[TMP13]], i64 0
; RM_SPLIT_REMAINDER-NEXT: [[SPLAT_SPLAT16:%.*]] = shufflevector <1 x float> [[SPLAT_SPLATINSERT15]], <1 x float> poison, <1 x i32> zeroinitializer
; RM_SPLIT_REMAINDER-NEXT: [[TMP14:%.*]] = fmul <1 x float> [[SPLAT_SPLAT16]], [[BLOCK14]]
; RM_SPLIT_REMAINDER-NEXT: [[TMP15:%.*]] = fadd <1 x float> [[TMP12]], [[TMP14]]
; RM_SPLIT_REMAINDER-NEXT: [[BLOCK17:%.*]] = shufflevector <3 x float> [[COL_LOAD4]], <3 x float> poison, <1 x i32> <i32 2>
; RM_SPLIT_REMAINDER-NEXT: [[TMP16:%.*]] = extractelement <3 x float> [[COL_LOAD]], i64 2
; RM_SPLIT_REMAINDER-NEXT: [[SPLAT_SPLATINSERT18:%.*]] = insertelement <1 x float> poison, float [[TMP16]], i64 0
; RM_SPLIT_REMAINDER-NEXT: [[SPLAT_SPLAT19:%.*]] = shufflevector <1 x float> [[SPLAT_SPLATINSERT18]], <1 x float> poison, <1 x i32> zeroinitializer
; RM_SPLIT_REMAINDER-NEXT: [[TMP17:%.*]] = fmul <1 x float> [[SPLAT_SPLAT19]], [[BLOCK17]]
; RM_SPLIT_REMAINDER-NEXT: [[TMP18:%.*]] = fadd <1 x float> [[TMP15]], [[TMP17]]
; RM_SPLIT_REMAINDER-NEXT: [[TMP19:%.*]] = shufflevector <1 x float> [[TMP18]], <1 x float> poison, <3 x i32> <i32 0, i32 poison, i32 poison>
; RM_SPLIT_REMAINDER-NEXT: [[TMP20:%.*]] = shufflevector <3 x float> [[TMP10]], <3 x float> [[TMP19]], <3 x i32> <i32 0, i32 1, i32 3>
; RM_SPLIT_REMAINDER-NEXT: store <3 x float> [[TMP20]], ptr [[C]], align 4
; RM_SPLIT_REMAINDER-NEXT: ret void
;
; RM_NO_SPLIT_REMAINDER-LABEL: define void @matmul(
; RM_NO_SPLIT_REMAINDER-SAME: ptr [[A:%.*]], ptr [[B:%.*]], ptr [[C:%.*]]) {
; RM_NO_SPLIT_REMAINDER-NEXT: [[COL_LOAD:%.*]] = load <3 x float>, ptr [[A]], align 4
; RM_NO_SPLIT_REMAINDER-NEXT: [[COL_LOAD1:%.*]] = load <3 x float>, ptr [[B]], align 4
; RM_NO_SPLIT_REMAINDER-NEXT: [[VEC_GEP:%.*]] = getelementptr float, ptr [[B]], i64 3
; RM_NO_SPLIT_REMAINDER-NEXT: [[COL_LOAD2:%.*]] = load <3 x float>, ptr [[VEC_GEP]], align 4
; RM_NO_SPLIT_REMAINDER-NEXT: [[VEC_GEP3:%.*]] = getelementptr float, ptr [[B]], i64 6
; RM_NO_SPLIT_REMAINDER-NEXT: [[COL_LOAD4:%.*]] = load <3 x float>, ptr [[VEC_GEP3]], align 4
; RM_NO_SPLIT_REMAINDER-NEXT: [[BLOCK:%.*]] = shufflevector <3 x float> [[COL_LOAD1]], <3 x float> poison, <3 x i32> <i32 0, i32 1, i32 2>
; RM_NO_SPLIT_REMAINDER-NEXT: [[TMP1:%.*]] = extractelement <3 x float> [[COL_LOAD]], i64 0
; RM_NO_SPLIT_REMAINDER-NEXT: [[SPLAT_SPLATINSERT:%.*]] = insertelement <3 x float> poison, float [[TMP1]], i64 0
; RM_NO_SPLIT_REMAINDER-NEXT: [[SPLAT_SPLAT:%.*]] = shufflevector <3 x float> [[SPLAT_SPLATINSERT]], <3 x float> poison, <3 x i32> zeroinitializer
; RM_NO_SPLIT_REMAINDER-NEXT: [[TMP2:%.*]] = fmul <3 x float> [[SPLAT_SPLAT]], [[BLOCK]]
; RM_NO_SPLIT_REMAINDER-NEXT: [[BLOCK5:%.*]] = shufflevector <3 x float> [[COL_LOAD2]], <3 x float> poison, <3 x i32> <i32 0, i32 1, i32 2>
; RM_NO_SPLIT_REMAINDER-NEXT: [[TMP3:%.*]] = extractelement <3 x float> [[COL_LOAD]], i64 1
; RM_NO_SPLIT_REMAINDER-NEXT: [[SPLAT_SPLATINSERT6:%.*]] = insertelement <3 x float> poison, float [[TMP3]], i64 0
; RM_NO_SPLIT_REMAINDER-NEXT: [[SPLAT_SPLAT7:%.*]] = shufflevector <3 x float> [[SPLAT_SPLATINSERT6]], <3 x float> poison, <3 x i32> zeroinitializer
; RM_NO_SPLIT_REMAINDER-NEXT: [[TMP4:%.*]] = fmul <3 x float> [[SPLAT_SPLAT7]], [[BLOCK5]]
; RM_NO_SPLIT_REMAINDER-NEXT: [[TMP5:%.*]] = fadd <3 x float> [[TMP2]], [[TMP4]]
; RM_NO_SPLIT_REMAINDER-NEXT: [[BLOCK8:%.*]] = shufflevector <3 x float> [[COL_LOAD4]], <3 x float> poison, <3 x i32> <i32 0, i32 1, i32 2>
; RM_NO_SPLIT_REMAINDER-NEXT: [[TMP6:%.*]] = extractelement <3 x float> [[COL_LOAD]], i64 2
; RM_NO_SPLIT_REMAINDER-NEXT: [[SPLAT_SPLATINSERT9:%.*]] = insertelement <3 x float> poison, float [[TMP6]], i64 0
; RM_NO_SPLIT_REMAINDER-NEXT: [[SPLAT_SPLAT10:%.*]] = shufflevector <3 x float> [[SPLAT_SPLATINSERT9]], <3 x float> poison, <3 x i32> zeroinitializer
; RM_NO_SPLIT_REMAINDER-NEXT: [[TMP7:%.*]] = fmul <3 x float> [[SPLAT_SPLAT10]], [[BLOCK8]]
; RM_NO_SPLIT_REMAINDER-NEXT: [[TMP8:%.*]] = fadd <3 x float> [[TMP5]], [[TMP7]]
; RM_NO_SPLIT_REMAINDER-NEXT: [[TMP9:%.*]] = shufflevector <3 x float> [[TMP8]], <3 x float> poison, <3 x i32> <i32 0, i32 1, i32 2>
; RM_NO_SPLIT_REMAINDER-NEXT: [[TMP10:%.*]] = shufflevector <3 x float> poison, <3 x float> [[TMP9]], <3 x i32> <i32 3, i32 4, i32 5>
; RM_NO_SPLIT_REMAINDER-NEXT: store <3 x float> [[TMP10]], ptr [[C]], align 4
; RM_NO_SPLIT_REMAINDER-NEXT: ret void
;
%a_load = load <3 x float>, ptr %a, align 4
%b_load = load <9 x float>, ptr %b, align 4
%matmul = tail call <3 x float> @llvm.matrix.multiply.v3f32.v9f32.v3f32(<3 x float> %a_load, <9 x float> %b_load, i32 1, i32 3, i32 3)
store <3 x float> %matmul, ptr %c, align 4
ret void
}
Loading