Skip to content

Commit 5fa3ccb

Browse files
[AArch64] Use SVE fdot for partial.reduce.fadd for NEON types. (#167856)
We only seem to use the SVE fdot for fixed-length vector types when they are larger than 128bits, whereas we can also use them for 128bits vectors if SVE2p1/SME2 is available.
1 parent a5342d5 commit 5fa3ccb

File tree

2 files changed

+60
-11
lines changed

2 files changed

+60
-11
lines changed

llvm/lib/Target/AArch64/AArch64ISelLowering.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1921,6 +1921,9 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
19211921
if (Subtarget->hasSVE2p1() || Subtarget->hasSME2()) {
19221922
setPartialReduceMLAAction(ISD::PARTIAL_REDUCE_FMLA, MVT::nxv4f32,
19231923
MVT::nxv8f16, Legal);
1924+
// We can use SVE2p1 fdot to emulate the fixed-length variant.
1925+
setPartialReduceMLAAction(ISD::PARTIAL_REDUCE_FMLA, MVT::v4f32,
1926+
MVT::v8f16, Custom);
19241927
}
19251928
}
19261929

llvm/test/CodeGen/AArch64/sve2p1-fixed-length-fdot.ll

Lines changed: 57 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,43 @@
44

55
target triple = "aarch64-linux-gnu"
66

7+
define void @fdot_v4f32(ptr %accptr, ptr %aptr, ptr %bptr) {
8+
; SVE2-LABEL: fdot_v4f32:
9+
; SVE2: // %bb.0: // %entry
10+
; SVE2-NEXT: ldr q0, [x1]
11+
; SVE2-NEXT: ldr q1, [x2]
12+
; SVE2-NEXT: fcvtl v2.4s, v0.4h
13+
; SVE2-NEXT: fcvtl v3.4s, v1.4h
14+
; SVE2-NEXT: fcvtl2 v0.4s, v0.8h
15+
; SVE2-NEXT: fcvtl2 v1.4s, v1.8h
16+
; SVE2-NEXT: fmul v2.4s, v2.4s, v3.4s
17+
; SVE2-NEXT: ldr q3, [x0]
18+
; SVE2-NEXT: fmul v0.4s, v0.4s, v1.4s
19+
; SVE2-NEXT: fadd v1.4s, v3.4s, v2.4s
20+
; SVE2-NEXT: fadd v0.4s, v1.4s, v0.4s
21+
; SVE2-NEXT: str q0, [x0]
22+
; SVE2-NEXT: ret
23+
;
24+
; SVE2P1-LABEL: fdot_v4f32:
25+
; SVE2P1: // %bb.0: // %entry
26+
; SVE2P1-NEXT: ldr q0, [x0]
27+
; SVE2P1-NEXT: ldr q1, [x1]
28+
; SVE2P1-NEXT: ldr q2, [x2]
29+
; SVE2P1-NEXT: fdot z0.s, z1.h, z2.h
30+
; SVE2P1-NEXT: str q0, [x0]
31+
; SVE2P1-NEXT: ret
32+
entry:
33+
%acc = load <4 x float>, ptr %accptr
34+
%a = load <8 x half>, ptr %aptr
35+
%b = load <8 x half>, ptr %bptr
36+
%a.wide = fpext <8 x half> %a to <8 x float>
37+
%b.wide = fpext <8 x half> %b to <8 x float>
38+
%mult = fmul <8 x float> %a.wide, %b.wide
39+
%partial.reduce = call <4 x float> @llvm.vector.partial.reduce.fadd(<4 x float> %acc, <8 x float> %mult)
40+
store <4 x float> %partial.reduce, ptr %accptr
41+
ret void
42+
}
43+
744
define void @fdot_wide_v8f32(ptr %accptr, ptr %aptr, ptr %bptr) vscale_range(2,0) {
845
; SVE2-LABEL: fdot_wide_v8f32:
946
; SVE2: // %bb.0: // %entry
@@ -177,17 +214,26 @@ entry:
177214
}
178215

179216
define <4 x float> @fixed_fdot_wide(<4 x float> %acc, <8 x half> %a, <8 x half> %b) {
180-
; CHECK-LABEL: fixed_fdot_wide:
181-
; CHECK: // %bb.0: // %entry
182-
; CHECK-NEXT: fcvtl v3.4s, v1.4h
183-
; CHECK-NEXT: fcvtl v4.4s, v2.4h
184-
; CHECK-NEXT: fcvtl2 v1.4s, v1.8h
185-
; CHECK-NEXT: fcvtl2 v2.4s, v2.8h
186-
; CHECK-NEXT: fmul v3.4s, v3.4s, v4.4s
187-
; CHECK-NEXT: fmul v1.4s, v1.4s, v2.4s
188-
; CHECK-NEXT: fadd v0.4s, v0.4s, v3.4s
189-
; CHECK-NEXT: fadd v0.4s, v0.4s, v1.4s
190-
; CHECK-NEXT: ret
217+
; SVE2-LABEL: fixed_fdot_wide:
218+
; SVE2: // %bb.0: // %entry
219+
; SVE2-NEXT: fcvtl v3.4s, v1.4h
220+
; SVE2-NEXT: fcvtl v4.4s, v2.4h
221+
; SVE2-NEXT: fcvtl2 v1.4s, v1.8h
222+
; SVE2-NEXT: fcvtl2 v2.4s, v2.8h
223+
; SVE2-NEXT: fmul v3.4s, v3.4s, v4.4s
224+
; SVE2-NEXT: fmul v1.4s, v1.4s, v2.4s
225+
; SVE2-NEXT: fadd v0.4s, v0.4s, v3.4s
226+
; SVE2-NEXT: fadd v0.4s, v0.4s, v1.4s
227+
; SVE2-NEXT: ret
228+
;
229+
; SVE2P1-LABEL: fixed_fdot_wide:
230+
; SVE2P1: // %bb.0: // %entry
231+
; SVE2P1-NEXT: // kill: def $q0 killed $q0 def $z0
232+
; SVE2P1-NEXT: // kill: def $q2 killed $q2 def $z2
233+
; SVE2P1-NEXT: // kill: def $q1 killed $q1 def $z1
234+
; SVE2P1-NEXT: fdot z0.s, z1.h, z2.h
235+
; SVE2P1-NEXT: // kill: def $q0 killed $q0 killed $z0
236+
; SVE2P1-NEXT: ret
191237
entry:
192238
%a.wide = fpext <8 x half> %a to <8 x float>
193239
%b.wide = fpext <8 x half> %b to <8 x float>

0 commit comments

Comments
 (0)