Skip to content

Commit 50a32d2

Browse files
[LLVM][CodeGen][SVE] Add isel for bfloat unordered reductions.
The omissions are VECREDUCE_SEQ_* and MUL. The former goes down a different code path and the latter is generally unsupport across all element types.
1 parent bb531ff commit 50a32d2

File tree

4 files changed

+275
-15
lines changed

4 files changed

+275
-15
lines changed

llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp

Lines changed: 27 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -188,6 +188,7 @@ class VectorLegalizer {
188188
void PromoteSETCC(SDNode *Node, SmallVectorImpl<SDValue> &Results);
189189

190190
void PromoteSTRICT(SDNode *Node, SmallVectorImpl<SDValue> &Results);
191+
void PromoteVECREDUCE(SDNode *Node, SmallVectorImpl<SDValue> &Results);
191192

192193
public:
193194
VectorLegalizer(SelectionDAG& dag) :
@@ -500,20 +501,14 @@ SDValue VectorLegalizer::LegalizeOp(SDValue Op) {
500501
case ISD::VECREDUCE_UMAX:
501502
case ISD::VECREDUCE_UMIN:
502503
case ISD::VECREDUCE_FADD:
503-
case ISD::VECREDUCE_FMUL:
504-
case ISD::VECTOR_FIND_LAST_ACTIVE:
505-
Action = TLI.getOperationAction(Node->getOpcode(),
506-
Node->getOperand(0).getValueType());
507-
break;
508504
case ISD::VECREDUCE_FMAX:
509-
case ISD::VECREDUCE_FMIN:
510505
case ISD::VECREDUCE_FMAXIMUM:
506+
case ISD::VECREDUCE_FMIN:
511507
case ISD::VECREDUCE_FMINIMUM:
508+
case ISD::VECREDUCE_FMUL:
509+
case ISD::VECTOR_FIND_LAST_ACTIVE:
512510
Action = TLI.getOperationAction(Node->getOpcode(),
513511
Node->getOperand(0).getValueType());
514-
// Defer non-vector results to LegalizeDAG.
515-
if (Action == TargetLowering::Promote)
516-
Action = TargetLowering::Legal;
517512
break;
518513
case ISD::VECREDUCE_SEQ_FADD:
519514
case ISD::VECREDUCE_SEQ_FMUL:
@@ -688,6 +683,22 @@ void VectorLegalizer::PromoteSTRICT(SDNode *Node,
688683
Results.push_back(Round.getValue(1));
689684
}
690685

686+
void VectorLegalizer::PromoteVECREDUCE(SDNode *Node,
687+
SmallVectorImpl<SDValue> &Results) {
688+
MVT OpVT = Node->getOperand(0).getSimpleValueType();
689+
assert(OpVT.isFloatingPoint() && "Expected floating point reduction!");
690+
MVT NewOpVT = TLI.getTypeToPromoteTo(Node->getOpcode(), OpVT);
691+
692+
SDLoc DL(Node);
693+
SDValue NewOp = DAG.getNode(ISD::FP_EXTEND, DL, NewOpVT, Node->getOperand(0));
694+
SDValue Rdx =
695+
DAG.getNode(Node->getOpcode(), DL, NewOpVT.getVectorElementType(), NewOp,
696+
Node->getFlags());
697+
SDValue Res = DAG.getNode(ISD::FP_ROUND, DL, Node->getValueType(0), Rdx,
698+
DAG.getIntPtrConstant(0, DL, /*isTarget=*/true));
699+
Results.push_back(Res);
700+
}
701+
691702
void VectorLegalizer::Promote(SDNode *Node, SmallVectorImpl<SDValue> &Results) {
692703
// For a few operations there is a specific concept for promotion based on
693704
// the operand's type.
@@ -719,6 +730,13 @@ void VectorLegalizer::Promote(SDNode *Node, SmallVectorImpl<SDValue> &Results) {
719730
case ISD::STRICT_FMA:
720731
PromoteSTRICT(Node, Results);
721732
return;
733+
case ISD::VECREDUCE_FADD:
734+
case ISD::VECREDUCE_FMAX:
735+
case ISD::VECREDUCE_FMAXIMUM:
736+
case ISD::VECREDUCE_FMIN:
737+
case ISD::VECREDUCE_FMINIMUM:
738+
PromoteVECREDUCE(Node, Results);
739+
return;
722740
case ISD::FP_ROUND:
723741
case ISD::FP_EXTEND:
724742
// These operations are used to do promotion so they can't be promoted

llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -11412,13 +11412,9 @@ SDValue TargetLowering::expandVecReduce(SDNode *Node, SelectionDAG &DAG) const {
1141211412
SDValue Op = Node->getOperand(0);
1141311413
EVT VT = Op.getValueType();
1141411414

11415-
if (VT.isScalableVector())
11416-
report_fatal_error(
11417-
"Expanding reductions for scalable vectors is undefined.");
11418-
1141911415
// Try to use a shuffle reduction for power of two vectors.
1142011416
if (VT.isPow2VectorType()) {
11421-
while (VT.getVectorNumElements() > 1) {
11417+
while (VT.getVectorElementCount().isKnownMultipleOf(2)) {
1142211418
EVT HalfVT = VT.getHalfNumVectorElementsVT(*DAG.getContext());
1142311419
if (!isOperationLegalOrCustom(BaseOpcode, HalfVT))
1142411420
break;
@@ -11427,9 +11423,18 @@ SDValue TargetLowering::expandVecReduce(SDNode *Node, SelectionDAG &DAG) const {
1142711423
std::tie(Lo, Hi) = DAG.SplitVector(Op, dl);
1142811424
Op = DAG.getNode(BaseOpcode, dl, HalfVT, Lo, Hi, Node->getFlags());
1142911425
VT = HalfVT;
11426+
11427+
// Stop if splitting is enough to make the reduction legal.
11428+
if (isOperationLegalOrCustom(Node->getOpcode(), HalfVT))
11429+
return DAG.getNode(Node->getOpcode(), dl, Node->getValueType(0), Op,
11430+
Node->getFlags());
1143011431
}
1143111432
}
1143211433

11434+
if (VT.isScalableVector())
11435+
report_fatal_error(
11436+
"Expanding reductions for scalable vectors is undefined.");
11437+
1143311438
EVT EltVT = VT.getVectorElementType();
1143411439
unsigned NumElts = VT.getVectorNumElements();
1143511440

llvm/lib/Target/AArch64/AArch64ISelLowering.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1780,7 +1780,9 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
17801780

17811781
for (auto Opcode :
17821782
{ISD::FCEIL, ISD::FDIV, ISD::FFLOOR, ISD::FNEARBYINT, ISD::FRINT,
1783-
ISD::FROUND, ISD::FROUNDEVEN, ISD::FSQRT, ISD::FTRUNC, ISD::SETCC}) {
1783+
ISD::FROUND, ISD::FROUNDEVEN, ISD::FSQRT, ISD::FTRUNC, ISD::SETCC,
1784+
ISD::VECREDUCE_FADD, ISD::VECREDUCE_FMAX, ISD::VECREDUCE_FMAXIMUM,
1785+
ISD::VECREDUCE_FMIN, ISD::VECREDUCE_FMINIMUM}) {
17841786
setOperationPromotedToType(Opcode, MVT::nxv2bf16, MVT::nxv2f32);
17851787
setOperationPromotedToType(Opcode, MVT::nxv4bf16, MVT::nxv4f32);
17861788
setOperationPromotedToType(Opcode, MVT::nxv8bf16, MVT::nxv8f32);
Lines changed: 235 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,235 @@
1+
; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5
2+
; RUN: llc -mattr=+sve,+bf16 < %s | FileCheck %s
3+
; RUN: llc -mattr=+sme -force-streaming < %s | FileCheck %s
4+
5+
target triple = "aarch64-unknown-linux-gnu"
6+
7+
; FADDV
8+
9+
define bfloat @faddv_nxv2bf16(<vscale x 2 x bfloat> %a) {
10+
; CHECK-LABEL: faddv_nxv2bf16:
11+
; CHECK: // %bb.0:
12+
; CHECK-NEXT: lsl z0.s, z0.s, #16
13+
; CHECK-NEXT: ptrue p0.d
14+
; CHECK-NEXT: faddv s0, p0, z0.s
15+
; CHECK-NEXT: bfcvt h0, s0
16+
; CHECK-NEXT: ret
17+
%res = call fast bfloat @llvm.vector.reduce.fadd.nxv2bf16(bfloat zeroinitializer, <vscale x 2 x bfloat> %a)
18+
ret bfloat %res
19+
}
20+
21+
define bfloat @faddv_nxv4bf16(<vscale x 4 x bfloat> %a) {
22+
; CHECK-LABEL: faddv_nxv4bf16:
23+
; CHECK: // %bb.0:
24+
; CHECK-NEXT: lsl z0.s, z0.s, #16
25+
; CHECK-NEXT: ptrue p0.s
26+
; CHECK-NEXT: faddv s0, p0, z0.s
27+
; CHECK-NEXT: bfcvt h0, s0
28+
; CHECK-NEXT: ret
29+
%res = call fast bfloat @llvm.vector.reduce.fadd.nxv4bf16(bfloat zeroinitializer, <vscale x 4 x bfloat> %a)
30+
ret bfloat %res
31+
}
32+
33+
define bfloat @faddv_nxv8bf16(<vscale x 8 x bfloat> %a) {
34+
; CHECK-LABEL: faddv_nxv8bf16:
35+
; CHECK: // %bb.0:
36+
; CHECK-NEXT: uunpkhi z1.s, z0.h
37+
; CHECK-NEXT: uunpklo z0.s, z0.h
38+
; CHECK-NEXT: ptrue p0.s
39+
; CHECK-NEXT: lsl z1.s, z1.s, #16
40+
; CHECK-NEXT: lsl z0.s, z0.s, #16
41+
; CHECK-NEXT: fadd z0.s, z0.s, z1.s
42+
; CHECK-NEXT: faddv s0, p0, z0.s
43+
; CHECK-NEXT: bfcvt h0, s0
44+
; CHECK-NEXT: ret
45+
%res = call fast bfloat @llvm.vector.reduce.fadd.nxv8bf16(bfloat zeroinitializer, <vscale x 8 x bfloat> %a)
46+
ret bfloat %res
47+
}
48+
49+
; FMAXNMV
50+
51+
define bfloat @fmaxv_nxv2bf16(<vscale x 2 x bfloat> %a) {
52+
; CHECK-LABEL: fmaxv_nxv2bf16:
53+
; CHECK: // %bb.0:
54+
; CHECK-NEXT: lsl z0.s, z0.s, #16
55+
; CHECK-NEXT: ptrue p0.d
56+
; CHECK-NEXT: fmaxnmv s0, p0, z0.s
57+
; CHECK-NEXT: bfcvt h0, s0
58+
; CHECK-NEXT: ret
59+
%res = call bfloat @llvm.vector.reduce.fmax.nxv2bf16(<vscale x 2 x bfloat> %a)
60+
ret bfloat %res
61+
}
62+
63+
define bfloat @fmaxv_nxv4bf16(<vscale x 4 x bfloat> %a) {
64+
; CHECK-LABEL: fmaxv_nxv4bf16:
65+
; CHECK: // %bb.0:
66+
; CHECK-NEXT: lsl z0.s, z0.s, #16
67+
; CHECK-NEXT: ptrue p0.s
68+
; CHECK-NEXT: fmaxnmv s0, p0, z0.s
69+
; CHECK-NEXT: bfcvt h0, s0
70+
; CHECK-NEXT: ret
71+
%res = call bfloat @llvm.vector.reduce.fmax.nxv4bf16(<vscale x 4 x bfloat> %a)
72+
ret bfloat %res
73+
}
74+
75+
define bfloat @fmaxv_nxv8bf16(<vscale x 8 x bfloat> %a) {
76+
; CHECK-LABEL: fmaxv_nxv8bf16:
77+
; CHECK: // %bb.0:
78+
; CHECK-NEXT: uunpkhi z1.s, z0.h
79+
; CHECK-NEXT: uunpklo z0.s, z0.h
80+
; CHECK-NEXT: ptrue p0.s
81+
; CHECK-NEXT: lsl z1.s, z1.s, #16
82+
; CHECK-NEXT: lsl z0.s, z0.s, #16
83+
; CHECK-NEXT: fmaxnm z0.s, p0/m, z0.s, z1.s
84+
; CHECK-NEXT: fmaxnmv s0, p0, z0.s
85+
; CHECK-NEXT: bfcvt h0, s0
86+
; CHECK-NEXT: ret
87+
%res = call bfloat @llvm.vector.reduce.fmax.nxv8bf16(<vscale x 8 x bfloat> %a)
88+
ret bfloat %res
89+
}
90+
91+
; FMINNMV
92+
93+
define bfloat @fminv_nxv2bf16(<vscale x 2 x bfloat> %a) {
94+
; CHECK-LABEL: fminv_nxv2bf16:
95+
; CHECK: // %bb.0:
96+
; CHECK-NEXT: lsl z0.s, z0.s, #16
97+
; CHECK-NEXT: ptrue p0.d
98+
; CHECK-NEXT: fminnmv s0, p0, z0.s
99+
; CHECK-NEXT: bfcvt h0, s0
100+
; CHECK-NEXT: ret
101+
%res = call bfloat @llvm.vector.reduce.fmin.nxv2bf16(<vscale x 2 x bfloat> %a)
102+
ret bfloat %res
103+
}
104+
105+
define bfloat @fminv_nxv4bf16(<vscale x 4 x bfloat> %a) {
106+
; CHECK-LABEL: fminv_nxv4bf16:
107+
; CHECK: // %bb.0:
108+
; CHECK-NEXT: lsl z0.s, z0.s, #16
109+
; CHECK-NEXT: ptrue p0.s
110+
; CHECK-NEXT: fminnmv s0, p0, z0.s
111+
; CHECK-NEXT: bfcvt h0, s0
112+
; CHECK-NEXT: ret
113+
%res = call bfloat @llvm.vector.reduce.fmin.nxv4bf16(<vscale x 4 x bfloat> %a)
114+
ret bfloat %res
115+
}
116+
117+
define bfloat @fminv_nxv8bf16(<vscale x 8 x bfloat> %a) {
118+
; CHECK-LABEL: fminv_nxv8bf16:
119+
; CHECK: // %bb.0:
120+
; CHECK-NEXT: uunpkhi z1.s, z0.h
121+
; CHECK-NEXT: uunpklo z0.s, z0.h
122+
; CHECK-NEXT: ptrue p0.s
123+
; CHECK-NEXT: lsl z1.s, z1.s, #16
124+
; CHECK-NEXT: lsl z0.s, z0.s, #16
125+
; CHECK-NEXT: fminnm z0.s, p0/m, z0.s, z1.s
126+
; CHECK-NEXT: fminnmv s0, p0, z0.s
127+
; CHECK-NEXT: bfcvt h0, s0
128+
; CHECK-NEXT: ret
129+
%res = call bfloat @llvm.vector.reduce.fmin.nxv8bf16(<vscale x 8 x bfloat> %a)
130+
ret bfloat %res
131+
}
132+
133+
; FMAXV
134+
135+
define bfloat @fmaximumv_nxv2bf16(<vscale x 2 x bfloat> %a) {
136+
; CHECK-LABEL: fmaximumv_nxv2bf16:
137+
; CHECK: // %bb.0:
138+
; CHECK-NEXT: lsl z0.s, z0.s, #16
139+
; CHECK-NEXT: ptrue p0.d
140+
; CHECK-NEXT: fmaxv s0, p0, z0.s
141+
; CHECK-NEXT: bfcvt h0, s0
142+
; CHECK-NEXT: ret
143+
%res = call bfloat @llvm.vector.reduce.fmaximum.nxv2bf16(<vscale x 2 x bfloat> %a)
144+
ret bfloat %res
145+
}
146+
147+
define bfloat @fmaximumv_nxv4bf16(<vscale x 4 x bfloat> %a) {
148+
; CHECK-LABEL: fmaximumv_nxv4bf16:
149+
; CHECK: // %bb.0:
150+
; CHECK-NEXT: lsl z0.s, z0.s, #16
151+
; CHECK-NEXT: ptrue p0.s
152+
; CHECK-NEXT: fmaxv s0, p0, z0.s
153+
; CHECK-NEXT: bfcvt h0, s0
154+
; CHECK-NEXT: ret
155+
%res = call bfloat @llvm.vector.reduce.fmaximum.nxv4bf16(<vscale x 4 x bfloat> %a)
156+
ret bfloat %res
157+
}
158+
159+
define bfloat @fmaximumv_nxv8bf16(<vscale x 8 x bfloat> %a) {
160+
; CHECK-LABEL: fmaximumv_nxv8bf16:
161+
; CHECK: // %bb.0:
162+
; CHECK-NEXT: uunpkhi z1.s, z0.h
163+
; CHECK-NEXT: uunpklo z0.s, z0.h
164+
; CHECK-NEXT: ptrue p0.s
165+
; CHECK-NEXT: lsl z1.s, z1.s, #16
166+
; CHECK-NEXT: lsl z0.s, z0.s, #16
167+
; CHECK-NEXT: fmax z0.s, p0/m, z0.s, z1.s
168+
; CHECK-NEXT: fmaxv s0, p0, z0.s
169+
; CHECK-NEXT: bfcvt h0, s0
170+
; CHECK-NEXT: ret
171+
%res = call bfloat @llvm.vector.reduce.fmaximum.nxv8bf16(<vscale x 8 x bfloat> %a)
172+
ret bfloat %res
173+
}
174+
175+
; FMINV
176+
177+
define bfloat @fminimumv_nxv2bf16(<vscale x 2 x bfloat> %a) {
178+
; CHECK-LABEL: fminimumv_nxv2bf16:
179+
; CHECK: // %bb.0:
180+
; CHECK-NEXT: lsl z0.s, z0.s, #16
181+
; CHECK-NEXT: ptrue p0.d
182+
; CHECK-NEXT: fminv s0, p0, z0.s
183+
; CHECK-NEXT: bfcvt h0, s0
184+
; CHECK-NEXT: ret
185+
%res = call bfloat @llvm.vector.reduce.fminimum.nxv2bf16(<vscale x 2 x bfloat> %a)
186+
ret bfloat %res
187+
}
188+
189+
define bfloat @fminimumv_nxv4bf16(<vscale x 4 x bfloat> %a) {
190+
; CHECK-LABEL: fminimumv_nxv4bf16:
191+
; CHECK: // %bb.0:
192+
; CHECK-NEXT: lsl z0.s, z0.s, #16
193+
; CHECK-NEXT: ptrue p0.s
194+
; CHECK-NEXT: fminv s0, p0, z0.s
195+
; CHECK-NEXT: bfcvt h0, s0
196+
; CHECK-NEXT: ret
197+
%res = call bfloat @llvm.vector.reduce.fminimum.nxv4bf16(<vscale x 4 x bfloat> %a)
198+
ret bfloat %res
199+
}
200+
201+
define bfloat @fminimumv_nxv8bf16(<vscale x 8 x bfloat> %a) {
202+
; CHECK-LABEL: fminimumv_nxv8bf16:
203+
; CHECK: // %bb.0:
204+
; CHECK-NEXT: uunpkhi z1.s, z0.h
205+
; CHECK-NEXT: uunpklo z0.s, z0.h
206+
; CHECK-NEXT: ptrue p0.s
207+
; CHECK-NEXT: lsl z1.s, z1.s, #16
208+
; CHECK-NEXT: lsl z0.s, z0.s, #16
209+
; CHECK-NEXT: fmin z0.s, p0/m, z0.s, z1.s
210+
; CHECK-NEXT: fminv s0, p0, z0.s
211+
; CHECK-NEXT: bfcvt h0, s0
212+
; CHECK-NEXT: ret
213+
%res = call bfloat @llvm.vector.reduce.fminimum.nxv8bf16(<vscale x 8 x bfloat> %a)
214+
ret bfloat %res
215+
}
216+
217+
declare bfloat @llvm.vector.reduce.fadd.nxv2bf16(bfloat, <vscale x 2 x bfloat>)
218+
declare bfloat @llvm.vector.reduce.fadd.nxv4bf16(bfloat, <vscale x 4 x bfloat>)
219+
declare bfloat @llvm.vector.reduce.fadd.nxv8bf16(bfloat, <vscale x 8 x bfloat>)
220+
221+
declare bfloat @llvm.vector.reduce.fmax.nxv2bf16(<vscale x 2 x bfloat>)
222+
declare bfloat @llvm.vector.reduce.fmax.nxv4bf16(<vscale x 4 x bfloat>)
223+
declare bfloat @llvm.vector.reduce.fmax.nxv8bf16(<vscale x 8 x bfloat>)
224+
225+
declare bfloat @llvm.vector.reduce.fmin.nxv2bf16(<vscale x 2 x bfloat>)
226+
declare bfloat @llvm.vector.reduce.fmin.nxv4bf16(<vscale x 4 x bfloat>)
227+
declare bfloat @llvm.vector.reduce.fmin.nxv8bf16(<vscale x 8 x bfloat>)
228+
229+
declare bfloat @llvm.vector.reduce.fmaximum.nxv2bf16(<vscale x 2 x bfloat>)
230+
declare bfloat @llvm.vector.reduce.fmaximum.nxv4bf16(<vscale x 4 x bfloat>)
231+
declare bfloat @llvm.vector.reduce.fmaximum.nxv8bf16(<vscale x 8 x bfloat>)
232+
233+
declare bfloat @llvm.vector.reduce.fminimum.nxv2bf16(<vscale x 2 x bfloat>)
234+
declare bfloat @llvm.vector.reduce.fminimum.nxv4bf16(<vscale x 4 x bfloat>)
235+
declare bfloat @llvm.vector.reduce.fminimum.nxv8bf16(<vscale x 8 x bfloat>)

0 commit comments

Comments
 (0)