Skip to content

Commit 9dc5f72

Browse files
Rename function. Set FP_ROUND TRUNC flag when safe and add associated tests.
1 parent 50a32d2 commit 9dc5f72

File tree

3 files changed

+62
-9
lines changed

3 files changed

+62
-9
lines changed

llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -188,7 +188,12 @@ class VectorLegalizer {
188188
void PromoteSETCC(SDNode *Node, SmallVectorImpl<SDValue> &Results);
189189

190190
void PromoteSTRICT(SDNode *Node, SmallVectorImpl<SDValue> &Results);
191-
void PromoteVECREDUCE(SDNode *Node, SmallVectorImpl<SDValue> &Results);
191+
192+
/// Calculate the reduction using a type of higher precision and round the
193+
/// result to match the original type. Setting NonArithmetic signifies the
194+
/// rounding of the result does not affect its value.
195+
void PromoteFloatVECREDUCE(SDNode *Node, SmallVectorImpl<SDValue> &Results,
196+
bool NonArithmetic);
192197

193198
public:
194199
VectorLegalizer(SelectionDAG& dag) :
@@ -683,8 +688,9 @@ void VectorLegalizer::PromoteSTRICT(SDNode *Node,
683688
Results.push_back(Round.getValue(1));
684689
}
685690

686-
void VectorLegalizer::PromoteVECREDUCE(SDNode *Node,
687-
SmallVectorImpl<SDValue> &Results) {
691+
void VectorLegalizer::PromoteFloatVECREDUCE(SDNode *Node,
692+
SmallVectorImpl<SDValue> &Results,
693+
bool NonArithmetic) {
688694
MVT OpVT = Node->getOperand(0).getSimpleValueType();
689695
assert(OpVT.isFloatingPoint() && "Expected floating point reduction!");
690696
MVT NewOpVT = TLI.getTypeToPromoteTo(Node->getOpcode(), OpVT);
@@ -694,8 +700,9 @@ void VectorLegalizer::PromoteVECREDUCE(SDNode *Node,
694700
SDValue Rdx =
695701
DAG.getNode(Node->getOpcode(), DL, NewOpVT.getVectorElementType(), NewOp,
696702
Node->getFlags());
697-
SDValue Res = DAG.getNode(ISD::FP_ROUND, DL, Node->getValueType(0), Rdx,
698-
DAG.getIntPtrConstant(0, DL, /*isTarget=*/true));
703+
SDValue Res =
704+
DAG.getNode(ISD::FP_ROUND, DL, Node->getValueType(0), Rdx,
705+
DAG.getIntPtrConstant(NonArithmetic, DL, /*isTarget=*/true));
699706
Results.push_back(Res);
700707
}
701708

@@ -731,11 +738,13 @@ void VectorLegalizer::Promote(SDNode *Node, SmallVectorImpl<SDValue> &Results) {
731738
PromoteSTRICT(Node, Results);
732739
return;
733740
case ISD::VECREDUCE_FADD:
741+
PromoteFloatVECREDUCE(Node, Results, /*NonArithmetic=*/false);
742+
return;
734743
case ISD::VECREDUCE_FMAX:
735744
case ISD::VECREDUCE_FMAXIMUM:
736745
case ISD::VECREDUCE_FMIN:
737746
case ISD::VECREDUCE_FMINIMUM:
738-
PromoteVECREDUCE(Node, Results);
747+
PromoteFloatVECREDUCE(Node, Results, /*NonArithmetic=*/true);
739748
return;
740749
case ISD::FP_ROUND:
741750
case ISD::FP_EXTEND:

llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11432,7 +11432,7 @@ SDValue TargetLowering::expandVecReduce(SDNode *Node, SelectionDAG &DAG) const {
1143211432
}
1143311433

1143411434
if (VT.isScalableVector())
11435-
report_fatal_error(
11435+
reportFatalInternalError(
1143611436
"Expanding reductions for scalable vectors is undefined.");
1143711437

1143811438
EVT EltVT = VT.getVectorElementType();

llvm/test/CodeGen/AArch64/sve-bf16-reductions.ll

Lines changed: 46 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5
2-
; RUN: llc -mattr=+sve,+bf16 < %s | FileCheck %s
3-
; RUN: llc -mattr=+sme -force-streaming < %s | FileCheck %s
2+
; RUN: llc -mattr=+sve,+bf16 < %s | FileCheck %s -check-prefixes=CHECK,SVE
3+
; RUN: llc -mattr=+sme -force-streaming < %s | FileCheck %s -check-prefixes=CHECK,SME
44

55
target triple = "aarch64-unknown-linux-gnu"
66

@@ -214,6 +214,50 @@ define bfloat @fminimumv_nxv8bf16(<vscale x 8 x bfloat> %a) {
214214
ret bfloat %res
215215
}
216216

217+
; The reduction is performed at a higher precision. Because add operations
218+
; can utilise that precision, its result must be rounded even if it's then
219+
; promoted.
220+
define float @promoted_fadd(<vscale x 4 x bfloat> %a) {
221+
; SVE-LABEL: promoted_fadd:
222+
; SVE: // %bb.0:
223+
; SVE-NEXT: lsl z0.s, z0.s, #16
224+
; SVE-NEXT: ptrue p0.s
225+
; SVE-NEXT: faddv s0, p0, z0.s
226+
; SVE-NEXT: bfcvt h0, s0
227+
; SVE-NEXT: shll v0.4s, v0.4h, #16
228+
; SVE-NEXT: // kill: def $s0 killed $s0 killed $q0
229+
; SVE-NEXT: ret
230+
;
231+
; SME-LABEL: promoted_fadd:
232+
; SME: // %bb.0:
233+
; SME-NEXT: lsl z0.s, z0.s, #16
234+
; SME-NEXT: ptrue p0.s
235+
; SME-NEXT: faddv s0, p0, z0.s
236+
; SME-NEXT: bfcvt h0, s0
237+
; SME-NEXT: fmov w8, s0
238+
; SME-NEXT: lsl w8, w8, #16
239+
; SME-NEXT: fmov s0, w8
240+
; SME-NEXT: ret
241+
%rdx = call fast bfloat @llvm.vector.reduce.fadd.nxv4bf16(bfloat zeroinitializer, <vscale x 4 x bfloat> %a)
242+
%res = fpext bfloat %rdx to float
243+
ret float %res
244+
}
245+
246+
; The reduction is performed at a higher precision. Because min/max operations
247+
; don't utilise that precision, its result can be used directly.
248+
define float @promoted_fmax(<vscale x 4 x bfloat> %a) {
249+
; CHECK-LABEL: promoted_fmax:
250+
; CHECK: // %bb.0:
251+
; CHECK-NEXT: lsl z0.s, z0.s, #16
252+
; CHECK-NEXT: ptrue p0.s
253+
; CHECK-NEXT: fmaxnmv s0, p0, z0.s
254+
; CHECK-NEXT: // kill: def $s0 killed $s0 killed $z0
255+
; CHECK-NEXT: ret
256+
%rdx = call bfloat @llvm.vector.reduce.fmax.nxv4bf16(<vscale x 4 x bfloat> %a)
257+
%res = fpext bfloat %rdx to float
258+
ret float %res
259+
}
260+
217261
declare bfloat @llvm.vector.reduce.fadd.nxv2bf16(bfloat, <vscale x 2 x bfloat>)
218262
declare bfloat @llvm.vector.reduce.fadd.nxv4bf16(bfloat, <vscale x 4 x bfloat>)
219263
declare bfloat @llvm.vector.reduce.fadd.nxv8bf16(bfloat, <vscale x 8 x bfloat>)

0 commit comments

Comments
 (0)