-
Notifications
You must be signed in to change notification settings - Fork 15.4k
[SelectionDAG][x86] Ensure vector reduction optimization #144231
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 4 commits
2b2130a
93df21f
57a3788
ebb3ba0
85e5195
90da8ff
376965e
886f8a7
1bf5442
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -1435,6 +1435,20 @@ X86TargetLowering::X86TargetLowering(const X86TargetMachine &TM, | |
| setOperationAction(ISD::BITREVERSE, VT, Custom); | ||
| } | ||
|
|
||
| // Vector min/max reductions | ||
| if (Subtarget.hasSSE41()) | ||
| { | ||
| for (MVT VT : MVT::vector_valuetypes()) { | ||
| if (VT.getScalarType() == MVT::i8 || VT.getScalarType() == MVT::i16) | ||
| { | ||
| setOperationAction(ISD::VECREDUCE_UMIN, VT, Custom); | ||
| setOperationAction(ISD::VECREDUCE_UMAX, VT, Custom); | ||
| setOperationAction(ISD::VECREDUCE_SMIN, VT, Custom); | ||
| setOperationAction(ISD::VECREDUCE_SMAX, VT, Custom); | ||
| } | ||
| } | ||
| } | ||
|
|
||
| if (!Subtarget.useSoftFloat() && Subtarget.hasAVX()) { | ||
| bool HasInt256 = Subtarget.hasInt256(); | ||
|
|
||
|
|
@@ -25409,6 +25423,94 @@ static SDValue LowerEXTEND_VECTOR_INREG(SDValue Op, | |
| return SignExt; | ||
| } | ||
|
|
||
| // Create a min/max v8i16/v16i8 horizontal reduction with PHMINPOSUW. | ||
| static SDValue createMinMaxReduction(SDValue Src, EVT TargetVT, SDLoc DL, | ||
| ISD::NodeType BinOp, SelectionDAG &DAG, | ||
| const X86Subtarget &Subtarget) | ||
| { | ||
| assert(Subtarget.hasSSE41() && "The caller must check if SSE4.1 is available"); | ||
|
|
||
| EVT SrcVT = Src.getValueType(); | ||
| EVT SrcSVT = SrcVT.getScalarType(); | ||
|
|
||
| if (SrcSVT != TargetVT || (SrcVT.getSizeInBits() % 128) != 0) | ||
| return SDValue(); | ||
|
|
||
| // First, reduce the source down to 128-bit, applying BinOp to lo/hi. | ||
| while (SrcVT.getSizeInBits() > 128) { | ||
| SDValue Lo, Hi; | ||
| std::tie(Lo, Hi) = splitVector(Src, DAG, DL); | ||
| SrcVT = Lo.getValueType(); | ||
| Src = DAG.getNode(BinOp, DL, SrcVT, Lo, Hi); | ||
| } | ||
| assert(((SrcVT == MVT::v8i16 && TargetVT == MVT::i16) || | ||
| (SrcVT == MVT::v16i8 && TargetVT == MVT::i8)) && | ||
| "Unexpected value type"); | ||
|
|
||
| // PHMINPOSUW applies to UMIN(v8i16), for SMIN/SMAX/UMAX we must apply a mask | ||
| // to flip the value accordingly. | ||
| SDValue Mask; | ||
| unsigned MaskEltsBits = TargetVT.getSizeInBits(); | ||
| if (BinOp == ISD::SMAX) | ||
| Mask = DAG.getConstant(APInt::getSignedMaxValue(MaskEltsBits), DL, SrcVT); | ||
| else if (BinOp == ISD::SMIN) | ||
| Mask = DAG.getConstant(APInt::getSignedMinValue(MaskEltsBits), DL, SrcVT); | ||
| else if (BinOp == ISD::UMAX) | ||
| Mask = DAG.getAllOnesConstant(DL, SrcVT); | ||
|
|
||
| if (Mask) | ||
| Src = DAG.getNode(ISD::XOR, DL, SrcVT, Mask, Src); | ||
|
|
||
| // For v16i8 cases we need to perform UMIN on pairs of byte elements, | ||
| // shuffling each upper element down and insert zeros. This means that the | ||
| // v16i8 UMIN will leave the upper element as zero, performing zero-extension | ||
| // ready for the PHMINPOS. | ||
| if (TargetVT == MVT::i8) { | ||
| SDValue Upper = DAG.getVectorShuffle( | ||
| SrcVT, DL, Src, DAG.getConstant(0, DL, MVT::v16i8), | ||
| {1, 16, 3, 16, 5, 16, 7, 16, 9, 16, 11, 16, 13, 16, 15, 16}); | ||
| Src = DAG.getNode(ISD::UMIN, DL, SrcVT, Src, Upper); | ||
| } | ||
|
|
||
| // Perform the PHMINPOS on a v8i16 vector, | ||
| Src = DAG.getBitcast(MVT::v8i16, Src); | ||
| Src = DAG.getNode(X86ISD::PHMINPOS, DL, MVT::v8i16, Src); | ||
| Src = DAG.getBitcast(SrcVT, Src); | ||
|
|
||
| if (Mask) | ||
| Src = DAG.getNode(ISD::XOR, DL, SrcVT, Mask, Src); | ||
|
|
||
| return DAG.getExtractVectorElt(DL, TargetVT, Src, 0); | ||
| } | ||
|
|
||
| static SDValue LowerVECTOR_REDUCE_MINMAX(SDValue Op, | ||
| const X86Subtarget& Subtarget, | ||
| SelectionDAG& DAG) | ||
| { | ||
| ISD::NodeType BinOp; | ||
| switch (Op.getOpcode()) | ||
| { | ||
| default: | ||
| assert(false && "Expected min/max reduction"); | ||
| break; | ||
| case ISD::VECREDUCE_UMIN: | ||
| BinOp = ISD::UMIN; | ||
| break; | ||
| case ISD::VECREDUCE_UMAX: | ||
| BinOp = ISD::UMAX; | ||
| break; | ||
| case ISD::VECREDUCE_SMIN: | ||
| BinOp = ISD::SMIN; | ||
| break; | ||
| case ISD::VECREDUCE_SMAX: | ||
| BinOp = ISD::SMAX; | ||
| break; | ||
| } | ||
|
||
|
|
||
| return createMinMaxReduction(Op->getOperand(0), Op.getValueType(), SDLoc(Op), | ||
| BinOp, DAG, Subtarget); | ||
| } | ||
|
|
||
| static SDValue LowerSIGN_EXTEND(SDValue Op, const X86Subtarget &Subtarget, | ||
| SelectionDAG &DAG) { | ||
| MVT VT = Op->getSimpleValueType(0); | ||
|
|
@@ -33620,6 +33722,11 @@ SDValue X86TargetLowering::LowerOperation(SDValue Op, SelectionDAG &DAG) const { | |
| case ISD::ZERO_EXTEND_VECTOR_INREG: | ||
| case ISD::SIGN_EXTEND_VECTOR_INREG: | ||
| return LowerEXTEND_VECTOR_INREG(Op, Subtarget, DAG); | ||
| case ISD::VECREDUCE_UMIN: | ||
| case ISD::VECREDUCE_UMAX: | ||
| case ISD::VECREDUCE_SMIN: | ||
| case ISD::VECREDUCE_SMAX: | ||
| return LowerVECTOR_REDUCE_MINMAX(Op, Subtarget, DAG); | ||
| case ISD::FP_TO_SINT: | ||
| case ISD::STRICT_FP_TO_SINT: | ||
| case ISD::FP_TO_UINT: | ||
|
|
@@ -46192,60 +46299,8 @@ static SDValue combineMinMaxReduction(SDNode *Extract, SelectionDAG &DAG, | |
| if (!Src) | ||
| return SDValue(); | ||
|
|
||
| EVT SrcVT = Src.getValueType(); | ||
| EVT SrcSVT = SrcVT.getScalarType(); | ||
| if (SrcSVT != ExtractVT || (SrcVT.getSizeInBits() % 128) != 0) | ||
| return SDValue(); | ||
|
|
||
| SDLoc DL(Extract); | ||
| SDValue MinPos = Src; | ||
|
|
||
| // First, reduce the source down to 128-bit, applying BinOp to lo/hi. | ||
| while (SrcVT.getSizeInBits() > 128) { | ||
| SDValue Lo, Hi; | ||
| std::tie(Lo, Hi) = splitVector(MinPos, DAG, DL); | ||
| SrcVT = Lo.getValueType(); | ||
| MinPos = DAG.getNode(BinOp, DL, SrcVT, Lo, Hi); | ||
| } | ||
| assert(((SrcVT == MVT::v8i16 && ExtractVT == MVT::i16) || | ||
| (SrcVT == MVT::v16i8 && ExtractVT == MVT::i8)) && | ||
| "Unexpected value type"); | ||
|
|
||
| // PHMINPOSUW applies to UMIN(v8i16), for SMIN/SMAX/UMAX we must apply a mask | ||
| // to flip the value accordingly. | ||
| SDValue Mask; | ||
| unsigned MaskEltsBits = ExtractVT.getSizeInBits(); | ||
| if (BinOp == ISD::SMAX) | ||
| Mask = DAG.getConstant(APInt::getSignedMaxValue(MaskEltsBits), DL, SrcVT); | ||
| else if (BinOp == ISD::SMIN) | ||
| Mask = DAG.getConstant(APInt::getSignedMinValue(MaskEltsBits), DL, SrcVT); | ||
| else if (BinOp == ISD::UMAX) | ||
| Mask = DAG.getAllOnesConstant(DL, SrcVT); | ||
|
|
||
| if (Mask) | ||
| MinPos = DAG.getNode(ISD::XOR, DL, SrcVT, Mask, MinPos); | ||
|
|
||
| // For v16i8 cases we need to perform UMIN on pairs of byte elements, | ||
| // shuffling each upper element down and insert zeros. This means that the | ||
| // v16i8 UMIN will leave the upper element as zero, performing zero-extension | ||
| // ready for the PHMINPOS. | ||
| if (ExtractVT == MVT::i8) { | ||
| SDValue Upper = DAG.getVectorShuffle( | ||
| SrcVT, DL, MinPos, DAG.getConstant(0, DL, MVT::v16i8), | ||
| {1, 16, 3, 16, 5, 16, 7, 16, 9, 16, 11, 16, 13, 16, 15, 16}); | ||
| MinPos = DAG.getNode(ISD::UMIN, DL, SrcVT, MinPos, Upper); | ||
| } | ||
|
|
||
| // Perform the PHMINPOS on a v8i16 vector, | ||
| MinPos = DAG.getBitcast(MVT::v8i16, MinPos); | ||
| MinPos = DAG.getNode(X86ISD::PHMINPOS, DL, MVT::v8i16, MinPos); | ||
| MinPos = DAG.getBitcast(SrcVT, MinPos); | ||
|
|
||
| if (Mask) | ||
| MinPos = DAG.getNode(ISD::XOR, DL, SrcVT, Mask, MinPos); | ||
|
|
||
| return DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, ExtractVT, MinPos, | ||
| DAG.getVectorIdxConstant(0, DL)); | ||
| return createMinMaxReduction(Src, ExtractVT, SDLoc(Extract), | ||
|
||
| BinOp, DAG, Subtarget); | ||
| } | ||
|
|
||
| // Attempt to replace an all_of/any_of/parity style horizontal reduction with a MOVMSK. | ||
|
|
@@ -47081,7 +47136,8 @@ static SDValue combineArithReduction(SDNode *ExtElt, SelectionDAG &DAG, | |
| /// scalars back, while for x64 we should use 64-bit extracts and shifts. | ||
| static SDValue combineExtractVectorElt(SDNode *N, SelectionDAG &DAG, | ||
| TargetLowering::DAGCombinerInfo &DCI, | ||
| const X86Subtarget &Subtarget) { | ||
| const X86Subtarget &Subtarget, | ||
| bool& TransformedBinOpReduction) { | ||
| if (SDValue NewOp = combineExtractWithShuffle(N, DAG, DCI, Subtarget)) | ||
| return NewOp; | ||
|
|
||
|
|
@@ -47169,23 +47225,33 @@ static SDValue combineExtractVectorElt(SDNode *N, SelectionDAG &DAG, | |
| // Check whether this extract is the root of a sum of absolute differences | ||
| // pattern. This has to be done here because we really want it to happen | ||
| // pre-legalization, | ||
| if (SDValue SAD = combineBasicSADPattern(N, DAG, Subtarget)) | ||
| if (SDValue SAD = combineBasicSADPattern(N, DAG, Subtarget)) { | ||
| TransformedBinOpReduction = true; | ||
| return SAD; | ||
| } | ||
|
|
||
| if (SDValue VPDPBUSD = combineVPDPBUSDPattern(N, DAG, Subtarget)) | ||
| if (SDValue VPDPBUSD = combineVPDPBUSDPattern(N, DAG, Subtarget)) { | ||
| TransformedBinOpReduction = true; | ||
| return VPDPBUSD; | ||
| } | ||
|
|
||
| // Attempt to replace an all_of/any_of horizontal reduction with a MOVMSK. | ||
| if (SDValue Cmp = combinePredicateReduction(N, DAG, Subtarget)) | ||
| if (SDValue Cmp = combinePredicateReduction(N, DAG, Subtarget)) { | ||
| TransformedBinOpReduction = true; | ||
| return Cmp; | ||
| } | ||
|
|
||
| // Attempt to replace min/max v8i16/v16i8 reductions with PHMINPOSUW. | ||
| if (SDValue MinMax = combineMinMaxReduction(N, DAG, Subtarget)) | ||
| if (SDValue MinMax = combineMinMaxReduction(N, DAG, Subtarget)) { | ||
| TransformedBinOpReduction = true; | ||
| return MinMax; | ||
| } | ||
|
|
||
| // Attempt to optimize ADD/FADD/MUL reductions with HADD, promotion etc.. | ||
| if (SDValue V = combineArithReduction(N, DAG, Subtarget)) | ||
| if (SDValue V = combineArithReduction(N, DAG, Subtarget)) { | ||
| TransformedBinOpReduction = true; | ||
| return V; | ||
| } | ||
|
|
||
| if (SDValue V = scalarizeExtEltFP(N, DAG, Subtarget, DCI)) | ||
| return V; | ||
|
|
@@ -47255,6 +47321,36 @@ static SDValue combineExtractVectorElt(SDNode *N, SelectionDAG &DAG, | |
| return SDValue(); | ||
| } | ||
|
|
||
| static SDValue combineExtractVectorEltAndOperand(SDNode* N, SelectionDAG& DAG, | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is the combine independent of moving code to the lowering? Better to split them if it is.
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. They are independent, but in the end they serve the same purpose. Should I open a separate PR for the lowering?
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes. That would be good to show which code affect the test change. |
||
| TargetLowering::DAGCombinerInfo& DCI, | ||
| const X86Subtarget& Subtarget) | ||
| { | ||
| bool TransformedBinOpReduction = false; | ||
| auto Op = combineExtractVectorElt(N, DAG, DCI, Subtarget, TransformedBinOpReduction); | ||
|
|
||
| if (TransformedBinOpReduction) | ||
| { | ||
| // In case we simplified N = extract_vector_element(V, 0) with Op and V | ||
| // resulted from a reduction, then we need to replace all uses of V with | ||
| // scalar_to_vector(Op) to make sure that we eliminated the binop + shuffle | ||
| // pyramid. This is safe to do, because the elements of V are undefined except | ||
| // for the zeroth element and Op does not depend on V. | ||
|
|
||
| auto OldV = N->getOperand(0); | ||
| assert(!Op.getNode()->hasPredecessor(OldV.getNode()) && | ||
| "Op must not depend on the converted reduction"); | ||
|
|
||
| auto NewV = DAG.getNode(ISD::SCALAR_TO_VECTOR, SDLoc(N), OldV->getValueType(0), Op); | ||
|
|
||
| auto NV = DCI.CombineTo(N, Op); | ||
| DCI.CombineTo(OldV.getNode(), NewV); | ||
|
|
||
| Op = NV; // Return N so it doesn't get rechecked! | ||
| } | ||
|
|
||
| return Op; | ||
| } | ||
|
|
||
| // Convert (vXiY *ext(vXi1 bitcast(iX))) to extend_in_reg(broadcast(iX)). | ||
| // This is more or less the reverse of combineBitcastvxi1. | ||
| static SDValue combineToExtendBoolVectorInReg( | ||
|
|
@@ -60702,7 +60798,7 @@ SDValue X86TargetLowering::PerformDAGCombine(SDNode *N, | |
| case ISD::EXTRACT_VECTOR_ELT: | ||
| case X86ISD::PEXTRW: | ||
| case X86ISD::PEXTRB: | ||
| return combineExtractVectorElt(N, DAG, DCI, Subtarget); | ||
| return combineExtractVectorEltAndOperand(N, DAG, DCI, Subtarget); | ||
| case ISD::CONCAT_VECTORS: | ||
| return combineCONCAT_VECTORS(N, DAG, DCI, Subtarget); | ||
| case ISD::INSERT_SUBVECTOR: | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
SSE should only be declaring v816/v16i8 custom if we can.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
X86TTIImpl::shouldExpandReduction does a more thorough check. If we can't lower them properly they will get expanded to shuffles and we will never see these reductions in the DAG.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That sounds like the ExpandReductions pass isn't accounting for type legalisation properly?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Now I see what you mean. I added checks here to only do lowering for vector types that support PHMINPOSUW.