Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 27 additions & 9 deletions llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,7 @@ class VectorLegalizer {
void PromoteSETCC(SDNode *Node, SmallVectorImpl<SDValue> &Results);

void PromoteSTRICT(SDNode *Node, SmallVectorImpl<SDValue> &Results);
void PromoteVECREDUCE(SDNode *Node, SmallVectorImpl<SDValue> &Results);

public:
VectorLegalizer(SelectionDAG& dag) :
Expand Down Expand Up @@ -500,20 +501,14 @@ SDValue VectorLegalizer::LegalizeOp(SDValue Op) {
case ISD::VECREDUCE_UMAX:
case ISD::VECREDUCE_UMIN:
case ISD::VECREDUCE_FADD:
case ISD::VECREDUCE_FMUL:
case ISD::VECTOR_FIND_LAST_ACTIVE:
Action = TLI.getOperationAction(Node->getOpcode(),
Node->getOperand(0).getValueType());
break;
case ISD::VECREDUCE_FMAX:
case ISD::VECREDUCE_FMIN:
case ISD::VECREDUCE_FMAXIMUM:
case ISD::VECREDUCE_FMIN:
case ISD::VECREDUCE_FMINIMUM:
case ISD::VECREDUCE_FMUL:
case ISD::VECTOR_FIND_LAST_ACTIVE:
Action = TLI.getOperationAction(Node->getOpcode(),
Node->getOperand(0).getValueType());
// Defer non-vector results to LegalizeDAG.
if (Action == TargetLowering::Promote)
Action = TargetLowering::Legal;
break;
case ISD::VECREDUCE_SEQ_FADD:
case ISD::VECREDUCE_SEQ_FMUL:
Expand Down Expand Up @@ -688,6 +683,22 @@ void VectorLegalizer::PromoteSTRICT(SDNode *Node,
Results.push_back(Round.getValue(1));
}

void VectorLegalizer::PromoteVECREDUCE(SDNode *Node,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Doesn't clearly indicate this is the for FP operations

SmallVectorImpl<SDValue> &Results) {
MVT OpVT = Node->getOperand(0).getSimpleValueType();
assert(OpVT.isFloatingPoint() && "Expected floating point reduction!");
MVT NewOpVT = TLI.getTypeToPromoteTo(Node->getOpcode(), OpVT);

SDLoc DL(Node);
SDValue NewOp = DAG.getNode(ISD::FP_EXTEND, DL, NewOpVT, Node->getOperand(0));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What about the strictfp case?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does strictfp apply here? There are no STRICT_VECREDUCE_ nodes. There are the ordered VECREDUCE_SEQ_ nodes, but they go down a different path so are not covered by this PR.

Copy link
Contributor

@arsenm arsenm Jun 13, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. This pushes on the broader issue with the current strictfp strategy where we need to duplicate all possible FP intrinsics and we're missing them here.

So for these non-strict intrinsics it's not an issue, the issue is we don't have strict versions of these

SDValue Rdx =
DAG.getNode(Node->getOpcode(), DL, NewOpVT.getVectorElementType(), NewOp,
Node->getFlags());
SDValue Res = DAG.getNode(ISD::FP_ROUND, DL, Node->getValueType(0), Rdx,
DAG.getIntPtrConstant(0, DL, /*isTarget=*/true));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we use 1 here? We are converting back so it should be exactly representable, but the other cases seem to not do this

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes I think we can for MIN/MAX but not for ADD because those cases might legitimately round to infinity?

Results.push_back(Res);
}

void VectorLegalizer::Promote(SDNode *Node, SmallVectorImpl<SDValue> &Results) {
// For a few operations there is a specific concept for promotion based on
// the operand's type.
Expand Down Expand Up @@ -719,6 +730,13 @@ void VectorLegalizer::Promote(SDNode *Node, SmallVectorImpl<SDValue> &Results) {
case ISD::STRICT_FMA:
PromoteSTRICT(Node, Results);
return;
case ISD::VECREDUCE_FADD:
case ISD::VECREDUCE_FMAX:
case ISD::VECREDUCE_FMAXIMUM:
case ISD::VECREDUCE_FMIN:
case ISD::VECREDUCE_FMINIMUM:
PromoteVECREDUCE(Node, Results);
return;
case ISD::FP_ROUND:
case ISD::FP_EXTEND:
// These operations are used to do promotion so they can't be promoted
Expand Down
15 changes: 10 additions & 5 deletions llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11412,13 +11412,9 @@ SDValue TargetLowering::expandVecReduce(SDNode *Node, SelectionDAG &DAG) const {
SDValue Op = Node->getOperand(0);
EVT VT = Op.getValueType();

if (VT.isScalableVector())
report_fatal_error(
"Expanding reductions for scalable vectors is undefined.");

// Try to use a shuffle reduction for power of two vectors.
if (VT.isPow2VectorType()) {
while (VT.getVectorNumElements() > 1) {
while (VT.getVectorElementCount().isKnownMultipleOf(2)) {
EVT HalfVT = VT.getHalfNumVectorElementsVT(*DAG.getContext());
if (!isOperationLegalOrCustom(BaseOpcode, HalfVT))
break;
Expand All @@ -11427,9 +11423,18 @@ SDValue TargetLowering::expandVecReduce(SDNode *Node, SelectionDAG &DAG) const {
std::tie(Lo, Hi) = DAG.SplitVector(Op, dl);
Op = DAG.getNode(BaseOpcode, dl, HalfVT, Lo, Hi, Node->getFlags());
VT = HalfVT;

// Stop if splitting is enough to make the reduction legal.
if (isOperationLegalOrCustom(Node->getOpcode(), HalfVT))
return DAG.getNode(Node->getOpcode(), dl, Node->getValueType(0), Op,
Node->getFlags());
}
}

if (VT.isScalableVector())
report_fatal_error(
"Expanding reductions for scalable vectors is undefined.");
Copy link
Contributor

@arsenm arsenm Jun 11, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if (VT.isScalableVector())
report_fatal_error(
"Expanding reductions for scalable vectors is undefined.");
if (VT.isScalableVector()) {
reportFatalInternalError(
"expanding reductions for scalable vectors is undefined");
}


EVT EltVT = VT.getVectorElementType();
unsigned NumElts = VT.getVectorNumElements();

Expand Down
4 changes: 3 additions & 1 deletion llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1780,7 +1780,9 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,

for (auto Opcode :
{ISD::FCEIL, ISD::FDIV, ISD::FFLOOR, ISD::FNEARBYINT, ISD::FRINT,
ISD::FROUND, ISD::FROUNDEVEN, ISD::FSQRT, ISD::FTRUNC, ISD::SETCC}) {
ISD::FROUND, ISD::FROUNDEVEN, ISD::FSQRT, ISD::FTRUNC, ISD::SETCC,
ISD::VECREDUCE_FADD, ISD::VECREDUCE_FMAX, ISD::VECREDUCE_FMAXIMUM,
ISD::VECREDUCE_FMIN, ISD::VECREDUCE_FMINIMUM}) {
setOperationPromotedToType(Opcode, MVT::nxv2bf16, MVT::nxv2f32);
setOperationPromotedToType(Opcode, MVT::nxv4bf16, MVT::nxv4f32);
setOperationPromotedToType(Opcode, MVT::nxv8bf16, MVT::nxv8f32);
Expand Down
235 changes: 235 additions & 0 deletions llvm/test/CodeGen/AArch64/sve-bf16-reductions.ll
Original file line number Diff line number Diff line change
@@ -0,0 +1,235 @@
; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5
; RUN: llc -mattr=+sve,+bf16 < %s | FileCheck %s
; RUN: llc -mattr=+sme -force-streaming < %s | FileCheck %s

target triple = "aarch64-unknown-linux-gnu"

; FADDV

define bfloat @faddv_nxv2bf16(<vscale x 2 x bfloat> %a) {
; CHECK-LABEL: faddv_nxv2bf16:
; CHECK: // %bb.0:
; CHECK-NEXT: lsl z0.s, z0.s, #16
; CHECK-NEXT: ptrue p0.d
; CHECK-NEXT: faddv s0, p0, z0.s
; CHECK-NEXT: bfcvt h0, s0
; CHECK-NEXT: ret
%res = call fast bfloat @llvm.vector.reduce.fadd.nxv2bf16(bfloat zeroinitializer, <vscale x 2 x bfloat> %a)
ret bfloat %res
}

define bfloat @faddv_nxv4bf16(<vscale x 4 x bfloat> %a) {
; CHECK-LABEL: faddv_nxv4bf16:
; CHECK: // %bb.0:
; CHECK-NEXT: lsl z0.s, z0.s, #16
; CHECK-NEXT: ptrue p0.s
; CHECK-NEXT: faddv s0, p0, z0.s
; CHECK-NEXT: bfcvt h0, s0
; CHECK-NEXT: ret
%res = call fast bfloat @llvm.vector.reduce.fadd.nxv4bf16(bfloat zeroinitializer, <vscale x 4 x bfloat> %a)
ret bfloat %res
}

define bfloat @faddv_nxv8bf16(<vscale x 8 x bfloat> %a) {
; CHECK-LABEL: faddv_nxv8bf16:
; CHECK: // %bb.0:
; CHECK-NEXT: uunpkhi z1.s, z0.h
; CHECK-NEXT: uunpklo z0.s, z0.h
; CHECK-NEXT: ptrue p0.s
; CHECK-NEXT: lsl z1.s, z1.s, #16
; CHECK-NEXT: lsl z0.s, z0.s, #16
; CHECK-NEXT: fadd z0.s, z0.s, z1.s
; CHECK-NEXT: faddv s0, p0, z0.s
; CHECK-NEXT: bfcvt h0, s0
; CHECK-NEXT: ret
%res = call fast bfloat @llvm.vector.reduce.fadd.nxv8bf16(bfloat zeroinitializer, <vscale x 8 x bfloat> %a)
ret bfloat %res
}

; FMAXNMV

define bfloat @fmaxv_nxv2bf16(<vscale x 2 x bfloat> %a) {
; CHECK-LABEL: fmaxv_nxv2bf16:
; CHECK: // %bb.0:
; CHECK-NEXT: lsl z0.s, z0.s, #16
; CHECK-NEXT: ptrue p0.d
; CHECK-NEXT: fmaxnmv s0, p0, z0.s
; CHECK-NEXT: bfcvt h0, s0
; CHECK-NEXT: ret
%res = call bfloat @llvm.vector.reduce.fmax.nxv2bf16(<vscale x 2 x bfloat> %a)
ret bfloat %res
}

define bfloat @fmaxv_nxv4bf16(<vscale x 4 x bfloat> %a) {
; CHECK-LABEL: fmaxv_nxv4bf16:
; CHECK: // %bb.0:
; CHECK-NEXT: lsl z0.s, z0.s, #16
; CHECK-NEXT: ptrue p0.s
; CHECK-NEXT: fmaxnmv s0, p0, z0.s
; CHECK-NEXT: bfcvt h0, s0
; CHECK-NEXT: ret
%res = call bfloat @llvm.vector.reduce.fmax.nxv4bf16(<vscale x 4 x bfloat> %a)
ret bfloat %res
}

define bfloat @fmaxv_nxv8bf16(<vscale x 8 x bfloat> %a) {
; CHECK-LABEL: fmaxv_nxv8bf16:
; CHECK: // %bb.0:
; CHECK-NEXT: uunpkhi z1.s, z0.h
; CHECK-NEXT: uunpklo z0.s, z0.h
; CHECK-NEXT: ptrue p0.s
; CHECK-NEXT: lsl z1.s, z1.s, #16
; CHECK-NEXT: lsl z0.s, z0.s, #16
; CHECK-NEXT: fmaxnm z0.s, p0/m, z0.s, z1.s
; CHECK-NEXT: fmaxnmv s0, p0, z0.s
; CHECK-NEXT: bfcvt h0, s0
; CHECK-NEXT: ret
%res = call bfloat @llvm.vector.reduce.fmax.nxv8bf16(<vscale x 8 x bfloat> %a)
ret bfloat %res
}

; FMINNMV

define bfloat @fminv_nxv2bf16(<vscale x 2 x bfloat> %a) {
; CHECK-LABEL: fminv_nxv2bf16:
; CHECK: // %bb.0:
; CHECK-NEXT: lsl z0.s, z0.s, #16
; CHECK-NEXT: ptrue p0.d
; CHECK-NEXT: fminnmv s0, p0, z0.s
; CHECK-NEXT: bfcvt h0, s0
; CHECK-NEXT: ret
%res = call bfloat @llvm.vector.reduce.fmin.nxv2bf16(<vscale x 2 x bfloat> %a)
ret bfloat %res
}

define bfloat @fminv_nxv4bf16(<vscale x 4 x bfloat> %a) {
; CHECK-LABEL: fminv_nxv4bf16:
; CHECK: // %bb.0:
; CHECK-NEXT: lsl z0.s, z0.s, #16
; CHECK-NEXT: ptrue p0.s
; CHECK-NEXT: fminnmv s0, p0, z0.s
; CHECK-NEXT: bfcvt h0, s0
; CHECK-NEXT: ret
%res = call bfloat @llvm.vector.reduce.fmin.nxv4bf16(<vscale x 4 x bfloat> %a)
ret bfloat %res
}

define bfloat @fminv_nxv8bf16(<vscale x 8 x bfloat> %a) {
; CHECK-LABEL: fminv_nxv8bf16:
; CHECK: // %bb.0:
; CHECK-NEXT: uunpkhi z1.s, z0.h
; CHECK-NEXT: uunpklo z0.s, z0.h
; CHECK-NEXT: ptrue p0.s
; CHECK-NEXT: lsl z1.s, z1.s, #16
; CHECK-NEXT: lsl z0.s, z0.s, #16
; CHECK-NEXT: fminnm z0.s, p0/m, z0.s, z1.s
; CHECK-NEXT: fminnmv s0, p0, z0.s
; CHECK-NEXT: bfcvt h0, s0
; CHECK-NEXT: ret
%res = call bfloat @llvm.vector.reduce.fmin.nxv8bf16(<vscale x 8 x bfloat> %a)
ret bfloat %res
}

; FMAXV

define bfloat @fmaximumv_nxv2bf16(<vscale x 2 x bfloat> %a) {
; CHECK-LABEL: fmaximumv_nxv2bf16:
; CHECK: // %bb.0:
; CHECK-NEXT: lsl z0.s, z0.s, #16
; CHECK-NEXT: ptrue p0.d
; CHECK-NEXT: fmaxv s0, p0, z0.s
; CHECK-NEXT: bfcvt h0, s0
; CHECK-NEXT: ret
%res = call bfloat @llvm.vector.reduce.fmaximum.nxv2bf16(<vscale x 2 x bfloat> %a)
ret bfloat %res
}

define bfloat @fmaximumv_nxv4bf16(<vscale x 4 x bfloat> %a) {
; CHECK-LABEL: fmaximumv_nxv4bf16:
; CHECK: // %bb.0:
; CHECK-NEXT: lsl z0.s, z0.s, #16
; CHECK-NEXT: ptrue p0.s
; CHECK-NEXT: fmaxv s0, p0, z0.s
; CHECK-NEXT: bfcvt h0, s0
; CHECK-NEXT: ret
%res = call bfloat @llvm.vector.reduce.fmaximum.nxv4bf16(<vscale x 4 x bfloat> %a)
ret bfloat %res
}

define bfloat @fmaximumv_nxv8bf16(<vscale x 8 x bfloat> %a) {
; CHECK-LABEL: fmaximumv_nxv8bf16:
; CHECK: // %bb.0:
; CHECK-NEXT: uunpkhi z1.s, z0.h
; CHECK-NEXT: uunpklo z0.s, z0.h
; CHECK-NEXT: ptrue p0.s
; CHECK-NEXT: lsl z1.s, z1.s, #16
; CHECK-NEXT: lsl z0.s, z0.s, #16
; CHECK-NEXT: fmax z0.s, p0/m, z0.s, z1.s
; CHECK-NEXT: fmaxv s0, p0, z0.s
; CHECK-NEXT: bfcvt h0, s0
; CHECK-NEXT: ret
%res = call bfloat @llvm.vector.reduce.fmaximum.nxv8bf16(<vscale x 8 x bfloat> %a)
ret bfloat %res
}

; FMINV

define bfloat @fminimumv_nxv2bf16(<vscale x 2 x bfloat> %a) {
; CHECK-LABEL: fminimumv_nxv2bf16:
; CHECK: // %bb.0:
; CHECK-NEXT: lsl z0.s, z0.s, #16
; CHECK-NEXT: ptrue p0.d
; CHECK-NEXT: fminv s0, p0, z0.s
; CHECK-NEXT: bfcvt h0, s0
; CHECK-NEXT: ret
%res = call bfloat @llvm.vector.reduce.fminimum.nxv2bf16(<vscale x 2 x bfloat> %a)
ret bfloat %res
}

define bfloat @fminimumv_nxv4bf16(<vscale x 4 x bfloat> %a) {
; CHECK-LABEL: fminimumv_nxv4bf16:
; CHECK: // %bb.0:
; CHECK-NEXT: lsl z0.s, z0.s, #16
; CHECK-NEXT: ptrue p0.s
; CHECK-NEXT: fminv s0, p0, z0.s
; CHECK-NEXT: bfcvt h0, s0
; CHECK-NEXT: ret
%res = call bfloat @llvm.vector.reduce.fminimum.nxv4bf16(<vscale x 4 x bfloat> %a)
ret bfloat %res
}

define bfloat @fminimumv_nxv8bf16(<vscale x 8 x bfloat> %a) {
; CHECK-LABEL: fminimumv_nxv8bf16:
; CHECK: // %bb.0:
; CHECK-NEXT: uunpkhi z1.s, z0.h
; CHECK-NEXT: uunpklo z0.s, z0.h
; CHECK-NEXT: ptrue p0.s
; CHECK-NEXT: lsl z1.s, z1.s, #16
; CHECK-NEXT: lsl z0.s, z0.s, #16
; CHECK-NEXT: fmin z0.s, p0/m, z0.s, z1.s
; CHECK-NEXT: fminv s0, p0, z0.s
; CHECK-NEXT: bfcvt h0, s0
; CHECK-NEXT: ret
%res = call bfloat @llvm.vector.reduce.fminimum.nxv8bf16(<vscale x 8 x bfloat> %a)
ret bfloat %res
}

declare bfloat @llvm.vector.reduce.fadd.nxv2bf16(bfloat, <vscale x 2 x bfloat>)
declare bfloat @llvm.vector.reduce.fadd.nxv4bf16(bfloat, <vscale x 4 x bfloat>)
declare bfloat @llvm.vector.reduce.fadd.nxv8bf16(bfloat, <vscale x 8 x bfloat>)

declare bfloat @llvm.vector.reduce.fmax.nxv2bf16(<vscale x 2 x bfloat>)
declare bfloat @llvm.vector.reduce.fmax.nxv4bf16(<vscale x 4 x bfloat>)
declare bfloat @llvm.vector.reduce.fmax.nxv8bf16(<vscale x 8 x bfloat>)

declare bfloat @llvm.vector.reduce.fmin.nxv2bf16(<vscale x 2 x bfloat>)
declare bfloat @llvm.vector.reduce.fmin.nxv4bf16(<vscale x 4 x bfloat>)
declare bfloat @llvm.vector.reduce.fmin.nxv8bf16(<vscale x 8 x bfloat>)

declare bfloat @llvm.vector.reduce.fmaximum.nxv2bf16(<vscale x 2 x bfloat>)
declare bfloat @llvm.vector.reduce.fmaximum.nxv4bf16(<vscale x 4 x bfloat>)
declare bfloat @llvm.vector.reduce.fmaximum.nxv8bf16(<vscale x 8 x bfloat>)

declare bfloat @llvm.vector.reduce.fminimum.nxv2bf16(<vscale x 2 x bfloat>)
declare bfloat @llvm.vector.reduce.fminimum.nxv4bf16(<vscale x 4 x bfloat>)
declare bfloat @llvm.vector.reduce.fminimum.nxv8bf16(<vscale x 8 x bfloat>)
Loading