Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2272,6 +2272,17 @@ void AArch64TargetLowering::addTypeForFixedLengthSVE(MVT VT) {
setPartialReduceMLAAction(MLAOps, VT,
MVT::getVectorVT(MVT::i8, NumElts * 2), Custom);
}

if (Subtarget->hasMatMulInt8()) {
if (VT.getVectorElementType() == MVT::i32)
setPartialReduceMLAAction(ISD::PARTIAL_REDUCE_SUMLA, VT,
MVT::getVectorVT(MVT::i8, NumElts * 4),
Custom);
else if (VT.getVectorElementType() == MVT::i64)
setPartialReduceMLAAction(ISD::PARTIAL_REDUCE_SUMLA, VT,
MVT::getVectorVT(MVT::i8, NumElts * 8),
Custom);
}
}

// Lower fixed length vector operations to scalable equivalents.
Expand Down
109 changes: 106 additions & 3 deletions llvm/test/CodeGen/AArch64/sve-fixed-length-partial-reduce.ll
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5
; RUN: llc -mattr=+dotprod -aarch64-enable-partial-reduce-nodes=true < %s | FileCheck %s --check-prefixes=COMMON,NEON
; RUN: llc -mattr=+sve,+dotprod -aarch64-enable-partial-reduce-nodes=true < %s | FileCheck %s --check-prefixes=COMMON,SVE
; RUN: llc -mattr=+sme -aarch64-enable-partial-reduce-nodes=true -force-streaming < %s | FileCheck %s --check-prefix=SME
; RUN: llc -mattr=+dotprod,+i8mm -aarch64-enable-partial-reduce-nodes=true < %s | FileCheck %s --check-prefixes=COMMON,NEON
; RUN: llc -mattr=+sve,+dotprod,+i8mm -aarch64-enable-partial-reduce-nodes=true < %s | FileCheck %s --check-prefixes=COMMON,SVE
; RUN: llc -mattr=+sme,+i8mm -aarch64-enable-partial-reduce-nodes=true -force-streaming < %s | FileCheck %s --check-prefix=SME

target triple = "aarch64"

Expand Down Expand Up @@ -407,6 +407,33 @@ define <4 x i32> @four_way_i8_i32_vl128(ptr %accptr, ptr %uptr, ptr %sptr) {
ret <4 x i32> %partial.reduce
}

define <4 x i32> @four_way_i8_i32_vl128_usdot(ptr %accptr, ptr %uptr, ptr %sptr) {
; COMMON-LABEL: four_way_i8_i32_vl128_usdot:
; COMMON: // %bb.0:
; COMMON-NEXT: ldr q0, [x0]
; COMMON-NEXT: ldr q1, [x1]
; COMMON-NEXT: ldr q2, [x2]
; COMMON-NEXT: usdot v0.4s, v1.16b, v2.16b
; COMMON-NEXT: ret
;
; SME-LABEL: four_way_i8_i32_vl128_usdot:
; SME: // %bb.0:
; SME-NEXT: ldr q0, [x0]
; SME-NEXT: ldr q1, [x1]
; SME-NEXT: ldr q2, [x2]
; SME-NEXT: usdot z0.s, z1.b, z2.b
; SME-NEXT: // kill: def $q0 killed $q0 killed $z0
; SME-NEXT: ret
%acc = load <4 x i32>, ptr %accptr
%u = load <16 x i8>, ptr %uptr
%s = load <16 x i8>, ptr %sptr
%u.wide = zext <16 x i8> %u to <16 x i32>
%s.wide = sext <16 x i8> %s to <16 x i32>
%mult = mul nuw nsw <16 x i32> %s.wide, %u.wide
%partial.reduce = tail call <4 x i32> @llvm.experimental.vector.partial.reduce.add(<4 x i32> %acc, <16 x i32> %mult)
ret <4 x i32> %partial.reduce
}

