Skip to content

Commit ab25e26

Browse files
Lian WangLian Wang
authored andcommitted
[SelectionDAG] Enable WidenVecOp_VECREDUCE_SEQ for scalable vector
Reviewed By: sdesmalen Differential Revision: https://reviews.llvm.org/D127710
1 parent 8323209 commit ab25e26

File tree

3 files changed

+190
-4
lines changed

3 files changed

+190
-4
lines changed

llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6144,8 +6144,20 @@ SDValue DAGTypeLegalizer::WidenVecOp_VECREDUCE_SEQ(SDNode *N) {
61446144
SDValue NeutralElem = DAG.getNeutralElement(BaseOpc, dl, ElemVT, Flags);
61456145

61466146
// Pad the vector with the neutral element.
6147-
unsigned OrigElts = OrigVT.getVectorNumElements();
6148-
unsigned WideElts = WideVT.getVectorNumElements();
6147+
unsigned OrigElts = OrigVT.getVectorMinNumElements();
6148+
unsigned WideElts = WideVT.getVectorMinNumElements();
6149+
6150+
if (WideVT.isScalableVector()) {
6151+
unsigned GCD = greatestCommonDivisor(OrigElts, WideElts);
6152+
EVT SplatVT = EVT::getVectorVT(*DAG.getContext(), ElemVT,
6153+
ElementCount::getScalable(GCD));
6154+
SDValue SplatNeutral = DAG.getSplatVector(SplatVT, dl, NeutralElem);
6155+
for (unsigned Idx = OrigElts; Idx < WideElts; Idx = Idx + GCD)
6156+
Op = DAG.getNode(ISD::INSERT_SUBVECTOR, dl, WideVT, Op, SplatNeutral,
6157+
DAG.getVectorIdxConstant(Idx, dl));
6158+
return DAG.getNode(Opc, dl, N->getValueType(0), AccOp, Op, Flags);
6159+
}
6160+
61496161
for (unsigned Idx = OrigElts; Idx < WideElts; Idx++)
61506162
Op = DAG.getNode(ISD::INSERT_VECTOR_ELT, dl, WideVT, Op, NeutralElem,
61516163
DAG.getVectorIdxConstant(Idx, dl));

llvm/test/CodeGen/AArch64/sve-fp-reduce.ll

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,79 @@ define half @fadda_nxv8f16(half %init, <vscale x 8 x half> %a) {
2929
ret half %res
3030
}
3131

32+
define half @fadda_nxv6f16(<vscale x 6 x half> %v, half %s) {
33+
; CHECK-LABEL: fadda_nxv6f16:
34+
; CHECK: str x29, [sp, #-16]!
35+
; CHECK-NEXT: .cfi_def_cfa_offset 16
36+
; CHECK-NEXT: .cfi_offset w29, -16
37+
; CHECK-NEXT: addvl sp, sp, #-1
38+
; CHECK-NEXT: .cfi_escape 0x0f, 0x0c, 0x8f, 0x00, 0x11, 0x10, 0x22, 0x11, 0x08, 0x92, 0x2e, 0x00, 0x1e, 0x22
39+
; CHECK-NEXT: adrp x8, .LCPI3_0
40+
; CHECK-NEXT: add x8, x8, :lo12:.LCPI3_0
41+
; CHECK-NEXT: ptrue p0.h
42+
; CHECK-NEXT: ptrue p1.d
43+
; CHECK-NEXT: st1h { z0.h }, p0, [sp]
44+
; CHECK-NEXT: ld1rh { z0.d }, p1/z, [x8]
45+
; CHECK-NEXT: st1h { z0.d }, p1, [sp, #3, mul vl]
46+
; CHECK-NEXT: fmov s0, s1
47+
; CHECK-NEXT: ld1h { z2.h }, p0/z, [sp]
48+
; CHECK-NEXT: fadda h0, p0, h0, z2.h
49+
; CHECK-NEXT: addvl sp, sp, #1
50+
; CHECK-NEXT: ldr x29, [sp], #16
51+
; CHECK-NEXT: ret
52+
%res = call half @llvm.vector.reduce.fadd.nxv6f16(half %s, <vscale x 6 x half> %v)
53+
ret half %res
54+
}
55+
56+
define half @fadda_nxv10f16(<vscale x 10 x half> %v, half %s) {
57+
; CHECK-LABEL: fadda_nxv10f16:
58+
; CHECK: str x29, [sp, #-16]!
59+
; CHECK-NEXT: .cfi_def_cfa_offset 16
60+
; CHECK-NEXT: .cfi_offset w29, -16
61+
; CHECK-NEXT: addvl sp, sp, #-3
62+
; CHECK-NEXT: .cfi_escape 0x0f, 0x0c, 0x8f, 0x00, 0x11, 0x10, 0x22, 0x11, 0x18, 0x92, 0x2e, 0x00, 0x1e, 0x22
63+
; CHECK-NEXT: adrp x8, .LCPI4_0
64+
; CHECK-NEXT: add x8, x8, :lo12:.LCPI4_0
65+
; CHECK-NEXT: ptrue p0.h
66+
; CHECK-NEXT: ptrue p1.d
67+
; CHECK-NEXT: st1h { z1.h }, p0, [sp]
68+
; CHECK-NEXT: ld1rh { z1.d }, p1/z, [x8]
69+
; CHECK-NEXT: addvl x8, sp, #1
70+
; CHECK-NEXT: fadda h2, p0, h2, z0.h
71+
; CHECK-NEXT: st1h { z1.d }, p1, [sp, #1, mul vl]
72+
; CHECK-NEXT: ld1h { z3.h }, p0/z, [sp]
73+
; CHECK-NEXT: st1h { z3.h }, p0, [sp, #1, mul vl]
74+
; CHECK-NEXT: st1h { z1.d }, p1, [sp, #6, mul vl]
75+
; CHECK-NEXT: ld1h { z3.h }, p0/z, [sp, #1, mul vl]
76+
; CHECK-NEXT: st1h { z3.h }, p0, [sp, #2, mul vl]
77+
; CHECK-NEXT: st1h { z1.d }, p1, [x8, #7, mul vl]
78+
; CHECK-NEXT: ld1h { z1.h }, p0/z, [sp, #2, mul vl]
79+
; CHECK-NEXT: fadda h2, p0, h2, z1.h
80+
; CHECK-NEXT: fmov s0, s2
81+
; CHECK-NEXT: addvl sp, sp, #3
82+
; CHECK-NEXT: ldr x29, [sp], #16
83+
; CHECK-NEXT: ret
84+
%res = call half @llvm.vector.reduce.fadd.nxv10f16(half %s, <vscale x 10 x half> %v)
85+
ret half %res
86+
}
87+
88+
define half @fadda_nxv12f16(<vscale x 12 x half> %v, half %s) {
89+
; CHECK-LABEL: fadda_nxv12f16:
90+
; CHECK: adrp x8, .LCPI5_0
91+
; CHECK-NEXT: add x8, x8, :lo12:.LCPI5_0
92+
; CHECK-NEXT: ptrue p0.s
93+
; CHECK-NEXT: uunpklo z1.s, z1.h
94+
; CHECK-NEXT: ld1rh { z3.s }, p0/z, [x8]
95+
; CHECK-NEXT: ptrue p0.h
96+
; CHECK-NEXT: fadda h2, p0, h2, z0.h
97+
; CHECK-NEXT: uzp1 z1.h, z1.h, z3.h
98+
; CHECK-NEXT: fadda h2, p0, h2, z1.h
99+
; CHECK-NEXT: fmov s0, s2
100+
; CHECK-NEXT: ret
101+
%res = call half @llvm.vector.reduce.fadd.nxv12f16(half %s, <vscale x 12 x half> %v)
102+
ret half %res
103+
}
104+
32105
define float @fadda_nxv2f32(float %init, <vscale x 2 x float> %a) {
33106
; CHECK-LABEL: fadda_nxv2f32:
34107
; CHECK: ptrue p0.d
@@ -233,6 +306,9 @@ define double @fminv_nxv2f64(<vscale x 2 x double> %a) {
233306
declare half @llvm.vector.reduce.fadd.nxv2f16(half, <vscale x 2 x half>)
234307
declare half @llvm.vector.reduce.fadd.nxv4f16(half, <vscale x 4 x half>)
235308
declare half @llvm.vector.reduce.fadd.nxv8f16(half, <vscale x 8 x half>)
309+
declare half @llvm.vector.reduce.fadd.nxv6f16(half, <vscale x 6 x half>)
310+
declare half @llvm.vector.reduce.fadd.nxv10f16(half, <vscale x 10 x half>)
311+
declare half @llvm.vector.reduce.fadd.nxv12f16(half, <vscale x 12 x half>)
236312
declare float @llvm.vector.reduce.fadd.nxv2f32(float, <vscale x 2 x float>)
237313
declare float @llvm.vector.reduce.fadd.nxv4f32(float, <vscale x 4 x float>)
238314
declare double @llvm.vector.reduce.fadd.nxv2f64(double, <vscale x 2 x double>)

llvm/test/CodeGen/RISCV/rvv/vreductions-fp-sdnode.ll

Lines changed: 100 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py
2-
; RUN: llc -mtriple=riscv32 -mattr=+d,+zfh,+experimental-zvfh,+v -target-abi=ilp32d \
2+
; RUN: llc -mtriple=riscv32 -mattr=+d,+zfh,+experimental-zvfh,+v,+m -target-abi=ilp32d \
33
; RUN: -verify-machineinstrs < %s | FileCheck %s
4-
; RUN: llc -mtriple=riscv64 -mattr=+d,+zfh,+experimental-zvfh,+v -target-abi=lp64d \
4+
; RUN: llc -mtriple=riscv64 -mattr=+d,+zfh,+experimental-zvfh,+v,+m -target-abi=lp64d \
55
; RUN: -verify-machineinstrs < %s | FileCheck %s
66

77
declare half @llvm.vector.reduce.fadd.nxv1f16(half, <vscale x 1 x half>)
@@ -1048,3 +1048,101 @@ define float @vreduce_nsz_fadd_nxv1f32(<vscale x 1 x float> %v, float %s) {
10481048
%red = call reassoc nsz float @llvm.vector.reduce.fadd.nxv1f32(float %s, <vscale x 1 x float> %v)
10491049
ret float %red
10501050
}
1051+
1052+
; Test Widen VECREDUCE_SEQ_FADD
1053+
declare half @llvm.vector.reduce.fadd.nxv3f16(half, <vscale x 3 x half>)
1054+
1055+
define half @vreduce_ord_fadd_nxv3f16(<vscale x 3 x half> %v, half %s) {
1056+
; CHECK-LABEL: vreduce_ord_fadd_nxv3f16:
1057+
; CHECK: # %bb.0:
1058+
; CHECK-NEXT: csrr a0, vlenb
1059+
; CHECK-NEXT: srli a0, a0, 3
1060+
; CHECK-NEXT: slli a1, a0, 1
1061+
; CHECK-NEXT: add a1, a1, a0
1062+
; CHECK-NEXT: add a0, a1, a0
1063+
; CHECK-NEXT: fmv.h.x ft0, zero
1064+
; CHECK-NEXT: fneg.h ft0, ft0
1065+
; CHECK-NEXT: vsetvli a2, zero, e16, m1, ta, mu
1066+
; CHECK-NEXT: vfmv.v.f v9, ft0
1067+
; CHECK-NEXT: vsetvli zero, a0, e16, m1, tu, mu
1068+
; CHECK-NEXT: vslideup.vx v8, v9, a1
1069+
; CHECK-NEXT: vsetivli zero, 1, e16, m1, ta, mu
1070+
; CHECK-NEXT: vfmv.s.f v9, fa0
1071+
; CHECK-NEXT: vsetvli a0, zero, e16, m1, ta, mu
1072+
; CHECK-NEXT: vfredosum.vs v8, v8, v9
1073+
; CHECK-NEXT: vfmv.f.s fa0, v8
1074+
; CHECK-NEXT: ret
1075+
%red = call half @llvm.vector.reduce.fadd.nxv3f16(half %s, <vscale x 3 x half> %v)
1076+
ret half %red
1077+
}
1078+
1079+
declare half @llvm.vector.reduce.fadd.nxv6f16(half, <vscale x 6 x half>)
1080+
1081+
define half @vreduce_ord_fadd_nxv6f16(<vscale x 6 x half> %v, half %s) {
1082+
; CHECK-LABEL: vreduce_ord_fadd_nxv6f16:
1083+
; CHECK: # %bb.0:
1084+
; CHECK-NEXT: csrr a0, vlenb
1085+
; CHECK-NEXT: srli a0, a0, 2
1086+
; CHECK-NEXT: add a1, a0, a0
1087+
; CHECK-NEXT: fmv.h.x ft0, zero
1088+
; CHECK-NEXT: fneg.h ft0, ft0
1089+
; CHECK-NEXT: vsetvli a2, zero, e16, m1, ta, mu
1090+
; CHECK-NEXT: vfmv.v.f v10, ft0
1091+
; CHECK-NEXT: vsetvli zero, a1, e16, m1, tu, mu
1092+
; CHECK-NEXT: vslideup.vx v9, v10, a0
1093+
; CHECK-NEXT: vsetivli zero, 1, e16, m1, ta, mu
1094+
; CHECK-NEXT: vfmv.s.f v10, fa0
1095+
; CHECK-NEXT: vsetvli a0, zero, e16, m2, ta, mu
1096+
; CHECK-NEXT: vfredosum.vs v8, v8, v10
1097+
; CHECK-NEXT: vfmv.f.s fa0, v8
1098+
; CHECK-NEXT: ret
1099+
%red = call half @llvm.vector.reduce.fadd.nxv6f16(half %s, <vscale x 6 x half> %v)
1100+
ret half %red
1101+
}
1102+
1103+
declare half @llvm.vector.reduce.fadd.nxv10f16(half, <vscale x 10 x half>)
1104+
1105+
define half @vreduce_ord_fadd_nxv10f16(<vscale x 10 x half> %v, half %s) {
1106+
; CHECK-LABEL: vreduce_ord_fadd_nxv10f16:
1107+
; CHECK: # %bb.0:
1108+
; CHECK-NEXT: csrr a0, vlenb
1109+
; CHECK-NEXT: srli a0, a0, 2
1110+
; CHECK-NEXT: add a1, a0, a0
1111+
; CHECK-NEXT: fmv.h.x ft0, zero
1112+
; CHECK-NEXT: fneg.h ft0, ft0
1113+
; CHECK-NEXT: vsetvli a2, zero, e16, m1, ta, mu
1114+
; CHECK-NEXT: vfmv.v.f v12, ft0
1115+
; CHECK-NEXT: vsetvli zero, a1, e16, m1, tu, mu
1116+
; CHECK-NEXT: vslideup.vx v10, v12, a0
1117+
; CHECK-NEXT: vsetvli zero, a0, e16, m1, tu, mu
1118+
; CHECK-NEXT: vslideup.vi v11, v12, 0
1119+
; CHECK-NEXT: vsetvli zero, a1, e16, m1, tu, mu
1120+
; CHECK-NEXT: vslideup.vx v11, v12, a0
1121+
; CHECK-NEXT: vsetivli zero, 1, e16, m1, ta, mu
1122+
; CHECK-NEXT: vfmv.s.f v12, fa0
1123+
; CHECK-NEXT: vsetvli a0, zero, e16, m4, ta, mu
1124+
; CHECK-NEXT: vfredosum.vs v8, v8, v12
1125+
; CHECK-NEXT: vfmv.f.s fa0, v8
1126+
; CHECK-NEXT: ret
1127+
%red = call half @llvm.vector.reduce.fadd.nxv10f16(half %s, <vscale x 10 x half> %v)
1128+
ret half %red
1129+
}
1130+
1131+
declare half @llvm.vector.reduce.fadd.nxv12f16(half, <vscale x 12 x half>)
1132+
1133+
define half @vreduce_ord_fadd_nxv12f16(<vscale x 12 x half> %v, half %s) {
1134+
; CHECK-LABEL: vreduce_ord_fadd_nxv12f16:
1135+
; CHECK: # %bb.0:
1136+
; CHECK-NEXT: vsetivli zero, 1, e16, m1, ta, mu
1137+
; CHECK-NEXT: vfmv.s.f v12, fa0
1138+
; CHECK-NEXT: fmv.h.x ft0, zero
1139+
; CHECK-NEXT: fneg.h ft0, ft0
1140+
; CHECK-NEXT: vsetvli a0, zero, e16, m1, ta, mu
1141+
; CHECK-NEXT: vfmv.v.f v11, ft0
1142+
; CHECK-NEXT: vsetvli a0, zero, e16, m4, ta, mu
1143+
; CHECK-NEXT: vfredosum.vs v8, v8, v12
1144+
; CHECK-NEXT: vfmv.f.s fa0, v8
1145+
; CHECK-NEXT: ret
1146+
%red = call half @llvm.vector.reduce.fadd.nxv12f16(half %s, <vscale x 12 x half> %v)
1147+
ret half %red
1148+
}

0 commit comments

Comments
 (0)