diff --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp index ea1435c3934be..5c686f2f3907f 100644 --- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp @@ -25100,26 +25100,26 @@ SDValue DAGCombiner::visitCONCAT_VECTORS(SDNode *N) { // Helper that peeks through INSERT_SUBVECTOR/CONCAT_VECTORS to find // if the subvector can be sourced for free. -static SDValue getSubVectorSrc(SDValue V, SDValue Index, EVT SubVT) { +static SDValue getSubVectorSrc(SDValue V, unsigned Index, EVT SubVT) { if (V.getOpcode() == ISD::INSERT_SUBVECTOR && - V.getOperand(1).getValueType() == SubVT && V.getOperand(2) == Index) { + V.getOperand(1).getValueType() == SubVT && + V.getConstantOperandAPInt(2) == Index) { return V.getOperand(1); } - auto *IndexC = dyn_cast(Index); - if (IndexC && V.getOpcode() == ISD::CONCAT_VECTORS && + if (V.getOpcode() == ISD::CONCAT_VECTORS && V.getOperand(0).getValueType() == SubVT && - (IndexC->getZExtValue() % SubVT.getVectorMinNumElements()) == 0) { - uint64_t SubIdx = IndexC->getZExtValue() / SubVT.getVectorMinNumElements(); + (Index % SubVT.getVectorMinNumElements()) == 0) { + uint64_t SubIdx = Index / SubVT.getVectorMinNumElements(); return V.getOperand(SubIdx); } return SDValue(); } -static SDValue narrowInsertExtractVectorBinOp(SDNode *Extract, +static SDValue narrowInsertExtractVectorBinOp(EVT SubVT, SDValue BinOp, + unsigned Index, const SDLoc &DL, SelectionDAG &DAG, bool LegalOperations) { const TargetLowering &TLI = DAG.getTargetLoweringInfo(); - SDValue BinOp = Extract->getOperand(0); unsigned BinOpcode = BinOp.getOpcode(); if (!TLI.isBinOp(BinOpcode) || BinOp->getNumValues() != 1) return SDValue(); @@ -25128,9 +25128,6 @@ static SDValue narrowInsertExtractVectorBinOp(SDNode *Extract, SDValue Bop0 = BinOp.getOperand(0), Bop1 = BinOp.getOperand(1); if (VecVT != Bop0.getValueType() || VecVT != Bop1.getValueType()) return SDValue(); - - SDValue Index = Extract->getOperand(1); - EVT SubVT = Extract->getValueType(0); if (!TLI.isOperationLegalOrCustom(BinOpcode, SubVT, LegalOperations)) return SDValue(); @@ -25146,29 +25143,25 @@ static SDValue narrowInsertExtractVectorBinOp(SDNode *Extract, // We are inserting both operands of the wide binop only to extract back // to the narrow vector size. Eliminate all of the insert/extract: // ext (binop (ins ?, X, Index), (ins ?, Y, Index)), Index --> binop X, Y - return DAG.getNode(BinOpcode, SDLoc(Extract), SubVT, Sub0, Sub1, - BinOp->getFlags()); + return DAG.getNode(BinOpcode, DL, SubVT, Sub0, Sub1, BinOp->getFlags()); } /// If we are extracting a subvector produced by a wide binary operator try /// to use a narrow binary operator and/or avoid concatenation and extraction. -static SDValue narrowExtractedVectorBinOp(SDNode *Extract, SelectionDAG &DAG, +static SDValue narrowExtractedVectorBinOp(EVT VT, SDValue Src, unsigned Index, + const SDLoc &DL, SelectionDAG &DAG, bool LegalOperations) { // TODO: Refactor with the caller (visitEXTRACT_SUBVECTOR), so we can share // some of these bailouts with other transforms. - if (SDValue V = narrowInsertExtractVectorBinOp(Extract, DAG, LegalOperations)) + if (SDValue V = narrowInsertExtractVectorBinOp(VT, Src, Index, DL, DAG, + LegalOperations)) return V; - // The extract index must be a constant, so we can map it to a concat operand. - auto *ExtractIndexC = dyn_cast(Extract->getOperand(1)); - if (!ExtractIndexC) - return SDValue(); - // We are looking for an optionally bitcasted wide vector binary operator // feeding an extract subvector. const TargetLowering &TLI = DAG.getTargetLoweringInfo(); - SDValue BinOp = peekThroughBitcasts(Extract->getOperand(0)); + SDValue BinOp = peekThroughBitcasts(Src); unsigned BOpcode = BinOp.getOpcode(); if (!TLI.isBinOp(BOpcode) || BinOp->getNumValues() != 1) return SDValue(); @@ -25190,9 +25183,7 @@ static SDValue narrowExtractedVectorBinOp(SDNode *Extract, SelectionDAG &DAG, if (!WideBVT.isFixedLengthVector()) return SDValue(); - EVT VT = Extract->getValueType(0); - unsigned ExtractIndex = ExtractIndexC->getZExtValue(); - assert(ExtractIndex % VT.getVectorNumElements() == 0 && + assert((Index % VT.getVectorNumElements()) == 0 && "Extract index is not a multiple of the vector length."); // Bail out if this is not a proper multiple width extraction. @@ -25219,12 +25210,11 @@ static SDValue narrowExtractedVectorBinOp(SDNode *Extract, SelectionDAG &DAG, // for concat ops. The narrow binop alone makes this transform profitable. // We can't just reuse the original extract index operand because we may have // bitcasted. - unsigned ConcatOpNum = ExtractIndex / VT.getVectorNumElements(); + unsigned ConcatOpNum = Index / VT.getVectorNumElements(); unsigned ExtBOIdx = ConcatOpNum * NarrowBVT.getVectorNumElements(); if (TLI.isExtractSubvectorCheap(NarrowBVT, WideBVT, ExtBOIdx) && - BinOp.hasOneUse() && Extract->getOperand(0)->hasOneUse()) { + BinOp.hasOneUse() && Src->hasOneUse()) { // extract (binop B0, B1), N --> binop (extract B0, N), (extract B1, N) - SDLoc DL(Extract); SDValue NewExtIndex = DAG.getVectorIdxConstant(ExtBOIdx, DL); SDValue X = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, NarrowBVT, BinOp.getOperand(0), NewExtIndex); @@ -25264,7 +25254,6 @@ static SDValue narrowExtractedVectorBinOp(SDNode *Extract, SelectionDAG &DAG, // extract (binop (concat X1, X2), (concat Y1, Y2)), N --> binop XN, YN // extract (binop (concat X1, X2), Y), N --> binop XN, (extract Y, IndexC) // extract (binop X, (concat Y1, Y2)), N --> binop (extract X, IndexC), YN - SDLoc DL(Extract); SDValue IndexC = DAG.getVectorIdxConstant(ExtBOIdx, DL); SDValue X = SubVecL ? DAG.getBitcast(NarrowBVT, SubVecL) : DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, NarrowBVT, @@ -25284,24 +25273,24 @@ static SDValue narrowExtractedVectorBinOp(SDNode *Extract, SelectionDAG &DAG, /// If we are extracting a subvector from a wide vector load, convert to a /// narrow load to eliminate the extraction: /// (extract_subvector (load wide vector)) --> (load narrow vector) -static SDValue narrowExtractedVectorLoad(SDNode *Extract, const SDLoc &DL, - SelectionDAG &DAG) { +static SDValue narrowExtractedVectorLoad(EVT VT, SDValue Src, unsigned Index, + const SDLoc &DL, SelectionDAG &DAG) { // TODO: Add support for big-endian. The offset calculation must be adjusted. if (DAG.getDataLayout().isBigEndian()) return SDValue(); - auto *Ld = dyn_cast(Extract->getOperand(0)); + auto *Ld = dyn_cast(Src); if (!Ld || !ISD::isNormalLoad(Ld) || !Ld->isSimple()) return SDValue(); - // Allow targets to opt-out. - EVT VT = Extract->getValueType(0); - // We can only create byte sized loads. if (!VT.isByteSized()) return SDValue(); - unsigned Index = Extract->getConstantOperandVal(1); + const TargetLowering &TLI = DAG.getTargetLoweringInfo(); + if (!TLI.isOperationLegalOrCustomOrPromote(ISD::LOAD, VT)) + return SDValue(); + unsigned NumElts = VT.getVectorMinNumElements(); // A fixed length vector being extracted from a scalable vector // may not be any *smaller* than the scalable one. @@ -25319,7 +25308,6 @@ static SDValue narrowExtractedVectorLoad(SDNode *Extract, const SDLoc &DL, if (Offset.isFixed()) ByteOffset = Offset.getFixedValue(); - const TargetLowering &TLI = DAG.getTargetLoweringInfo(); if (!TLI.shouldReduceLoadWidth(Ld, Ld->getExtensionType(), VT, ByteOffset)) return SDValue(); @@ -25350,23 +25338,18 @@ static SDValue narrowExtractedVectorLoad(SDNode *Extract, const SDLoc &DL, /// iff it is legal and profitable to do so. Notably, the trimmed mask /// (containing only the elements that are extracted) /// must reference at most two subvectors. -static SDValue foldExtractSubvectorFromShuffleVector(SDNode *N, +static SDValue foldExtractSubvectorFromShuffleVector(EVT NarrowVT, SDValue Src, + unsigned Index, + const SDLoc &DL, SelectionDAG &DAG, - const TargetLowering &TLI, bool LegalOperations) { - assert(N->getOpcode() == ISD::EXTRACT_SUBVECTOR && - "Must only be called on EXTRACT_SUBVECTOR's"); - - SDValue N0 = N->getOperand(0); - // Only deal with non-scalable vectors. - EVT NarrowVT = N->getValueType(0); - EVT WideVT = N0.getValueType(); + EVT WideVT = Src.getValueType(); if (!NarrowVT.isFixedLengthVector() || !WideVT.isFixedLengthVector()) return SDValue(); // The operand must be a shufflevector. - auto *WideShuffleVector = dyn_cast(N0); + auto *WideShuffleVector = dyn_cast(Src); if (!WideShuffleVector) return SDValue(); @@ -25375,13 +25358,13 @@ static SDValue foldExtractSubvectorFromShuffleVector(SDNode *N, return SDValue(); // And the narrow shufflevector that we'll form must be legal. + const TargetLowering &TLI = DAG.getTargetLoweringInfo(); if (LegalOperations && !TLI.isOperationLegalOrCustom(ISD::VECTOR_SHUFFLE, NarrowVT)) return SDValue(); - uint64_t FirstExtractedEltIdx = N->getConstantOperandVal(1); int NumEltsExtracted = NarrowVT.getVectorNumElements(); - assert((FirstExtractedEltIdx % NumEltsExtracted) == 0 && + assert((Index % NumEltsExtracted) == 0 && "Extract index is not a multiple of the output vector length."); int WideNumElts = WideVT.getVectorNumElements(); @@ -25392,8 +25375,7 @@ static SDValue foldExtractSubvectorFromShuffleVector(SDNode *N, DemandedSubvectors; // Try to decode the wide mask into narrow mask from at most two subvectors. - for (int M : WideShuffleVector->getMask().slice(FirstExtractedEltIdx, - NumEltsExtracted)) { + for (int M : WideShuffleVector->getMask().slice(Index, NumEltsExtracted)) { assert((M >= -1) && (M < (2 * WideNumElts)) && "Out-of-bounds shuffle mask?"); @@ -25476,8 +25458,6 @@ static SDValue foldExtractSubvectorFromShuffleVector(SDNode *N, !TLI.isShuffleMaskLegal(NewMask, NarrowVT)) return SDValue(); - SDLoc DL(N); - SmallVector NewOps; for (const std::pair &DemandedSubvector : DemandedSubvectors) { @@ -25507,9 +25487,8 @@ SDValue DAGCombiner::visitEXTRACT_SUBVECTOR(SDNode *N) { if (V.isUndef()) return DAG.getUNDEF(NVT); - if (TLI.isOperationLegalOrCustomOrPromote(ISD::LOAD, NVT)) - if (SDValue NarrowLoad = narrowExtractedVectorLoad(N, DL, DAG)) - return NarrowLoad; + if (SDValue NarrowLoad = narrowExtractedVectorLoad(NVT, V, ExtIdx, DL, DAG)) + return NarrowLoad; // Combine an extract of an extract into a single extract_subvector. // ext (ext X, C), 0 --> ext X, C @@ -25631,9 +25610,13 @@ SDValue DAGCombiner::visitEXTRACT_SUBVECTOR(SDNode *N) { } } - if (SDValue V = - foldExtractSubvectorFromShuffleVector(N, DAG, TLI, LegalOperations)) - return V; + if (SDValue Shuffle = foldExtractSubvectorFromShuffleVector( + NVT, V, ExtIdx, DL, DAG, LegalOperations)) + return Shuffle; + + if (SDValue NarrowBOp = + narrowExtractedVectorBinOp(NVT, V, ExtIdx, DL, DAG, LegalOperations)) + return NarrowBOp; V = peekThroughBitcasts(V); @@ -25694,9 +25677,6 @@ SDValue DAGCombiner::visitEXTRACT_SUBVECTOR(SDNode *N) { } } - if (SDValue NarrowBOp = narrowExtractedVectorBinOp(N, DAG, LegalOperations)) - return NarrowBOp; - // If only EXTRACT_SUBVECTOR nodes use the source vector we can // simplify it based on the (valid) extractions. if (!V.getValueType().isScalableVector() &&