-
Notifications
You must be signed in to change notification settings - Fork 15.4k
[LLVM][CodeGen][SVE] Add isel for bfloat unordered reductions. #143540
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 1 commit
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -188,6 +188,7 @@ class VectorLegalizer { | |
| void PromoteSETCC(SDNode *Node, SmallVectorImpl<SDValue> &Results); | ||
|
|
||
| void PromoteSTRICT(SDNode *Node, SmallVectorImpl<SDValue> &Results); | ||
| void PromoteVECREDUCE(SDNode *Node, SmallVectorImpl<SDValue> &Results); | ||
|
|
||
| public: | ||
| VectorLegalizer(SelectionDAG& dag) : | ||
|
|
@@ -500,20 +501,14 @@ SDValue VectorLegalizer::LegalizeOp(SDValue Op) { | |
| case ISD::VECREDUCE_UMAX: | ||
| case ISD::VECREDUCE_UMIN: | ||
| case ISD::VECREDUCE_FADD: | ||
| case ISD::VECREDUCE_FMUL: | ||
| case ISD::VECTOR_FIND_LAST_ACTIVE: | ||
| Action = TLI.getOperationAction(Node->getOpcode(), | ||
| Node->getOperand(0).getValueType()); | ||
| break; | ||
| case ISD::VECREDUCE_FMAX: | ||
| case ISD::VECREDUCE_FMIN: | ||
| case ISD::VECREDUCE_FMAXIMUM: | ||
| case ISD::VECREDUCE_FMIN: | ||
| case ISD::VECREDUCE_FMINIMUM: | ||
| case ISD::VECREDUCE_FMUL: | ||
| case ISD::VECTOR_FIND_LAST_ACTIVE: | ||
| Action = TLI.getOperationAction(Node->getOpcode(), | ||
| Node->getOperand(0).getValueType()); | ||
| // Defer non-vector results to LegalizeDAG. | ||
| if (Action == TargetLowering::Promote) | ||
| Action = TargetLowering::Legal; | ||
| break; | ||
| case ISD::VECREDUCE_SEQ_FADD: | ||
| case ISD::VECREDUCE_SEQ_FMUL: | ||
|
|
@@ -688,6 +683,22 @@ void VectorLegalizer::PromoteSTRICT(SDNode *Node, | |
| Results.push_back(Round.getValue(1)); | ||
| } | ||
|
|
||
| void VectorLegalizer::PromoteVECREDUCE(SDNode *Node, | ||
| SmallVectorImpl<SDValue> &Results) { | ||
| MVT OpVT = Node->getOperand(0).getSimpleValueType(); | ||
| assert(OpVT.isFloatingPoint() && "Expected floating point reduction!"); | ||
| MVT NewOpVT = TLI.getTypeToPromoteTo(Node->getOpcode(), OpVT); | ||
|
|
||
| SDLoc DL(Node); | ||
| SDValue NewOp = DAG.getNode(ISD::FP_EXTEND, DL, NewOpVT, Node->getOperand(0)); | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What about the strictfp case?
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Does strictfp apply here? There are no STRICT_VECREDUCE_ nodes. There are the ordered VECREDUCE_SEQ_ nodes, but they go down a different path so are not covered by this PR.
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes. This pushes on the broader issue with the current strictfp strategy where we need to duplicate all possible FP intrinsics and we're missing them here. So for these non-strict intrinsics it's not an issue, the issue is we don't have strict versions of these |
||
| SDValue Rdx = | ||
| DAG.getNode(Node->getOpcode(), DL, NewOpVT.getVectorElementType(), NewOp, | ||
| Node->getFlags()); | ||
| SDValue Res = DAG.getNode(ISD::FP_ROUND, DL, Node->getValueType(0), Rdx, | ||
| DAG.getIntPtrConstant(0, DL, /*isTarget=*/true)); | ||
|
||
| Results.push_back(Res); | ||
| } | ||
|
|
||
| void VectorLegalizer::Promote(SDNode *Node, SmallVectorImpl<SDValue> &Results) { | ||
| // For a few operations there is a specific concept for promotion based on | ||
| // the operand's type. | ||
|
|
@@ -719,6 +730,13 @@ void VectorLegalizer::Promote(SDNode *Node, SmallVectorImpl<SDValue> &Results) { | |
| case ISD::STRICT_FMA: | ||
| PromoteSTRICT(Node, Results); | ||
| return; | ||
| case ISD::VECREDUCE_FADD: | ||
| case ISD::VECREDUCE_FMAX: | ||
| case ISD::VECREDUCE_FMAXIMUM: | ||
| case ISD::VECREDUCE_FMIN: | ||
| case ISD::VECREDUCE_FMINIMUM: | ||
| PromoteVECREDUCE(Node, Results); | ||
| return; | ||
| case ISD::FP_ROUND: | ||
| case ISD::FP_EXTEND: | ||
| // These operations are used to do promotion so they can't be promoted | ||
|
|
||
| Original file line number | Diff line number | Diff line change | ||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -11412,13 +11412,9 @@ SDValue TargetLowering::expandVecReduce(SDNode *Node, SelectionDAG &DAG) const { | |||||||||||||||
| SDValue Op = Node->getOperand(0); | ||||||||||||||||
| EVT VT = Op.getValueType(); | ||||||||||||||||
|
|
||||||||||||||||
| if (VT.isScalableVector()) | ||||||||||||||||
| report_fatal_error( | ||||||||||||||||
| "Expanding reductions for scalable vectors is undefined."); | ||||||||||||||||
|
|
||||||||||||||||
| // Try to use a shuffle reduction for power of two vectors. | ||||||||||||||||
| if (VT.isPow2VectorType()) { | ||||||||||||||||
| while (VT.getVectorNumElements() > 1) { | ||||||||||||||||
| while (VT.getVectorElementCount().isKnownMultipleOf(2)) { | ||||||||||||||||
| EVT HalfVT = VT.getHalfNumVectorElementsVT(*DAG.getContext()); | ||||||||||||||||
| if (!isOperationLegalOrCustom(BaseOpcode, HalfVT)) | ||||||||||||||||
| break; | ||||||||||||||||
|
|
@@ -11427,9 +11423,18 @@ SDValue TargetLowering::expandVecReduce(SDNode *Node, SelectionDAG &DAG) const { | |||||||||||||||
| std::tie(Lo, Hi) = DAG.SplitVector(Op, dl); | ||||||||||||||||
| Op = DAG.getNode(BaseOpcode, dl, HalfVT, Lo, Hi, Node->getFlags()); | ||||||||||||||||
| VT = HalfVT; | ||||||||||||||||
|
|
||||||||||||||||
| // Stop if splitting is enough to make the reduction legal. | ||||||||||||||||
| if (isOperationLegalOrCustom(Node->getOpcode(), HalfVT)) | ||||||||||||||||
| return DAG.getNode(Node->getOpcode(), dl, Node->getValueType(0), Op, | ||||||||||||||||
| Node->getFlags()); | ||||||||||||||||
| } | ||||||||||||||||
| } | ||||||||||||||||
|
|
||||||||||||||||
| if (VT.isScalableVector()) | ||||||||||||||||
| report_fatal_error( | ||||||||||||||||
| "Expanding reductions for scalable vectors is undefined."); | ||||||||||||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||||||||||
|
|
||||||||||||||||
| EVT EltVT = VT.getVectorElementType(); | ||||||||||||||||
| unsigned NumElts = VT.getVectorNumElements(); | ||||||||||||||||
|
|
||||||||||||||||
|
|
||||||||||||||||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,235 @@ | ||
| ; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5 | ||
| ; RUN: llc -mattr=+sve,+bf16 < %s | FileCheck %s | ||
| ; RUN: llc -mattr=+sme -force-streaming < %s | FileCheck %s | ||
|
|
||
| target triple = "aarch64-unknown-linux-gnu" | ||
|
|
||
| ; FADDV | ||
|
|
||
| define bfloat @faddv_nxv2bf16(<vscale x 2 x bfloat> %a) { | ||
| ; CHECK-LABEL: faddv_nxv2bf16: | ||
| ; CHECK: // %bb.0: | ||
| ; CHECK-NEXT: lsl z0.s, z0.s, #16 | ||
| ; CHECK-NEXT: ptrue p0.d | ||
| ; CHECK-NEXT: faddv s0, p0, z0.s | ||
| ; CHECK-NEXT: bfcvt h0, s0 | ||
| ; CHECK-NEXT: ret | ||
| %res = call fast bfloat @llvm.vector.reduce.fadd.nxv2bf16(bfloat zeroinitializer, <vscale x 2 x bfloat> %a) | ||
| ret bfloat %res | ||
| } | ||
|
|
||
| define bfloat @faddv_nxv4bf16(<vscale x 4 x bfloat> %a) { | ||
| ; CHECK-LABEL: faddv_nxv4bf16: | ||
| ; CHECK: // %bb.0: | ||
| ; CHECK-NEXT: lsl z0.s, z0.s, #16 | ||
| ; CHECK-NEXT: ptrue p0.s | ||
| ; CHECK-NEXT: faddv s0, p0, z0.s | ||
| ; CHECK-NEXT: bfcvt h0, s0 | ||
| ; CHECK-NEXT: ret | ||
| %res = call fast bfloat @llvm.vector.reduce.fadd.nxv4bf16(bfloat zeroinitializer, <vscale x 4 x bfloat> %a) | ||
| ret bfloat %res | ||
| } | ||
|
|
||
| define bfloat @faddv_nxv8bf16(<vscale x 8 x bfloat> %a) { | ||
| ; CHECK-LABEL: faddv_nxv8bf16: | ||
| ; CHECK: // %bb.0: | ||
| ; CHECK-NEXT: uunpkhi z1.s, z0.h | ||
| ; CHECK-NEXT: uunpklo z0.s, z0.h | ||
| ; CHECK-NEXT: ptrue p0.s | ||
| ; CHECK-NEXT: lsl z1.s, z1.s, #16 | ||
| ; CHECK-NEXT: lsl z0.s, z0.s, #16 | ||
| ; CHECK-NEXT: fadd z0.s, z0.s, z1.s | ||
| ; CHECK-NEXT: faddv s0, p0, z0.s | ||
| ; CHECK-NEXT: bfcvt h0, s0 | ||
| ; CHECK-NEXT: ret | ||
| %res = call fast bfloat @llvm.vector.reduce.fadd.nxv8bf16(bfloat zeroinitializer, <vscale x 8 x bfloat> %a) | ||
| ret bfloat %res | ||
| } | ||
|
|
||
| ; FMAXNMV | ||
|
|
||
| define bfloat @fmaxv_nxv2bf16(<vscale x 2 x bfloat> %a) { | ||
| ; CHECK-LABEL: fmaxv_nxv2bf16: | ||
| ; CHECK: // %bb.0: | ||
| ; CHECK-NEXT: lsl z0.s, z0.s, #16 | ||
| ; CHECK-NEXT: ptrue p0.d | ||
| ; CHECK-NEXT: fmaxnmv s0, p0, z0.s | ||
| ; CHECK-NEXT: bfcvt h0, s0 | ||
| ; CHECK-NEXT: ret | ||
| %res = call bfloat @llvm.vector.reduce.fmax.nxv2bf16(<vscale x 2 x bfloat> %a) | ||
| ret bfloat %res | ||
| } | ||
|
|
||
| define bfloat @fmaxv_nxv4bf16(<vscale x 4 x bfloat> %a) { | ||
| ; CHECK-LABEL: fmaxv_nxv4bf16: | ||
| ; CHECK: // %bb.0: | ||
| ; CHECK-NEXT: lsl z0.s, z0.s, #16 | ||
| ; CHECK-NEXT: ptrue p0.s | ||
| ; CHECK-NEXT: fmaxnmv s0, p0, z0.s | ||
| ; CHECK-NEXT: bfcvt h0, s0 | ||
| ; CHECK-NEXT: ret | ||
| %res = call bfloat @llvm.vector.reduce.fmax.nxv4bf16(<vscale x 4 x bfloat> %a) | ||
| ret bfloat %res | ||
| } | ||
|
|
||
| define bfloat @fmaxv_nxv8bf16(<vscale x 8 x bfloat> %a) { | ||
| ; CHECK-LABEL: fmaxv_nxv8bf16: | ||
| ; CHECK: // %bb.0: | ||
| ; CHECK-NEXT: uunpkhi z1.s, z0.h | ||
| ; CHECK-NEXT: uunpklo z0.s, z0.h | ||
| ; CHECK-NEXT: ptrue p0.s | ||
| ; CHECK-NEXT: lsl z1.s, z1.s, #16 | ||
| ; CHECK-NEXT: lsl z0.s, z0.s, #16 | ||
| ; CHECK-NEXT: fmaxnm z0.s, p0/m, z0.s, z1.s | ||
| ; CHECK-NEXT: fmaxnmv s0, p0, z0.s | ||
| ; CHECK-NEXT: bfcvt h0, s0 | ||
| ; CHECK-NEXT: ret | ||
| %res = call bfloat @llvm.vector.reduce.fmax.nxv8bf16(<vscale x 8 x bfloat> %a) | ||
| ret bfloat %res | ||
| } | ||
|
|
||
| ; FMINNMV | ||
|
|
||
| define bfloat @fminv_nxv2bf16(<vscale x 2 x bfloat> %a) { | ||
| ; CHECK-LABEL: fminv_nxv2bf16: | ||
| ; CHECK: // %bb.0: | ||
| ; CHECK-NEXT: lsl z0.s, z0.s, #16 | ||
| ; CHECK-NEXT: ptrue p0.d | ||
| ; CHECK-NEXT: fminnmv s0, p0, z0.s | ||
| ; CHECK-NEXT: bfcvt h0, s0 | ||
| ; CHECK-NEXT: ret | ||
| %res = call bfloat @llvm.vector.reduce.fmin.nxv2bf16(<vscale x 2 x bfloat> %a) | ||
| ret bfloat %res | ||
| } | ||
|
|
||
| define bfloat @fminv_nxv4bf16(<vscale x 4 x bfloat> %a) { | ||
| ; CHECK-LABEL: fminv_nxv4bf16: | ||
| ; CHECK: // %bb.0: | ||
| ; CHECK-NEXT: lsl z0.s, z0.s, #16 | ||
| ; CHECK-NEXT: ptrue p0.s | ||
| ; CHECK-NEXT: fminnmv s0, p0, z0.s | ||
| ; CHECK-NEXT: bfcvt h0, s0 | ||
| ; CHECK-NEXT: ret | ||
| %res = call bfloat @llvm.vector.reduce.fmin.nxv4bf16(<vscale x 4 x bfloat> %a) | ||
| ret bfloat %res | ||
| } | ||
|
|
||
| define bfloat @fminv_nxv8bf16(<vscale x 8 x bfloat> %a) { | ||
| ; CHECK-LABEL: fminv_nxv8bf16: | ||
| ; CHECK: // %bb.0: | ||
| ; CHECK-NEXT: uunpkhi z1.s, z0.h | ||
| ; CHECK-NEXT: uunpklo z0.s, z0.h | ||
| ; CHECK-NEXT: ptrue p0.s | ||
| ; CHECK-NEXT: lsl z1.s, z1.s, #16 | ||
| ; CHECK-NEXT: lsl z0.s, z0.s, #16 | ||
| ; CHECK-NEXT: fminnm z0.s, p0/m, z0.s, z1.s | ||
| ; CHECK-NEXT: fminnmv s0, p0, z0.s | ||
| ; CHECK-NEXT: bfcvt h0, s0 | ||
| ; CHECK-NEXT: ret | ||
| %res = call bfloat @llvm.vector.reduce.fmin.nxv8bf16(<vscale x 8 x bfloat> %a) | ||
| ret bfloat %res | ||
| } | ||
|
|
||
| ; FMAXV | ||
|
|
||
| define bfloat @fmaximumv_nxv2bf16(<vscale x 2 x bfloat> %a) { | ||
| ; CHECK-LABEL: fmaximumv_nxv2bf16: | ||
| ; CHECK: // %bb.0: | ||
| ; CHECK-NEXT: lsl z0.s, z0.s, #16 | ||
| ; CHECK-NEXT: ptrue p0.d | ||
| ; CHECK-NEXT: fmaxv s0, p0, z0.s | ||
| ; CHECK-NEXT: bfcvt h0, s0 | ||
| ; CHECK-NEXT: ret | ||
| %res = call bfloat @llvm.vector.reduce.fmaximum.nxv2bf16(<vscale x 2 x bfloat> %a) | ||
| ret bfloat %res | ||
| } | ||
|
|
||
| define bfloat @fmaximumv_nxv4bf16(<vscale x 4 x bfloat> %a) { | ||
| ; CHECK-LABEL: fmaximumv_nxv4bf16: | ||
| ; CHECK: // %bb.0: | ||
| ; CHECK-NEXT: lsl z0.s, z0.s, #16 | ||
| ; CHECK-NEXT: ptrue p0.s | ||
| ; CHECK-NEXT: fmaxv s0, p0, z0.s | ||
| ; CHECK-NEXT: bfcvt h0, s0 | ||
| ; CHECK-NEXT: ret | ||
| %res = call bfloat @llvm.vector.reduce.fmaximum.nxv4bf16(<vscale x 4 x bfloat> %a) | ||
| ret bfloat %res | ||
| } | ||
|
|
||
| define bfloat @fmaximumv_nxv8bf16(<vscale x 8 x bfloat> %a) { | ||
| ; CHECK-LABEL: fmaximumv_nxv8bf16: | ||
| ; CHECK: // %bb.0: | ||
| ; CHECK-NEXT: uunpkhi z1.s, z0.h | ||
| ; CHECK-NEXT: uunpklo z0.s, z0.h | ||
| ; CHECK-NEXT: ptrue p0.s | ||
| ; CHECK-NEXT: lsl z1.s, z1.s, #16 | ||
| ; CHECK-NEXT: lsl z0.s, z0.s, #16 | ||
| ; CHECK-NEXT: fmax z0.s, p0/m, z0.s, z1.s | ||
| ; CHECK-NEXT: fmaxv s0, p0, z0.s | ||
| ; CHECK-NEXT: bfcvt h0, s0 | ||
| ; CHECK-NEXT: ret | ||
| %res = call bfloat @llvm.vector.reduce.fmaximum.nxv8bf16(<vscale x 8 x bfloat> %a) | ||
| ret bfloat %res | ||
| } | ||
|
|
||
| ; FMINV | ||
|
|
||
| define bfloat @fminimumv_nxv2bf16(<vscale x 2 x bfloat> %a) { | ||
| ; CHECK-LABEL: fminimumv_nxv2bf16: | ||
| ; CHECK: // %bb.0: | ||
| ; CHECK-NEXT: lsl z0.s, z0.s, #16 | ||
| ; CHECK-NEXT: ptrue p0.d | ||
| ; CHECK-NEXT: fminv s0, p0, z0.s | ||
| ; CHECK-NEXT: bfcvt h0, s0 | ||
| ; CHECK-NEXT: ret | ||
| %res = call bfloat @llvm.vector.reduce.fminimum.nxv2bf16(<vscale x 2 x bfloat> %a) | ||
| ret bfloat %res | ||
| } | ||
|
|
||
| define bfloat @fminimumv_nxv4bf16(<vscale x 4 x bfloat> %a) { | ||
| ; CHECK-LABEL: fminimumv_nxv4bf16: | ||
| ; CHECK: // %bb.0: | ||
| ; CHECK-NEXT: lsl z0.s, z0.s, #16 | ||
| ; CHECK-NEXT: ptrue p0.s | ||
| ; CHECK-NEXT: fminv s0, p0, z0.s | ||
| ; CHECK-NEXT: bfcvt h0, s0 | ||
| ; CHECK-NEXT: ret | ||
| %res = call bfloat @llvm.vector.reduce.fminimum.nxv4bf16(<vscale x 4 x bfloat> %a) | ||
| ret bfloat %res | ||
| } | ||
|
|
||
| define bfloat @fminimumv_nxv8bf16(<vscale x 8 x bfloat> %a) { | ||
| ; CHECK-LABEL: fminimumv_nxv8bf16: | ||
| ; CHECK: // %bb.0: | ||
| ; CHECK-NEXT: uunpkhi z1.s, z0.h | ||
| ; CHECK-NEXT: uunpklo z0.s, z0.h | ||
| ; CHECK-NEXT: ptrue p0.s | ||
| ; CHECK-NEXT: lsl z1.s, z1.s, #16 | ||
| ; CHECK-NEXT: lsl z0.s, z0.s, #16 | ||
| ; CHECK-NEXT: fmin z0.s, p0/m, z0.s, z1.s | ||
| ; CHECK-NEXT: fminv s0, p0, z0.s | ||
| ; CHECK-NEXT: bfcvt h0, s0 | ||
| ; CHECK-NEXT: ret | ||
| %res = call bfloat @llvm.vector.reduce.fminimum.nxv8bf16(<vscale x 8 x bfloat> %a) | ||
| ret bfloat %res | ||
| } | ||
|
|
||
| declare bfloat @llvm.vector.reduce.fadd.nxv2bf16(bfloat, <vscale x 2 x bfloat>) | ||
| declare bfloat @llvm.vector.reduce.fadd.nxv4bf16(bfloat, <vscale x 4 x bfloat>) | ||
| declare bfloat @llvm.vector.reduce.fadd.nxv8bf16(bfloat, <vscale x 8 x bfloat>) | ||
|
|
||
| declare bfloat @llvm.vector.reduce.fmax.nxv2bf16(<vscale x 2 x bfloat>) | ||
| declare bfloat @llvm.vector.reduce.fmax.nxv4bf16(<vscale x 4 x bfloat>) | ||
| declare bfloat @llvm.vector.reduce.fmax.nxv8bf16(<vscale x 8 x bfloat>) | ||
|
|
||
| declare bfloat @llvm.vector.reduce.fmin.nxv2bf16(<vscale x 2 x bfloat>) | ||
| declare bfloat @llvm.vector.reduce.fmin.nxv4bf16(<vscale x 4 x bfloat>) | ||
| declare bfloat @llvm.vector.reduce.fmin.nxv8bf16(<vscale x 8 x bfloat>) | ||
|
|
||
| declare bfloat @llvm.vector.reduce.fmaximum.nxv2bf16(<vscale x 2 x bfloat>) | ||
| declare bfloat @llvm.vector.reduce.fmaximum.nxv4bf16(<vscale x 4 x bfloat>) | ||
| declare bfloat @llvm.vector.reduce.fmaximum.nxv8bf16(<vscale x 8 x bfloat>) | ||
|
|
||
| declare bfloat @llvm.vector.reduce.fminimum.nxv2bf16(<vscale x 2 x bfloat>) | ||
| declare bfloat @llvm.vector.reduce.fminimum.nxv4bf16(<vscale x 4 x bfloat>) | ||
| declare bfloat @llvm.vector.reduce.fminimum.nxv8bf16(<vscale x 8 x bfloat>) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Doesn't clearly indicate this is the for FP operations