define <8 x i32> @four_way_i8_i32_vl128_double_width(ptr %accptr, ptr %uptr, ptr %sptr) {
;
; COMMON-LABEL: four_way_i8_i32_vl128_double_width:
Expand Down Expand Up @@ -438,6 +465,37 @@ define <8 x i32> @four_way_i8_i32_vl128_double_width(ptr %accptr, ptr %uptr, ptr
ret <8 x i32> %partial.reduce
}

define <8 x i32> @four_way_i8_i32_vl128_double_width_usdot(ptr %accptr, ptr %uptr, ptr %sptr) {
;
; COMMON-LABEL: four_way_i8_i32_vl128_double_width_usdot:
; COMMON: // %bb.0:
; COMMON-NEXT: ldp q0, q1, [x0]
; COMMON-NEXT: ldp q3, q2, [x1]
; COMMON-NEXT: ldp q5, q4, [x2]
; COMMON-NEXT: usdot v0.4s, v3.16b, v5.16b
; COMMON-NEXT: usdot v1.4s, v2.16b, v4.16b
; COMMON-NEXT: ret
;
; SME-LABEL: four_way_i8_i32_vl128_double_width_usdot:
; SME: // %bb.0:
; SME-NEXT: ldp q0, q1, [x0]
; SME-NEXT: ldp q3, q2, [x1]
; SME-NEXT: ldp q5, q4, [x2]
; SME-NEXT: usdot z0.s, z3.b, z5.b
; SME-NEXT: usdot z1.s, z2.b, z4.b
; SME-NEXT: // kill: def $q0 killed $q0 killed $z0
; SME-NEXT: // kill: def $q1 killed $q1 killed $z1
; SME-NEXT: ret
%acc = load <8 x i32>, ptr %accptr
%u = load <32 x i8>, ptr %uptr
%s = load <32 x i8>, ptr %sptr
%u.wide = zext <32 x i8> %u to <32 x i32>
%s.wide = sext <32 x i8> %s to <32 x i32>
%mult = mul nuw nsw <32 x i32> %s.wide, %u.wide
%partial.reduce = tail call <8 x i32> @llvm.experimental.vector.partial.reduce.add(<8 x i32> %acc, <32 x i32> %mult)
ret <8 x i32> %partial.reduce
}

define <8 x i32> @four_way_i8_i32_vl256(ptr %accptr, ptr %uptr, ptr %sptr) vscale_range(2,2) {
;
;
Expand Down Expand Up @@ -483,6 +541,51 @@ define <8 x i32> @four_way_i8_i32_vl256(ptr %accptr, ptr %uptr, ptr %sptr) vscal
ret <8 x i32> %partial.reduce
}

define <8 x i32> @four_way_i8_i32_vl256_usdot(ptr %accptr, ptr %uptr, ptr %sptr) vscale_range(2,2) {
;
;
; NEON-LABEL: four_way_i8_i32_vl256_usdot:
; NEON: // %bb.0:
; NEON-NEXT: ldp q0, q1, [x0]
; NEON-NEXT: ldp q3, q2, [x1]
; NEON-NEXT: ldp q5, q4, [x2]
; NEON-NEXT: usdot v0.4s, v3.16b, v5.16b
; NEON-NEXT: usdot v1.4s, v2.16b, v4.16b
; NEON-NEXT: ret
;
; SVE-LABEL: four_way_i8_i32_vl256_usdot:
; SVE: // %bb.0:
; SVE-NEXT: ldr z0, [x0]
; SVE-NEXT: ldr z1, [x1]
; SVE-NEXT: ldr z2, [x2]
; SVE-NEXT: usdot z0.s, z1.b, z2.b
; SVE-NEXT: mov z1.d, z0.d
; SVE-NEXT: ext z1.b, z1.b, z0.b, #16
; SVE-NEXT: // kill: def $q0 killed $q0 killed $z0
; SVE-NEXT: // kill: def $q1 killed $q1 killed $z1
; SVE-NEXT: ret
;
; SME-LABEL: four_way_i8_i32_vl256_usdot:
; SME: // %bb.0:
; SME-NEXT: ldr z0, [x0]
; SME-NEXT: ldr z1, [x1]
; SME-NEXT: ldr z2, [x2]
; SME-NEXT: usdot z0.s, z1.b, z2.b
; SME-NEXT: mov z1.d, z0.d
; SME-NEXT: ext z1.b, z1.b, z0.b, #16
; SME-NEXT: // kill: def $q0 killed $q0 killed $z0
; SME-NEXT: // kill: def $q1 killed $q1 killed $z1
; SME-NEXT: ret
%acc = load <8 x i32>, ptr %accptr
%u = load <32 x i8>, ptr %uptr
%s = load <32 x i8>, ptr %sptr
%u.wide = zext <32 x i8> %u to <32 x i32>
%s.wide = sext <32 x i8> %s to <32 x i32>
%mult = mul nuw nsw <32 x i32> %s.wide, %u.wide
%partial.reduce = tail call <8 x i32> @llvm.experimental.vector.partial.reduce.add(<8 x i32> %acc, <32 x i32> %mult)
ret <8 x i32> %partial.reduce
}

;
; Four-way dot (i16 -> i64)
;
Expand Down
Loading