Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2272,6 +2272,13 @@ void AArch64TargetLowering::addTypeForFixedLengthSVE(MVT VT) {
setPartialReduceMLAAction(MLAOps, VT,
MVT::getVectorVT(MVT::i8, NumElts * 2), Custom);
}

if (Subtarget->hasMatMulInt8()) {
if (VT.getVectorElementType() == MVT::i32)
setPartialReduceMLAAction(ISD::PARTIAL_REDUCE_SUMLA, VT, MVT::getVectorVT(MVT::i8, NumElts * 4), Custom);
else if (VT.getVectorElementType() == MVT::i64)
setPartialReduceMLAAction(ISD::PARTIAL_REDUCE_SUMLA, VT, MVT::getVectorVT(MVT::i8, NumElts * 8), Custom);
}
}

// Lower fixed length vector operations to scalable equivalents.
Expand Down
160 changes: 158 additions & 2 deletions llvm/test/CodeGen/AArch64/sve-fixed-length-partial-reduce.ll
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5
; RUN: llc -mattr=+dotprod -aarch64-enable-partial-reduce-nodes=true < %s | FileCheck %s --check-prefixes=COMMON,NEON
; RUN: llc -mattr=+sve,+dotprod -aarch64-enable-partial-reduce-nodes=true < %s | FileCheck %s --check-prefixes=COMMON,SVE
; RUN: llc -mattr=+dotprod,+i8mm -aarch64-enable-partial-reduce-nodes=true < %s | FileCheck %s --check-prefixes=COMMON,NEON
; RUN: llc -mattr=+sve,+dotprod,+i8mm -aarch64-enable-partial-reduce-nodes=true < %s | FileCheck %s --check-prefixes=COMMON,SVE
; RUN: llc -mattr=+sme -aarch64-enable-partial-reduce-nodes=true -force-streaming < %s | FileCheck %s --check-prefix=SME

target triple = "aarch64"
Expand Down Expand Up @@ -407,6 +407,46 @@ define <4 x i32> @four_way_i8_i32_vl128(ptr %accptr, ptr %uptr, ptr %sptr) {
ret <4 x i32> %partial.reduce
}

define <4 x i32> @four_way_i8_i32_vl128_usdot(ptr %accptr, ptr %uptr, ptr %sptr) {
; COMMON-LABEL: four_way_i8_i32_vl128_usdot:
; COMMON: // %bb.0:
; COMMON-NEXT: ldr q0, [x0]
; COMMON-NEXT: ldr q1, [x1]
; COMMON-NEXT: ldr q2, [x2]
; COMMON-NEXT: usdot v0.4s, v1.16b, v2.16b
; COMMON-NEXT: ret
;
; SME-LABEL: four_way_i8_i32_vl128_usdot:
; SME: // %bb.0:
; SME-NEXT: ptrue p0.s, vl4
; SME-NEXT: ldr q2, [x0]
; SME-NEXT: mov w8, #4 // =0x4
; SME-NEXT: ld1b { z0.s }, p0/z, [x1]
; SME-NEXT: ld1sb { z1.s }, p0/z, [x2]
; SME-NEXT: mad z0.s, p0/m, z1.s, z2.s
; SME-NEXT: ld1b { z1.s }, p0/z, [x1, x8]
; SME-NEXT: ld1sb { z2.s }, p0/z, [x2, x8]
; SME-NEXT: mov w8, #8 // =0x8
; SME-NEXT: mla z0.s, p0/m, z2.s, z1.s
; SME-NEXT: ld1b { z1.s }, p0/z, [x1, x8]
; SME-NEXT: ld1sb { z2.s }, p0/z, [x2, x8]
; SME-NEXT: mov w8, #12 // =0xc
; SME-NEXT: mla z0.s, p0/m, z2.s, z1.s
; SME-NEXT: ld1b { z1.s }, p0/z, [x1, x8]
; SME-NEXT: ld1sb { z2.s }, p0/z, [x2, x8]
; SME-NEXT: mla z0.s, p0/m, z2.s, z1.s
; SME-NEXT: // kill: def $q0 killed $q0 killed $z0
; SME-NEXT: ret
%acc = load <4 x i32>, ptr %accptr
%u = load <16 x i8>, ptr %uptr
%s = load <16 x i8>, ptr %sptr
%u.wide = zext <16 x i8> %u to <16 x i32>
%s.wide = sext <16 x i8> %s to <16 x i32>
%mult = mul nuw nsw <16 x i32> %s.wide, %u.wide
%partial.reduce = tail call <4 x i32> @llvm.experimental.vector.partial.reduce.add(<4 x i32> %acc, <16 x i32> %mult)
ret <4 x i32> %partial.reduce
}

