Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 35 additions & 7 deletions llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,12 @@ static cl::opt<MatrixLayoutTy> MatrixLayout(
static cl::opt<bool> PrintAfterTransposeOpt("matrix-print-after-transpose-opt",
cl::init(false));

static cl::opt<unsigned> SplitMatmulRemainderOverThreshold(
"matrix-split-matmul-remainder-over-threshold", cl::Hidden,
cl::desc("Illegal remainder vectors over this size in bits should be split "
"in the inner loop of matmul"),
cl::init(0));

/// Helper function to either return Scope, if it is a subprogram or the
/// attached subprogram for a local scope.
static DISubprogram *getSubprogram(DIScope *Scope) {
Expand Down Expand Up @@ -1720,6 +1726,31 @@ class LowerMatrixIntrinsics {
ToRemove.push_back(MatMul);
}

/// Given \p Remainder iterations of the the matmul inner loop,
/// potentially lower \p Blocksize that is used for the underlying
/// vector.
unsigned capBlockSize(unsigned BlockSize, unsigned Remainder, Type *EltType) {
if (BlockSize <= Remainder)
return BlockSize;

// If the remainder is also a legal type just use it.
auto *VecTy = FixedVectorType::get(EltType, Remainder);
if (TTI.isTypeLegal(VecTy))
return Remainder;

// Similarly, if the vector is small enough that we don't want
// to split further.
if (VecTy->getPrimitiveSizeInBits() <= SplitMatmulRemainderOverThreshold)
return Remainder;

// Gradually lower the vectorization factor to cover the
// remainder.
do {
BlockSize /= 2;
} while (BlockSize > Remainder);
return BlockSize;
}

/// Compute \p Result += \p A * \p B for input matrices with left-associating
/// addition.
///
Expand Down Expand Up @@ -1757,10 +1788,8 @@ class LowerMatrixIntrinsics {
bool isSumZero = isa<ConstantAggregateZero>(Result.getColumn(J));

for (unsigned I = 0; I < R; I += BlockSize) {
// Gradually lower the vectorization factor to cover the remainder.
while (I + BlockSize > R)
BlockSize /= 2;

// Lower block size to make sure we stay within bounds.
BlockSize = capBlockSize(BlockSize, R - I, Result.getElementType());
Value *Sum = IsTiled ? Result.extractVector(I, J, BlockSize, Builder)
: nullptr;
for (unsigned K = 0; K < M; ++K) {
Expand All @@ -1785,9 +1814,8 @@ class LowerMatrixIntrinsics {
unsigned BlockSize = VF;
bool isSumZero = isa<ConstantAggregateZero>(Result.getRow(I));
for (unsigned J = 0; J < C; J += BlockSize) {
// Gradually lower the vectorization factor to cover the remainder.
while (J + BlockSize > C)
BlockSize /= 2;
// Lower the vectorization factor to cover the remainder.
BlockSize = capBlockSize(BlockSize, C - J, Result.getElementType());

Value *Sum = nullptr;
for (unsigned K = 0; K < M; ++K) {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 5
; RUN: opt -passes='lower-matrix-intrinsics' -matrix-default-layout=row-major -S < %s | FileCheck --check-prefix=SPLIT_REMAINDER %s
; RUN: opt -passes='lower-matrix-intrinsics' -matrix-split-matmul-remainder-over-threshold=96 -matrix-default-layout=row-major -S < %s | FileCheck --check-prefix=NO_SPLIT_REMAINDER %s
; RUN: opt -passes='lower-matrix-intrinsics' -matrix-split-matmul-remainder-over-threshold=64 -matrix-default-layout=row-major -S < %s | FileCheck --check-prefix=SPLIT_REMAINDER %s

; REQUIRES: aarch64-registered-target

target datalayout = "e-m:o-i64:64-f80:128-n8:8:32:64-S128"
target triple = "aarch64-apple-ios"

define void @matmul(ptr %a, ptr %b, ptr %c) {
; SPLIT_REMAINDER-LABEL: define void @matmul(
; SPLIT_REMAINDER-SAME: ptr [[A:%.*]], ptr [[B:%.*]], ptr [[C:%.*]]) {
; SPLIT_REMAINDER-NEXT: [[COL_LOAD:%.*]] = load <3 x float>, ptr [[A]], align 4
; SPLIT_REMAINDER-NEXT: [[COL_LOAD1:%.*]] = load <3 x float>, ptr [[B]], align 4
; SPLIT_REMAINDER-NEXT: [[VEC_GEP:%.*]] = getelementptr float, ptr [[B]], i64 3
; SPLIT_REMAINDER-NEXT: [[COL_LOAD2:%.*]] = load <3 x float>, ptr [[VEC_GEP]], align 4
; SPLIT_REMAINDER-NEXT: [[VEC_GEP3:%.*]] = getelementptr float, ptr [[B]], i64 6
; SPLIT_REMAINDER-NEXT: [[COL_LOAD4:%.*]] = load <3 x float>, ptr [[VEC_GEP3]], align 4
; SPLIT_REMAINDER-NEXT: [[BLOCK:%.*]] = shufflevector <3 x float> [[COL_LOAD1]], <3 x float> poison, <2 x i32> <i32 0, i32 1>
; SPLIT_REMAINDER-NEXT: [[TMP1:%.*]] = extractelement <3 x float> [[COL_LOAD]], i64 0
; SPLIT_REMAINDER-NEXT: [[SPLAT_SPLATINSERT:%.*]] = insertelement <2 x float> poison, float [[TMP1]], i64 0
; SPLIT_REMAINDER-NEXT: [[SPLAT_SPLAT:%.*]] = shufflevector <2 x float> [[SPLAT_SPLATINSERT]], <2 x float> poison, <2 x i32> zeroinitializer
; SPLIT_REMAINDER-NEXT: [[TMP2:%.*]] = fmul <2 x float> [[SPLAT_SPLAT]], [[BLOCK]]
; SPLIT_REMAINDER-NEXT: [[BLOCK5:%.*]] = shufflevector <3 x float> [[COL_LOAD2]], <3 x float> poison, <2 x i32> <i32 0, i32 1>
; SPLIT_REMAINDER-NEXT: [[TMP3:%.*]] = extractelement <3 x float> [[COL_LOAD]], i64 1
; SPLIT_REMAINDER-NEXT: [[SPLAT_SPLATINSERT6:%.*]] = insertelement <2 x float> poison, float [[TMP3]], i64 0
; SPLIT_REMAINDER-NEXT: [[SPLAT_SPLAT7:%.*]] = shufflevector <2 x float> [[SPLAT_SPLATINSERT6]], <2 x float> poison, <2 x i32> zeroinitializer
; SPLIT_REMAINDER-NEXT: [[TMP4:%.*]] = fmul <2 x float> [[SPLAT_SPLAT7]], [[BLOCK5]]
; SPLIT_REMAINDER-NEXT: [[TMP5:%.*]] = fadd <2 x float> [[TMP2]], [[TMP4]]
; SPLIT_REMAINDER-NEXT: [[BLOCK8:%.*]] = shufflevector <3 x float> [[COL_LOAD4]], <3 x float> poison, <2 x i32> <i32 0, i32 1>
; SPLIT_REMAINDER-NEXT: [[TMP6:%.*]] = extractelement <3 x float> [[COL_LOAD]], i64 2
; SPLIT_REMAINDER-NEXT: [[SPLAT_SPLATINSERT9:%.*]] = insertelement <2 x float> poison, float [[TMP6]], i64 0
; SPLIT_REMAINDER-NEXT: [[SPLAT_SPLAT10:%.*]] = shufflevector <2 x float> [[SPLAT_SPLATINSERT9]], <2 x float> poison, <2 x i32> zeroinitializer
; SPLIT_REMAINDER-NEXT: [[TMP7:%.*]] = fmul <2 x float> [[SPLAT_SPLAT10]], [[BLOCK8]]
; SPLIT_REMAINDER-NEXT: [[TMP8:%.*]] = fadd <2 x float> [[TMP5]], [[TMP7]]
; SPLIT_REMAINDER-NEXT: [[TMP9:%.*]] = shufflevector <2 x float> [[TMP8]], <2 x float> poison, <3 x i32> <i32 0, i32 1, i32 poison>
; SPLIT_REMAINDER-NEXT: [[TMP10:%.*]] = shufflevector <3 x float> poison, <3 x float> [[TMP9]], <3 x i32> <i32 3, i32 4, i32 2>
; SPLIT_REMAINDER-NEXT: [[BLOCK11:%.*]] = shufflevector <3 x float> [[COL_LOAD1]], <3 x float> poison, <1 x i32> <i32 2>
; SPLIT_REMAINDER-NEXT: [[TMP11:%.*]] = extractelement <3 x float> [[COL_LOAD]], i64 0
; SPLIT_REMAINDER-NEXT: [[SPLAT_SPLATINSERT12:%.*]] = insertelement <1 x float> poison, float [[TMP11]], i64 0
; SPLIT_REMAINDER-NEXT: [[SPLAT_SPLAT13:%.*]] = shufflevector <1 x float> [[SPLAT_SPLATINSERT12]], <1 x float> poison, <1 x i32> zeroinitializer
; SPLIT_REMAINDER-NEXT: [[TMP12:%.*]] = fmul <1 x float> [[SPLAT_SPLAT13]], [[BLOCK11]]
; SPLIT_REMAINDER-NEXT: [[BLOCK14:%.*]] = shufflevector <3 x float> [[COL_LOAD2]], <3 x float> poison, <1 x i32> <i32 2>
; SPLIT_REMAINDER-NEXT: [[TMP13:%.*]] = extractelement <3 x float> [[COL_LOAD]], i64 1
; SPLIT_REMAINDER-NEXT: [[SPLAT_SPLATINSERT15:%.*]] = insertelement <1 x float> poison, float [[TMP13]], i64 0
; SPLIT_REMAINDER-NEXT: [[SPLAT_SPLAT16:%.*]] = shufflevector <1 x float> [[SPLAT_SPLATINSERT15]], <1 x float> poison, <1 x i32> zeroinitializer
; SPLIT_REMAINDER-NEXT: [[TMP14:%.*]] = fmul <1 x float> [[SPLAT_SPLAT16]], [[BLOCK14]]
; SPLIT_REMAINDER-NEXT: [[TMP15:%.*]] = fadd <1 x float> [[TMP12]], [[TMP14]]
; SPLIT_REMAINDER-NEXT: [[BLOCK17:%.*]] = shufflevector <3 x float> [[COL_LOAD4]], <3 x float> poison, <1 x i32> <i32 2>
; SPLIT_REMAINDER-NEXT: [[TMP16:%.*]] = extractelement <3 x float> [[COL_LOAD]], i64 2
; SPLIT_REMAINDER-NEXT: [[SPLAT_SPLATINSERT18:%.*]] = insertelement <1 x float> poison, float [[TMP16]], i64 0
; SPLIT_REMAINDER-NEXT: [[SPLAT_SPLAT19:%.*]] = shufflevector <1 x float> [[SPLAT_SPLATINSERT18]], <1 x float> poison, <1 x i32> zeroinitializer
; SPLIT_REMAINDER-NEXT: [[TMP17:%.*]] = fmul <1 x float> [[SPLAT_SPLAT19]], [[BLOCK17]]
; SPLIT_REMAINDER-NEXT: [[TMP18:%.*]] = fadd <1 x float> [[TMP15]], [[TMP17]]
; SPLIT_REMAINDER-NEXT: [[TMP19:%.*]] = shufflevector <1 x float> [[TMP18]], <1 x float> poison, <3 x i32> <i32 0, i32 poison, i32 poison>
; SPLIT_REMAINDER-NEXT: [[TMP20:%.*]] = shufflevector <3 x float> [[TMP10]], <3 x float> [[TMP19]], <3 x i32> <i32 0, i32 1, i32 3>
; SPLIT_REMAINDER-NEXT: store <3 x float> [[TMP20]], ptr [[C]], align 4
; SPLIT_REMAINDER-NEXT: ret void
;
; NO_SPLIT_REMAINDER-LABEL: define void @matmul(
; NO_SPLIT_REMAINDER-SAME: ptr [[A:%.*]], ptr [[B:%.*]], ptr [[C:%.*]]) {
; NO_SPLIT_REMAINDER-NEXT: [[COL_LOAD:%.*]] = load <3 x float>, ptr [[A]], align 4
; NO_SPLIT_REMAINDER-NEXT: [[COL_LOAD1:%.*]] = load <3 x float>, ptr [[B]], align 4
; NO_SPLIT_REMAINDER-NEXT: [[VEC_GEP:%.*]] = getelementptr float, ptr [[B]], i64 3
; NO_SPLIT_REMAINDER-NEXT: [[COL_LOAD2:%.*]] = load <3 x float>, ptr [[VEC_GEP]], align 4
; NO_SPLIT_REMAINDER-NEXT: [[VEC_GEP3:%.*]] = getelementptr float, ptr [[B]], i64 6
; NO_SPLIT_REMAINDER-NEXT: [[COL_LOAD4:%.*]] = load <3 x float>, ptr [[VEC_GEP3]], align 4
; NO_SPLIT_REMAINDER-NEXT: [[BLOCK:%.*]] = shufflevector <3 x float> [[COL_LOAD1]], <3 x float> poison, <3 x i32> <i32 0, i32 1, i32 2>
; NO_SPLIT_REMAINDER-NEXT: [[TMP1:%.*]] = extractelement <3 x float> [[COL_LOAD]], i64 0
; NO_SPLIT_REMAINDER-NEXT: [[SPLAT_SPLATINSERT:%.*]] = insertelement <3 x float> poison, float [[TMP1]], i64 0
; NO_SPLIT_REMAINDER-NEXT: [[SPLAT_SPLAT:%.*]] = shufflevector <3 x float> [[SPLAT_SPLATINSERT]], <3 x float> poison, <3 x i32> zeroinitializer
; NO_SPLIT_REMAINDER-NEXT: [[TMP2:%.*]] = fmul <3 x float> [[SPLAT_SPLAT]], [[BLOCK]]
; NO_SPLIT_REMAINDER-NEXT: [[BLOCK5:%.*]] = shufflevector <3 x float> [[COL_LOAD2]], <3 x float> poison, <3 x i32> <i32 0, i32 1, i32 2>
; NO_SPLIT_REMAINDER-NEXT: [[TMP3:%.*]] = extractelement <3 x float> [[COL_LOAD]], i64 1
; NO_SPLIT_REMAINDER-NEXT: [[SPLAT_SPLATINSERT6:%.*]] = insertelement <3 x float> poison, float [[TMP3]], i64 0
; NO_SPLIT_REMAINDER-NEXT: [[SPLAT_SPLAT7:%.*]] = shufflevector <3 x float> [[SPLAT_SPLATINSERT6]], <3 x float> poison, <3 x i32> zeroinitializer
; NO_SPLIT_REMAINDER-NEXT: [[TMP4:%.*]] = fmul <3 x float> [[SPLAT_SPLAT7]], [[BLOCK5]]
; NO_SPLIT_REMAINDER-NEXT: [[TMP5:%.*]] = fadd <3 x float> [[TMP2]], [[TMP4]]
; NO_SPLIT_REMAINDER-NEXT: [[BLOCK8:%.*]] = shufflevector <3 x float> [[COL_LOAD4]], <3 x float> poison, <3 x i32> <i32 0, i32 1, i32 2>
; NO_SPLIT_REMAINDER-NEXT: [[TMP6:%.*]] = extractelement <3 x float> [[COL_LOAD]], i64 2
; NO_SPLIT_REMAINDER-NEXT: [[SPLAT_SPLATINSERT9:%.*]] = insertelement <3 x float> poison, float [[TMP6]], i64 0
; NO_SPLIT_REMAINDER-NEXT: [[SPLAT_SPLAT10:%.*]] = shufflevector <3 x float> [[SPLAT_SPLATINSERT9]], <3 x float> poison, <3 x i32> zeroinitializer
; NO_SPLIT_REMAINDER-NEXT: [[TMP7:%.*]] = fmul <3 x float> [[SPLAT_SPLAT10]], [[BLOCK8]]
; NO_SPLIT_REMAINDER-NEXT: [[TMP8:%.*]] = fadd <3 x float> [[TMP5]], [[TMP7]]
; NO_SPLIT_REMAINDER-NEXT: [[TMP9:%.*]] = shufflevector <3 x float> [[TMP8]], <3 x float> poison, <3 x i32> <i32 0, i32 1, i32 2>
; NO_SPLIT_REMAINDER-NEXT: [[TMP10:%.*]] = shufflevector <3 x float> poison, <3 x float> [[TMP9]], <3 x i32> <i32 3, i32 4, i32 5>
; NO_SPLIT_REMAINDER-NEXT: store <3 x float> [[TMP10]], ptr [[C]], align 4
; NO_SPLIT_REMAINDER-NEXT: ret void
;
%a_load = load <3 x float>, ptr %a, align 4
%b_load = load <9 x float>, ptr %b, align 4
%matmul = tail call <3 x float> @llvm.matrix.multiply.v3f32.v9f32.v3f32(<3 x float> %a_load, <9 x float> %b_load, i32 1, i32 3, i32 3)
store <3 x float> %matmul, ptr %c, align 4
ret void
}
Loading
Loading