Skip to content

Commit 3eea01e

Browse files
authored
[AArch64] Add ISel support for partial reductions to use SVE2.1 udot/sdot (#158310)
This allows dot products with scalable 8xi16 vectors (and fixed-length vectors which are converted into a scalable vector) accumulating into a 4xi32 vector to lower into a single instruction (`udot`/`sdot`), rather than a sequence of `umlalb`s and `umlalt`s`.
1 parent f9f62ef commit 3eea01e

File tree

2 files changed

+112
-0
lines changed

2 files changed

+112
-0
lines changed

llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4238,6 +4238,13 @@ defm UDOT_ZZZ_HtoS : sve2p1_two_way_dot_vv<"udot", 0b1, int_aarch64_sve_udot_x2
42384238
defm SDOT_ZZZI_HtoS : sve2p1_two_way_dot_vvi<"sdot", 0b0, int_aarch64_sve_sdot_lane_x2>;
42394239
defm UDOT_ZZZI_HtoS : sve2p1_two_way_dot_vvi<"udot", 0b1, int_aarch64_sve_udot_lane_x2>;
42404240

4241+
let Predicates = [HasSVE2p1_or_SME2] in {
4242+
def : Pat<(nxv4i32 (partial_reduce_umla nxv4i32:$Acc, nxv8i16:$MulLHS, nxv8i16:$MulRHS)),
4243+
(UDOT_ZZZ_HtoS $Acc, $MulLHS, $MulRHS)>;
4244+
def : Pat<(nxv4i32 (partial_reduce_smla nxv4i32:$Acc, nxv8i16:$MulLHS, nxv8i16:$MulRHS)),
4245+
(SDOT_ZZZ_HtoS $Acc, $MulLHS, $MulRHS)>;
4246+
} // End HasSVE2p1_or_SME2
4247+
42414248
defm SQCVTN_Z2Z_StoH : sve2p1_multi_vec_extract_narrow<"sqcvtn", 0b00, int_aarch64_sve_sqcvtn_x2>;
42424249
defm UQCVTN_Z2Z_StoH : sve2p1_multi_vec_extract_narrow<"uqcvtn", 0b01, int_aarch64_sve_uqcvtn_x2>;
42434250
defm SQCVTUN_Z2Z_StoH : sve2p1_multi_vec_extract_narrow<"sqcvtun", 0b10, int_aarch64_sve_sqcvtun_x2>;
Lines changed: 105 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,105 @@
1+
; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5
2+
; RUN: llc -mtriple=aarch64-linux-gnu -mattr=+sve2p1 < %s | FileCheck %s
3+
; RUN: llc -mtriple=aarch64-linux-gnu -mattr=+sme2 -force-streaming < %s | FileCheck %s
4+
5+
define <vscale x 4 x i32> @udot_vl128(<vscale x 4 x i32> %acc, <vscale x 8 x i16> %a, <vscale x 8 x i16> %b) {
6+
; CHECK-LABEL: udot_vl128:
7+
; CHECK: // %bb.0: // %entry
8+
; CHECK-NEXT: udot z0.s, z1.h, z2.h
9+
; CHECK-NEXT: ret
10+
entry:
11+
%a.wide = zext <vscale x 8 x i16> %a to <vscale x 8 x i32>
12+
%b.wide = zext <vscale x 8 x i16> %b to <vscale x 8 x i32>
13+
%mult = mul nuw nsw <vscale x 8 x i32> %a.wide, %b.wide
14+
%partial.reduce = tail call <vscale x 4 x i32> @llvm.experimental.vector.partial.reduce.add(<vscale x 4 x i32> %acc, <vscale x 8 x i32> %mult)
15+
ret <vscale x 4 x i32> %partial.reduce
16+
}
17+
18+
define <vscale x 4 x i32> @sdot_vl128(<vscale x 4 x i32> %acc, <vscale x 8 x i16> %a, <vscale x 8 x i16> %b) {
19+
; CHECK-LABEL: sdot_vl128:
20+
; CHECK: // %bb.0: // %entry
21+
; CHECK-NEXT: sdot z0.s, z1.h, z2.h
22+
; CHECK-NEXT: ret
23+
entry:
24+
%a.wide = sext <vscale x 8 x i16> %a to <vscale x 8 x i32>
25+
%b.wide = sext <vscale x 8 x i16> %b to <vscale x 8 x i32>
26+
%mult = mul nuw nsw <vscale x 8 x i32> %a.wide, %b.wide
27+
%partial.reduce = tail call <vscale x 4 x i32> @llvm.experimental.vector.partial.reduce.add(<vscale x 4 x i32> %acc, <vscale x 8 x i32> %mult)
28+
ret <vscale x 4 x i32> %partial.reduce
29+
}
30+
31+
define void @udot_vl256(ptr %accptr, ptr %aptr, ptr %bptr) vscale_range(2,2) {
32+
; CHECK-LABEL: udot_vl256:
33+
; CHECK: // %bb.0: // %entry
34+
; CHECK-NEXT: ldr z0, [x0]
35+
; CHECK-NEXT: ldr z1, [x1]
36+
; CHECK-NEXT: ldr z2, [x2]
37+
; CHECK-NEXT: udot z0.s, z1.h, z2.h
38+
; CHECK-NEXT: str z0, [x0]
39+
; CHECK-NEXT: ret
40+
entry:
41+
%acc = load <8 x i32>, ptr %accptr
42+
%a = load <16 x i16>, ptr %aptr
43+
%b = load <16 x i16>, ptr %bptr
44+
%a.wide = zext <16 x i16> %a to <16 x i32>
45+
%b.wide = zext <16 x i16> %b to <16 x i32>
46+
%mult = mul nuw nsw <16 x i32> %a.wide, %b.wide
47+
%partial.reduce = tail call <8 x i32> @llvm.experimental.vector.partial.reduce.add(<8 x i32> %acc, <16 x i32> %mult)
48+
store <8 x i32> %partial.reduce, ptr %accptr
49+
ret void
50+
}
51+
52+
define void @sdot_vl256(ptr %accptr, ptr %aptr, ptr %bptr) vscale_range(2,2) {
53+
; CHECK-LABEL: sdot_vl256:
54+
; CHECK: // %bb.0: // %entry
55+
; CHECK-NEXT: ldr z0, [x0]
56+
; CHECK-NEXT: ldr z1, [x1]
57+
; CHECK-NEXT: ldr z2, [x2]
58+
; CHECK-NEXT: sdot z0.s, z1.h, z2.h
59+
; CHECK-NEXT: str z0, [x0]
60+
; CHECK-NEXT: ret
61+
entry:
62+
%acc = load <8 x i32>, ptr %accptr
63+
%a = load <16 x i16>, ptr %aptr
64+
%b = load <16 x i16>, ptr %bptr
65+
%a.wide = sext <16 x i16> %a to <16 x i32>
66+
%b.wide = sext <16 x i16> %b to <16 x i32>
67+
%mult = mul nuw nsw <16 x i32> %a.wide, %b.wide
68+
%partial.reduce = tail call <8 x i32> @llvm.experimental.vector.partial.reduce.add(<8 x i32> %acc, <16 x i32> %mult)
69+
store <8 x i32> %partial.reduce, ptr %accptr
70+
ret void
71+
}
72+
73+
define <4 x i32> @fixed_udot_s_h(<4 x i32> %acc, <8 x i16> %a, <8 x i16> %b) {
74+
; CHECK-LABEL: fixed_udot_s_h:
75+
; CHECK: // %bb.0: // %entry
76+
; CHECK-NEXT: // kill: def $q0 killed $q0 def $z0
77+
; CHECK-NEXT: // kill: def $q2 killed $q2 def $z2
78+
; CHECK-NEXT: // kill: def $q1 killed $q1 def $z1
79+
; CHECK-NEXT: udot z0.s, z1.h, z2.h
80+
; CHECK-NEXT: // kill: def $q0 killed $q0 killed $z0
81+
; CHECK-NEXT: ret
82+
entry:
83+
%a.wide = zext <8 x i16> %a to <8 x i32>
84+
%b.wide = zext <8 x i16> %b to <8 x i32>
85+
%mult = mul nuw nsw <8 x i32> %a.wide, %b.wide
86+
%partial.reduce = tail call <4 x i32> @llvm.experimental.vector.partial.reduce.add(<4 x i32> %acc, <8 x i32> %mult)
87+
ret <4 x i32> %partial.reduce
88+
}
89+
90+
define <4 x i32> @fixed_sdot_s_h(<4 x i32> %acc, <8 x i16> %a, <8 x i16> %b) {
91+
; CHECK-LABEL: fixed_sdot_s_h:
92+
; CHECK: // %bb.0: // %entry
93+
; CHECK-NEXT: // kill: def $q0 killed $q0 def $z0
94+
; CHECK-NEXT: // kill: def $q2 killed $q2 def $z2
95+
; CHECK-NEXT: // kill: def $q1 killed $q1 def $z1
96+
; CHECK-NEXT: sdot z0.s, z1.h, z2.h
97+
; CHECK-NEXT: // kill: def $q0 killed $q0 killed $z0
98+
; CHECK-NEXT: ret
99+
entry:
100+
%a.wide = sext <8 x i16> %a to <8 x i32>
101+
%b.wide = sext <8 x i16> %b to <8 x i32>
102+
%mult = mul nuw nsw <8 x i32> %a.wide, %b.wide
103+
%partial.reduce = tail call <4 x i32> @llvm.experimental.vector.partial.reduce.add(<4 x i32> %acc, <8 x i32> %mult)
104+
ret <4 x i32> %partial.reduce
105+
}

0 commit comments

Comments
 (0)