define <8 x i32> @four_way_i8_i32_vl128_double_width(ptr %accptr, ptr %uptr, ptr %sptr) {
;
; COMMON-LABEL: four_way_i8_i32_vl128_double_width:
Expand Down Expand Up @@ -438,6 +478,67 @@ define <8 x i32> @four_way_i8_i32_vl128_double_width(ptr %accptr, ptr %uptr, ptr
ret <8 x i32> %partial.reduce
}

define <8 x i32> @four_way_i8_i32_vl128_double_width_usdot(ptr %accptr, ptr %uptr, ptr %sptr) {
;
; COMMON-LABEL: four_way_i8_i32_vl128_double_width_usdot:
; COMMON: // %bb.0:
; COMMON-NEXT: ldp q0, q1, [x0]
; COMMON-NEXT: ldp q3, q2, [x1]
; COMMON-NEXT: ldp q5, q4, [x2]
; COMMON-NEXT: usdot v0.4s, v3.16b, v5.16b
; COMMON-NEXT: usdot v1.4s, v2.16b, v4.16b
; COMMON-NEXT: ret
;
; SME-LABEL: four_way_i8_i32_vl128_double_width_usdot:
; SME: // %bb.0:
; SME-NEXT: ptrue p0.s, vl4
; SME-NEXT: mov w8, #16 // =0x10
; SME-NEXT: mov w9, #4 // =0x4
; SME-NEXT: ldp q5, q4, [x0]
; SME-NEXT: ld1b { z0.s }, p0/z, [x1, x8]
; SME-NEXT: ld1b { z1.s }, p0/z, [x1]
; SME-NEXT: ld1sb { z2.s }, p0/z, [x2, x8]
; SME-NEXT: ld1sb { z3.s }, p0/z, [x2]
; SME-NEXT: mov w8, #20 // =0x14
; SME-NEXT: ld1b { z6.s }, p0/z, [x1, x8]
; SME-NEXT: mad z0.s, p0/m, z2.s, z4.s
; SME-NEXT: ld1b { z2.s }, p0/z, [x1, x9]
; SME-NEXT: ld1sb { z4.s }, p0/z, [x2, x9]
; SME-NEXT: mad z1.s, p0/m, z3.s, z5.s
; SME-NEXT: ld1sb { z3.s }, p0/z, [x2, x8]
; SME-NEXT: mov w8, #24 // =0x18
; SME-NEXT: mov w9, #8 // =0x8
; SME-NEXT: ld1b { z5.s }, p0/z, [x1, x8]
; SME-NEXT: mla z0.s, p0/m, z3.s, z6.s
; SME-NEXT: ld1sb { z3.s }, p0/z, [x2, x8]
; SME-NEXT: mov w8, #28 // =0x1c
; SME-NEXT: mla z1.s, p0/m, z4.s, z2.s
; SME-NEXT: ld1b { z2.s }, p0/z, [x1, x9]
; SME-NEXT: ld1sb { z4.s }, p0/z, [x2, x9]
; SME-NEXT: mov w9, #12 // =0xc
; SME-NEXT: ld1b { z6.s }, p0/z, [x1, x8]
; SME-NEXT: mla z1.s, p0/m, z4.s, z2.s
; SME-NEXT: movprfx z2, z0
; SME-NEXT: mla z2.s, p0/m, z3.s, z5.s
; SME-NEXT: ld1b { z0.s }, p0/z, [x1, x9]
; SME-NEXT: ld1sb { z3.s }, p0/z, [x2, x8]
; SME-NEXT: ld1sb { z4.s }, p0/z, [x2, x9]
; SME-NEXT: mad z0.s, p0/m, z4.s, z1.s
; SME-NEXT: movprfx z1, z2
; SME-NEXT: mla z1.s, p0/m, z3.s, z6.s
; SME-NEXT: // kill: def $q0 killed $q0 killed $z0
; SME-NEXT: // kill: def $q1 killed $q1 killed $z1
; SME-NEXT: ret
%acc = load <8 x i32>, ptr %accptr
%u = load <32 x i8>, ptr %uptr
%s = load <32 x i8>, ptr %sptr
%u.wide = zext <32 x i8> %u to <32 x i32>
%s.wide = sext <32 x i8> %s to <32 x i32>
%mult = mul nuw nsw <32 x i32> %s.wide, %u.wide
%partial.reduce = tail call <8 x i32> @llvm.experimental.vector.partial.reduce.add(<8 x i32> %acc, <32 x i32> %mult)
ret <8 x i32> %partial.reduce
}

define <8 x i32> @four_way_i8_i32_vl256(ptr %accptr, ptr %uptr, ptr %sptr) vscale_range(2,2) {
;
;
Expand Down Expand Up @@ -483,6 +584,61 @@ define <8 x i32> @four_way_i8_i32_vl256(ptr %accptr, ptr %uptr, ptr %sptr) vscal
ret <8 x i32> %partial.reduce
}

define <8 x i32> @four_way_i8_i32_vl256_usdot(ptr %accptr, ptr %uptr, ptr %sptr) vscale_range(2,2) {
;
;
; NEON-LABEL: four_way_i8_i32_vl256_usdot:
; NEON: // %bb.0:
; NEON-NEXT: ldp q0, q1, [x0]
; NEON-NEXT: ldp q3, q2, [x1]
; NEON-NEXT: ldp q5, q4, [x2]
; NEON-NEXT: usdot v0.4s, v3.16b, v5.16b
; NEON-NEXT: usdot v1.4s, v2.16b, v4.16b
; NEON-NEXT: ret
;
; SVE-LABEL: four_way_i8_i32_vl256_usdot:
; SVE: // %bb.0:
; SVE-NEXT: ldr z0, [x0]
; SVE-NEXT: ldr z1, [x1]
; SVE-NEXT: ldr z2, [x2]
; SVE-NEXT: usdot z0.s, z1.b, z2.b
; SVE-NEXT: mov z1.d, z0.d
; SVE-NEXT: ext z1.b, z1.b, z0.b, #16
; SVE-NEXT: // kill: def $q0 killed $q0 killed $z0
; SVE-NEXT: // kill: def $q1 killed $q1 killed $z1
; SVE-NEXT: ret
;
; SME-LABEL: four_way_i8_i32_vl256_usdot:
; SME: // %bb.0:
; SME-NEXT: ptrue p0.s
; SME-NEXT: ldr z0, [x0]
; SME-NEXT: ld1b { z1.s }, p0/z, [x1]
; SME-NEXT: ld1sb { z2.s }, p0/z, [x2]
; SME-NEXT: mla z0.s, p0/m, z2.s, z1.s
; SME-NEXT: ld1b { z1.s }, p0/z, [x1, #1, mul vl]
; SME-NEXT: ld1sb { z2.s }, p0/z, [x2, #1, mul vl]
; SME-NEXT: mla z0.s, p0/m, z2.s, z1.s
; SME-NEXT: ld1b { z1.s }, p0/z, [x1, #2, mul vl]
; SME-NEXT: ld1sb { z2.s }, p0/z, [x2, #2, mul vl]
; SME-NEXT: mla z0.s, p0/m, z2.s, z1.s
; SME-NEXT: ld1b { z1.s }, p0/z, [x1, #3, mul vl]
; SME-NEXT: ld1sb { z2.s }, p0/z, [x2, #3, mul vl]
; SME-NEXT: mla z0.s, p0/m, z2.s, z1.s
; SME-NEXT: mov z1.d, z0.d
; SME-NEXT: ext z1.b, z1.b, z0.b, #16
; SME-NEXT: // kill: def $q0 killed $q0 killed $z0
; SME-NEXT: // kill: def $q1 killed $q1 killed $z1
; SME-NEXT: ret
%acc = load <8 x i32>, ptr %accptr
%u = load <32 x i8>, ptr %uptr
%s = load <32 x i8>, ptr %sptr
%u.wide = zext <32 x i8> %u to <32 x i32>
%s.wide = sext <32 x i8> %s to <32 x i32>
%mult = mul nuw nsw <32 x i32> %s.wide, %u.wide
%partial.reduce = tail call <8 x i32> @llvm.experimental.vector.partial.reduce.add(<8 x i32> %acc, <32 x i32> %mult)
ret <8 x i32> %partial.reduce
}

;
; Four-way dot (i16 -> i64)
;
Expand Down
Loading