diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp b/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp index 4a1cd642233ef..f908a66128ec8 100644 --- a/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp @@ -189,6 +189,12 @@ class VectorLegalizer { void PromoteSTRICT(SDNode *Node, SmallVectorImpl &Results); + /// Calculate the reduction using a type of higher precision and round the + /// result to match the original type. Setting NonArithmetic signifies the + /// rounding of the result does not affect its value. + void PromoteFloatVECREDUCE(SDNode *Node, SmallVectorImpl &Results, + bool NonArithmetic); + public: VectorLegalizer(SelectionDAG& dag) : DAG(dag), TLI(dag.getTargetLoweringInfo()) {} @@ -500,20 +506,14 @@ SDValue VectorLegalizer::LegalizeOp(SDValue Op) { case ISD::VECREDUCE_UMAX: case ISD::VECREDUCE_UMIN: case ISD::VECREDUCE_FADD: - case ISD::VECREDUCE_FMUL: - case ISD::VECTOR_FIND_LAST_ACTIVE: - Action = TLI.getOperationAction(Node->getOpcode(), - Node->getOperand(0).getValueType()); - break; case ISD::VECREDUCE_FMAX: - case ISD::VECREDUCE_FMIN: case ISD::VECREDUCE_FMAXIMUM: + case ISD::VECREDUCE_FMIN: case ISD::VECREDUCE_FMINIMUM: + case ISD::VECREDUCE_FMUL: + case ISD::VECTOR_FIND_LAST_ACTIVE: Action = TLI.getOperationAction(Node->getOpcode(), Node->getOperand(0).getValueType()); - // Defer non-vector results to LegalizeDAG. - if (Action == TargetLowering::Promote) - Action = TargetLowering::Legal; break; case ISD::VECREDUCE_SEQ_FADD: case ISD::VECREDUCE_SEQ_FMUL: @@ -688,6 +688,24 @@ void VectorLegalizer::PromoteSTRICT(SDNode *Node, Results.push_back(Round.getValue(1)); } +void VectorLegalizer::PromoteFloatVECREDUCE(SDNode *Node, + SmallVectorImpl &Results, + bool NonArithmetic) { + MVT OpVT = Node->getOperand(0).getSimpleValueType(); + assert(OpVT.isFloatingPoint() && "Expected floating point reduction!"); + MVT NewOpVT = TLI.getTypeToPromoteTo(Node->getOpcode(), OpVT); + + SDLoc DL(Node); + SDValue NewOp = DAG.getNode(ISD::FP_EXTEND, DL, NewOpVT, Node->getOperand(0)); + SDValue Rdx = + DAG.getNode(Node->getOpcode(), DL, NewOpVT.getVectorElementType(), NewOp, + Node->getFlags()); + SDValue Res = + DAG.getNode(ISD::FP_ROUND, DL, Node->getValueType(0), Rdx, + DAG.getIntPtrConstant(NonArithmetic, DL, /*isTarget=*/true)); + Results.push_back(Res); +} + void VectorLegalizer::Promote(SDNode *Node, SmallVectorImpl &Results) { // For a few operations there is a specific concept for promotion based on // the operand's type. @@ -719,6 +737,15 @@ void VectorLegalizer::Promote(SDNode *Node, SmallVectorImpl &Results) { case ISD::STRICT_FMA: PromoteSTRICT(Node, Results); return; + case ISD::VECREDUCE_FADD: + PromoteFloatVECREDUCE(Node, Results, /*NonArithmetic=*/false); + return; + case ISD::VECREDUCE_FMAX: + case ISD::VECREDUCE_FMAXIMUM: + case ISD::VECREDUCE_FMIN: + case ISD::VECREDUCE_FMINIMUM: + PromoteFloatVECREDUCE(Node, Results, /*NonArithmetic=*/true); + return; case ISD::FP_ROUND: case ISD::FP_EXTEND: // These operations are used to do promotion so they can't be promoted diff --git a/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp b/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp index a0ffb4b6d5a4c..ca1a1165115cf 100644 --- a/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp @@ -11412,13 +11412,9 @@ SDValue TargetLowering::expandVecReduce(SDNode *Node, SelectionDAG &DAG) const { SDValue Op = Node->getOperand(0); EVT VT = Op.getValueType(); - if (VT.isScalableVector()) - report_fatal_error( - "Expanding reductions for scalable vectors is undefined."); - // Try to use a shuffle reduction for power of two vectors. if (VT.isPow2VectorType()) { - while (VT.getVectorNumElements() > 1) { + while (VT.getVectorElementCount().isKnownMultipleOf(2)) { EVT HalfVT = VT.getHalfNumVectorElementsVT(*DAG.getContext()); if (!isOperationLegalOrCustom(BaseOpcode, HalfVT)) break; @@ -11427,9 +11423,18 @@ SDValue TargetLowering::expandVecReduce(SDNode *Node, SelectionDAG &DAG) const { std::tie(Lo, Hi) = DAG.SplitVector(Op, dl); Op = DAG.getNode(BaseOpcode, dl, HalfVT, Lo, Hi, Node->getFlags()); VT = HalfVT; + + // Stop if splitting is enough to make the reduction legal. + if (isOperationLegalOrCustom(Node->getOpcode(), HalfVT)) + return DAG.getNode(Node->getOpcode(), dl, Node->getValueType(0), Op, + Node->getFlags()); } } + if (VT.isScalableVector()) + reportFatalInternalError( + "Expanding reductions for scalable vectors is undefined."); + EVT EltVT = VT.getVectorElementType(); unsigned NumElts = VT.getVectorNumElements(); diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp index caac00c5b2faa..9322f615827d9 100644 --- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp +++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp @@ -1780,7 +1780,9 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM, for (auto Opcode : {ISD::FCEIL, ISD::FDIV, ISD::FFLOOR, ISD::FNEARBYINT, ISD::FRINT, - ISD::FROUND, ISD::FROUNDEVEN, ISD::FSQRT, ISD::FTRUNC, ISD::SETCC}) { + ISD::FROUND, ISD::FROUNDEVEN, ISD::FSQRT, ISD::FTRUNC, ISD::SETCC, + ISD::VECREDUCE_FADD, ISD::VECREDUCE_FMAX, ISD::VECREDUCE_FMAXIMUM, + ISD::VECREDUCE_FMIN, ISD::VECREDUCE_FMINIMUM}) { setOperationPromotedToType(Opcode, MVT::nxv2bf16, MVT::nxv2f32); setOperationPromotedToType(Opcode, MVT::nxv4bf16, MVT::nxv4f32); setOperationPromotedToType(Opcode, MVT::nxv8bf16, MVT::nxv8f32); diff --git a/llvm/test/CodeGen/AArch64/sve-bf16-reductions.ll b/llvm/test/CodeGen/AArch64/sve-bf16-reductions.ll new file mode 100644 index 0000000000000..7f79c9c5431ea --- /dev/null +++ b/llvm/test/CodeGen/AArch64/sve-bf16-reductions.ll @@ -0,0 +1,279 @@ +; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5 +; RUN: llc -mattr=+sve,+bf16 < %s | FileCheck %s -check-prefixes=CHECK,SVE +; RUN: llc -mattr=+sme -force-streaming < %s | FileCheck %s -check-prefixes=CHECK,SME + +target triple = "aarch64-unknown-linux-gnu" + +; FADDV + +define bfloat @faddv_nxv2bf16( %a) { +; CHECK-LABEL: faddv_nxv2bf16: +; CHECK: // %bb.0: +; CHECK-NEXT: lsl z0.s, z0.s, #16 +; CHECK-NEXT: ptrue p0.d +; CHECK-NEXT: faddv s0, p0, z0.s +; CHECK-NEXT: bfcvt h0, s0 +; CHECK-NEXT: ret + %res = call fast bfloat @llvm.vector.reduce.fadd.nxv2bf16(bfloat zeroinitializer, %a) + ret bfloat %res +} + +define bfloat @faddv_nxv4bf16( %a) { +; CHECK-LABEL: faddv_nxv4bf16: +; CHECK: // %bb.0: +; CHECK-NEXT: lsl z0.s, z0.s, #16 +; CHECK-NEXT: ptrue p0.s +; CHECK-NEXT: faddv s0, p0, z0.s +; CHECK-NEXT: bfcvt h0, s0 +; CHECK-NEXT: ret + %res = call fast bfloat @llvm.vector.reduce.fadd.nxv4bf16(bfloat zeroinitializer, %a) + ret bfloat %res +} + +define bfloat @faddv_nxv8bf16( %a) { +; CHECK-LABEL: faddv_nxv8bf16: +; CHECK: // %bb.0: +; CHECK-NEXT: uunpkhi z1.s, z0.h +; CHECK-NEXT: uunpklo z0.s, z0.h +; CHECK-NEXT: ptrue p0.s +; CHECK-NEXT: lsl z1.s, z1.s, #16 +; CHECK-NEXT: lsl z0.s, z0.s, #16 +; CHECK-NEXT: fadd z0.s, z0.s, z1.s +; CHECK-NEXT: faddv s0, p0, z0.s +; CHECK-NEXT: bfcvt h0, s0 +; CHECK-NEXT: ret + %res = call fast bfloat @llvm.vector.reduce.fadd.nxv8bf16(bfloat zeroinitializer, %a) + ret bfloat %res +} + +; FMAXNMV + +define bfloat @fmaxv_nxv2bf16( %a) { +; CHECK-LABEL: fmaxv_nxv2bf16: +; CHECK: // %bb.0: +; CHECK-NEXT: lsl z0.s, z0.s, #16 +; CHECK-NEXT: ptrue p0.d +; CHECK-NEXT: fmaxnmv s0, p0, z0.s +; CHECK-NEXT: bfcvt h0, s0 +; CHECK-NEXT: ret + %res = call bfloat @llvm.vector.reduce.fmax.nxv2bf16( %a) + ret bfloat %res +} + +define bfloat @fmaxv_nxv4bf16( %a) { +; CHECK-LABEL: fmaxv_nxv4bf16: +; CHECK: // %bb.0: +; CHECK-NEXT: lsl z0.s, z0.s, #16 +; CHECK-NEXT: ptrue p0.s +; CHECK-NEXT: fmaxnmv s0, p0, z0.s +; CHECK-NEXT: bfcvt h0, s0 +; CHECK-NEXT: ret + %res = call bfloat @llvm.vector.reduce.fmax.nxv4bf16( %a) + ret bfloat %res +} + +define bfloat @fmaxv_nxv8bf16( %a) { +; CHECK-LABEL: fmaxv_nxv8bf16: +; CHECK: // %bb.0: +; CHECK-NEXT: uunpkhi z1.s, z0.h +; CHECK-NEXT: uunpklo z0.s, z0.h +; CHECK-NEXT: ptrue p0.s +; CHECK-NEXT: lsl z1.s, z1.s, #16 +; CHECK-NEXT: lsl z0.s, z0.s, #16 +; CHECK-NEXT: fmaxnm z0.s, p0/m, z0.s, z1.s +; CHECK-NEXT: fmaxnmv s0, p0, z0.s +; CHECK-NEXT: bfcvt h0, s0 +; CHECK-NEXT: ret + %res = call bfloat @llvm.vector.reduce.fmax.nxv8bf16( %a) + ret bfloat %res +} + +; FMINNMV + +define bfloat @fminv_nxv2bf16( %a) { +; CHECK-LABEL: fminv_nxv2bf16: +; CHECK: // %bb.0: +; CHECK-NEXT: lsl z0.s, z0.s, #16 +; CHECK-NEXT: ptrue p0.d +; CHECK-NEXT: fminnmv s0, p0, z0.s +; CHECK-NEXT: bfcvt h0, s0 +; CHECK-NEXT: ret + %res = call bfloat @llvm.vector.reduce.fmin.nxv2bf16( %a) + ret bfloat %res +} + +define bfloat @fminv_nxv4bf16( %a) { +; CHECK-LABEL: fminv_nxv4bf16: +; CHECK: // %bb.0: +; CHECK-NEXT: lsl z0.s, z0.s, #16 +; CHECK-NEXT: ptrue p0.s +; CHECK-NEXT: fminnmv s0, p0, z0.s +; CHECK-NEXT: bfcvt h0, s0 +; CHECK-NEXT: ret + %res = call bfloat @llvm.vector.reduce.fmin.nxv4bf16( %a) + ret bfloat %res +} + +define bfloat @fminv_nxv8bf16( %a) { +; CHECK-LABEL: fminv_nxv8bf16: +; CHECK: // %bb.0: +; CHECK-NEXT: uunpkhi z1.s, z0.h +; CHECK-NEXT: uunpklo z0.s, z0.h +; CHECK-NEXT: ptrue p0.s +; CHECK-NEXT: lsl z1.s, z1.s, #16 +; CHECK-NEXT: lsl z0.s, z0.s, #16 +; CHECK-NEXT: fminnm z0.s, p0/m, z0.s, z1.s +; CHECK-NEXT: fminnmv s0, p0, z0.s +; CHECK-NEXT: bfcvt h0, s0 +; CHECK-NEXT: ret + %res = call bfloat @llvm.vector.reduce.fmin.nxv8bf16( %a) + ret bfloat %res +} + +; FMAXV + +define bfloat @fmaximumv_nxv2bf16( %a) { +; CHECK-LABEL: fmaximumv_nxv2bf16: +; CHECK: // %bb.0: +; CHECK-NEXT: lsl z0.s, z0.s, #16 +; CHECK-NEXT: ptrue p0.d +; CHECK-NEXT: fmaxv s0, p0, z0.s +; CHECK-NEXT: bfcvt h0, s0 +; CHECK-NEXT: ret + %res = call bfloat @llvm.vector.reduce.fmaximum.nxv2bf16( %a) + ret bfloat %res +} + +define bfloat @fmaximumv_nxv4bf16( %a) { +; CHECK-LABEL: fmaximumv_nxv4bf16: +; CHECK: // %bb.0: +; CHECK-NEXT: lsl z0.s, z0.s, #16 +; CHECK-NEXT: ptrue p0.s +; CHECK-NEXT: fmaxv s0, p0, z0.s +; CHECK-NEXT: bfcvt h0, s0 +; CHECK-NEXT: ret + %res = call bfloat @llvm.vector.reduce.fmaximum.nxv4bf16( %a) + ret bfloat %res +} + +define bfloat @fmaximumv_nxv8bf16( %a) { +; CHECK-LABEL: fmaximumv_nxv8bf16: +; CHECK: // %bb.0: +; CHECK-NEXT: uunpkhi z1.s, z0.h +; CHECK-NEXT: uunpklo z0.s, z0.h +; CHECK-NEXT: ptrue p0.s +; CHECK-NEXT: lsl z1.s, z1.s, #16 +; CHECK-NEXT: lsl z0.s, z0.s, #16 +; CHECK-NEXT: fmax z0.s, p0/m, z0.s, z1.s +; CHECK-NEXT: fmaxv s0, p0, z0.s +; CHECK-NEXT: bfcvt h0, s0 +; CHECK-NEXT: ret + %res = call bfloat @llvm.vector.reduce.fmaximum.nxv8bf16( %a) + ret bfloat %res +} + +; FMINV + +define bfloat @fminimumv_nxv2bf16( %a) { +; CHECK-LABEL: fminimumv_nxv2bf16: +; CHECK: // %bb.0: +; CHECK-NEXT: lsl z0.s, z0.s, #16 +; CHECK-NEXT: ptrue p0.d +; CHECK-NEXT: fminv s0, p0, z0.s +; CHECK-NEXT: bfcvt h0, s0 +; CHECK-NEXT: ret + %res = call bfloat @llvm.vector.reduce.fminimum.nxv2bf16( %a) + ret bfloat %res +} + +define bfloat @fminimumv_nxv4bf16( %a) { +; CHECK-LABEL: fminimumv_nxv4bf16: +; CHECK: // %bb.0: +; CHECK-NEXT: lsl z0.s, z0.s, #16 +; CHECK-NEXT: ptrue p0.s +; CHECK-NEXT: fminv s0, p0, z0.s +; CHECK-NEXT: bfcvt h0, s0 +; CHECK-NEXT: ret + %res = call bfloat @llvm.vector.reduce.fminimum.nxv4bf16( %a) + ret bfloat %res +} + +define bfloat @fminimumv_nxv8bf16( %a) { +; CHECK-LABEL: fminimumv_nxv8bf16: +; CHECK: // %bb.0: +; CHECK-NEXT: uunpkhi z1.s, z0.h +; CHECK-NEXT: uunpklo z0.s, z0.h +; CHECK-NEXT: ptrue p0.s +; CHECK-NEXT: lsl z1.s, z1.s, #16 +; CHECK-NEXT: lsl z0.s, z0.s, #16 +; CHECK-NEXT: fmin z0.s, p0/m, z0.s, z1.s +; CHECK-NEXT: fminv s0, p0, z0.s +; CHECK-NEXT: bfcvt h0, s0 +; CHECK-NEXT: ret + %res = call bfloat @llvm.vector.reduce.fminimum.nxv8bf16( %a) + ret bfloat %res +} + +; The reduction is performed at a higher precision. Because add operations +; can utilise that precision, its result must be rounded even if it's then +; promoted. +define float @promoted_fadd( %a) { +; SVE-LABEL: promoted_fadd: +; SVE: // %bb.0: +; SVE-NEXT: lsl z0.s, z0.s, #16 +; SVE-NEXT: ptrue p0.s +; SVE-NEXT: faddv s0, p0, z0.s +; SVE-NEXT: bfcvt h0, s0 +; SVE-NEXT: shll v0.4s, v0.4h, #16 +; SVE-NEXT: // kill: def $s0 killed $s0 killed $q0 +; SVE-NEXT: ret +; +; SME-LABEL: promoted_fadd: +; SME: // %bb.0: +; SME-NEXT: lsl z0.s, z0.s, #16 +; SME-NEXT: ptrue p0.s +; SME-NEXT: faddv s0, p0, z0.s +; SME-NEXT: bfcvt h0, s0 +; SME-NEXT: fmov w8, s0 +; SME-NEXT: lsl w8, w8, #16 +; SME-NEXT: fmov s0, w8 +; SME-NEXT: ret + %rdx = call fast bfloat @llvm.vector.reduce.fadd.nxv4bf16(bfloat zeroinitializer, %a) + %res = fpext bfloat %rdx to float + ret float %res +} + +; The reduction is performed at a higher precision. Because min/max operations +; don't utilise that precision, its result can be used directly. +define float @promoted_fmax( %a) { +; CHECK-LABEL: promoted_fmax: +; CHECK: // %bb.0: +; CHECK-NEXT: lsl z0.s, z0.s, #16 +; CHECK-NEXT: ptrue p0.s +; CHECK-NEXT: fmaxnmv s0, p0, z0.s +; CHECK-NEXT: // kill: def $s0 killed $s0 killed $z0 +; CHECK-NEXT: ret + %rdx = call bfloat @llvm.vector.reduce.fmax.nxv4bf16( %a) + %res = fpext bfloat %rdx to float + ret float %res +} + +declare bfloat @llvm.vector.reduce.fadd.nxv2bf16(bfloat, ) +declare bfloat @llvm.vector.reduce.fadd.nxv4bf16(bfloat, ) +declare bfloat @llvm.vector.reduce.fadd.nxv8bf16(bfloat, ) + +declare bfloat @llvm.vector.reduce.fmax.nxv2bf16() +declare bfloat @llvm.vector.reduce.fmax.nxv4bf16() +declare bfloat @llvm.vector.reduce.fmax.nxv8bf16() + +declare bfloat @llvm.vector.reduce.fmin.nxv2bf16() +declare bfloat @llvm.vector.reduce.fmin.nxv4bf16() +declare bfloat @llvm.vector.reduce.fmin.nxv8bf16() + +declare bfloat @llvm.vector.reduce.fmaximum.nxv2bf16() +declare bfloat @llvm.vector.reduce.fmaximum.nxv4bf16() +declare bfloat @llvm.vector.reduce.fmaximum.nxv8bf16() + +declare bfloat @llvm.vector.reduce.fminimum.nxv2bf16() +declare bfloat @llvm.vector.reduce.fminimum.nxv4bf16() +declare bfloat @llvm.vector.reduce.fminimum.nxv8bf16()