@@ -39659,13 +39659,6 @@ static bool matchBinaryPermuteShuffle(
3965939659 return false;
3966039660}
3966139661
39662- static SDValue combineX86ShuffleChainWithExtract(
39663- ArrayRef<SDValue> Inputs, unsigned RootOpcode, MVT RootVT,
39664- ArrayRef<int> BaseMask, int Depth, ArrayRef<const SDNode *> SrcNodes,
39665- bool AllowVariableCrossLaneMask, bool AllowVariablePerLaneMask,
39666- bool IsMaskedShuffle, SelectionDAG &DAG, const SDLoc &DL,
39667- const X86Subtarget &Subtarget);
39668-
3966939662/// Combine an arbitrary chain of shuffles into a single instruction if
3967039663/// possible.
3967139664///
@@ -40210,14 +40203,6 @@ static SDValue combineX86ShuffleChain(
4021040203 return DAG.getBitcast(RootVT, Res);
4021140204 }
4021240205
40213- // If that failed and either input is extracted then try to combine as a
40214- // shuffle with the larger type.
40215- if (SDValue WideShuffle = combineX86ShuffleChainWithExtract(
40216- Inputs, RootOpc, RootVT, BaseMask, Depth, SrcNodes,
40217- AllowVariableCrossLaneMask, AllowVariablePerLaneMask,
40218- IsMaskedShuffle, DAG, DL, Subtarget))
40219- return WideShuffle;
40220-
4022140206 // If we have a dual input lane-crossing shuffle then lower to VPERMV3,
4022240207 // (non-VLX will pad to 512-bit shuffles).
4022340208 if (AllowVariableCrossLaneMask && !MaskContainsZeros &&
@@ -40383,14 +40368,6 @@ static SDValue combineX86ShuffleChain(
4038340368 return DAG.getBitcast(RootVT, Res);
4038440369 }
4038540370
40386- // If that failed and either input is extracted then try to combine as a
40387- // shuffle with the larger type.
40388- if (SDValue WideShuffle = combineX86ShuffleChainWithExtract(
40389- Inputs, RootOpc, RootVT, BaseMask, Depth, SrcNodes,
40390- AllowVariableCrossLaneMask, AllowVariablePerLaneMask, IsMaskedShuffle,
40391- DAG, DL, Subtarget))
40392- return WideShuffle;
40393-
4039440371 // If we have a dual input shuffle then lower to VPERMV3,
4039540372 // (non-VLX will pad to 512-bit shuffles)
4039640373 if (!UnaryShuffle && AllowVariablePerLaneMask && !MaskContainsZeros &&
@@ -40416,148 +40393,6 @@ static SDValue combineX86ShuffleChain(
4041640393 return SDValue();
4041740394}
4041840395
40419- // Combine an arbitrary chain of shuffles + extract_subvectors into a single
40420- // instruction if possible.
40421- //
40422- // Wrapper for combineX86ShuffleChain that extends the shuffle mask to a larger
40423- // type size to attempt to combine:
40424- // shuffle(extract_subvector(x,c1),extract_subvector(y,c2),m1)
40425- // -->
40426- // extract_subvector(shuffle(x,y,m2),0)
40427- static SDValue combineX86ShuffleChainWithExtract(
40428- ArrayRef<SDValue> Inputs, unsigned RootOpcode, MVT RootVT,
40429- ArrayRef<int> BaseMask, int Depth, ArrayRef<const SDNode *> SrcNodes,
40430- bool AllowVariableCrossLaneMask, bool AllowVariablePerLaneMask,
40431- bool IsMaskedShuffle, SelectionDAG &DAG, const SDLoc &DL,
40432- const X86Subtarget &Subtarget) {
40433- unsigned NumMaskElts = BaseMask.size();
40434- unsigned NumInputs = Inputs.size();
40435- if (NumInputs == 0)
40436- return SDValue();
40437-
40438- unsigned RootSizeInBits = RootVT.getSizeInBits();
40439- unsigned RootEltSizeInBits = RootSizeInBits / NumMaskElts;
40440- assert((RootSizeInBits % NumMaskElts) == 0 && "Unexpected root shuffle mask");
40441-
40442- // Peek through subvectors to find widest legal vector.
40443- // TODO: Handle ISD::TRUNCATE
40444- unsigned WideSizeInBits = RootSizeInBits;
40445- for (SDValue Input : Inputs) {
40446- Input = peekThroughBitcasts(Input);
40447- while (1) {
40448- if (Input.getOpcode() == ISD::EXTRACT_SUBVECTOR) {
40449- Input = peekThroughBitcasts(Input.getOperand(0));
40450- continue;
40451- }
40452- if (Input.getOpcode() == ISD::INSERT_SUBVECTOR &&
40453- Input.getOperand(0).isUndef()) {
40454- Input = peekThroughBitcasts(Input.getOperand(1));
40455- continue;
40456- }
40457- break;
40458- }
40459- if (DAG.getTargetLoweringInfo().isTypeLegal(Input.getValueType()) &&
40460- WideSizeInBits < Input.getValueSizeInBits())
40461- WideSizeInBits = Input.getValueSizeInBits();
40462- }
40463-
40464- // Bail if we fail to find a source larger than the existing root.
40465- if (WideSizeInBits <= RootSizeInBits ||
40466- (WideSizeInBits % RootSizeInBits) != 0)
40467- return SDValue();
40468-
40469- // Create new mask for larger type.
40470- SmallVector<int, 64> WideMask;
40471- growShuffleMask(BaseMask, WideMask, RootSizeInBits, WideSizeInBits);
40472-
40473- // Attempt to peek through inputs and adjust mask when we extract from an
40474- // upper subvector.
40475- int AdjustedMasks = 0;
40476- SmallVector<SDValue, 4> WideInputs(Inputs);
40477- for (unsigned I = 0; I != NumInputs; ++I) {
40478- SDValue &Input = WideInputs[I];
40479- Input = peekThroughBitcasts(Input);
40480- while (1) {
40481- if (Input.getOpcode() == ISD::EXTRACT_SUBVECTOR &&
40482- Input.getOperand(0).getValueSizeInBits() <= WideSizeInBits) {
40483- uint64_t Idx = Input.getConstantOperandVal(1);
40484- if (Idx != 0) {
40485- ++AdjustedMasks;
40486- unsigned InputEltSizeInBits = Input.getScalarValueSizeInBits();
40487- Idx = (Idx * InputEltSizeInBits) / RootEltSizeInBits;
40488-
40489- int lo = I * WideMask.size();
40490- int hi = (I + 1) * WideMask.size();
40491- for (int &M : WideMask)
40492- if (lo <= M && M < hi)
40493- M += Idx;
40494- }
40495- Input = peekThroughBitcasts(Input.getOperand(0));
40496- continue;
40497- }
40498- // TODO: Handle insertions into upper subvectors.
40499- if (Input.getOpcode() == ISD::INSERT_SUBVECTOR &&
40500- Input.getOperand(0).isUndef() &&
40501- isNullConstant(Input.getOperand(2))) {
40502- Input = peekThroughBitcasts(Input.getOperand(1));
40503- continue;
40504- }
40505- break;
40506- }
40507- }
40508-
40509- // Remove unused/repeated shuffle source ops.
40510- resolveTargetShuffleInputsAndMask(WideInputs, WideMask);
40511- assert(!WideInputs.empty() && "Shuffle with no inputs detected");
40512-
40513- // Bail if we're always extracting from the lowest subvectors,
40514- // combineX86ShuffleChain should match this for the current width, or the
40515- // shuffle still references too many inputs.
40516- if (AdjustedMasks == 0 || WideInputs.size() > 2)
40517- return SDValue();
40518-
40519- // Minor canonicalization of the accumulated shuffle mask to make it easier
40520- // to match below. All this does is detect masks with sequential pairs of
40521- // elements, and shrink them to the half-width mask. It does this in a loop
40522- // so it will reduce the size of the mask to the minimal width mask which
40523- // performs an equivalent shuffle.
40524- while (WideMask.size() > 1) {
40525- SmallVector<int, 64> WidenedMask;
40526- if (!canWidenShuffleElements(WideMask, WidenedMask))
40527- break;
40528- WideMask = std::move(WidenedMask);
40529- }
40530-
40531- // Canonicalization of binary shuffle masks to improve pattern matching by
40532- // commuting the inputs.
40533- if (WideInputs.size() == 2 && canonicalizeShuffleMaskWithCommute(WideMask)) {
40534- ShuffleVectorSDNode::commuteMask(WideMask);
40535- std::swap(WideInputs[0], WideInputs[1]);
40536- }
40537-
40538- // Increase depth for every upper subvector we've peeked through.
40539- Depth += AdjustedMasks;
40540-
40541- // Attempt to combine wider chain.
40542- // TODO: Can we use a better Root?
40543- SDValue WideRoot = WideInputs.front().getValueSizeInBits() >
40544- WideInputs.back().getValueSizeInBits()
40545- ? WideInputs.front()
40546- : WideInputs.back();
40547- assert(WideRoot.getValueSizeInBits() == WideSizeInBits &&
40548- "WideRootSize mismatch");
40549-
40550- if (SDValue WideShuffle = combineX86ShuffleChain(
40551- WideInputs, RootOpcode, WideRoot.getSimpleValueType(), WideMask,
40552- Depth, SrcNodes, AllowVariableCrossLaneMask, AllowVariablePerLaneMask,
40553- IsMaskedShuffle, DAG, SDLoc(WideRoot), Subtarget)) {
40554- WideShuffle = extractSubVector(WideShuffle, 0, DAG, DL, RootSizeInBits);
40555- return DAG.getBitcast(RootVT, WideShuffle);
40556- }
40557-
40558- return SDValue();
40559- }
40560-
4056140396// Canonicalize the combined shuffle mask chain with horizontal ops.
4056240397// NOTE: This may update the Ops and Mask.
4056340398static SDValue canonicalizeShuffleMaskWithHorizOp(
@@ -40970,6 +40805,54 @@ static SDValue combineX86ShufflesRecursively(
4097040805 OpMask.assign(NumElts, SM_SentinelUndef);
4097140806 std::iota(OpMask.begin(), OpMask.end(), ExtractIdx);
4097240807 OpZero = OpUndef = APInt::getZero(NumElts);
40808+ } else if (Op.getOpcode() == ISD::EXTRACT_SUBVECTOR &&
40809+ TLI.isTypeLegal(Op.getOperand(0).getValueType()) &&
40810+ Op.getOperand(0).getValueSizeInBits() > RootSizeInBits &&
40811+ (Op.getOperand(0).getValueSizeInBits() % RootSizeInBits) == 0) {
40812+ // Extracting from vector larger than RootVT - scale the mask and attempt to
40813+ // fold the shuffle with the larger root type, then extract the lower
40814+ // elements.
40815+ unsigned NewRootSizeInBits = Op.getOperand(0).getValueSizeInBits();
40816+ unsigned Scale = NewRootSizeInBits / RootSizeInBits;
40817+ MVT NewRootVT = MVT::getVectorVT(RootVT.getScalarType(),
40818+ Scale * RootVT.getVectorNumElements());
40819+ SmallVector<int, 64> NewRootMask;
40820+ growShuffleMask(RootMask, NewRootMask, RootSizeInBits, NewRootSizeInBits);
40821+ // If we're using the lowest subvector, just replace it directly in the src
40822+ // ops/nodes.
40823+ SmallVector<SDValue, 16> NewSrcOps(SrcOps);
40824+ SmallVector<const SDNode *, 16> NewSrcNodes(SrcNodes);
40825+ if (isNullConstant(Op.getOperand(1))) {
40826+ NewSrcOps[SrcOpIndex] = Op.getOperand(0);
40827+ NewSrcNodes.push_back(Op.getNode());
40828+ }
40829+ // Don't increase the combine depth - we're effectively working on the same
40830+ // nodes, just with a wider type.
40831+ if (SDValue WideShuffle = combineX86ShufflesRecursively(
40832+ NewSrcOps, SrcOpIndex, RootOpc, NewRootVT, NewRootMask, NewSrcNodes,
40833+ Depth, MaxDepth, AllowVariableCrossLaneMask,
40834+ AllowVariablePerLaneMask, IsMaskedShuffle, DAG, DL, Subtarget))
40835+ return DAG.getBitcast(
40836+ RootVT, extractSubVector(WideShuffle, 0, DAG, DL, RootSizeInBits));
40837+ return SDValue();
40838+ } else if (Op.getOpcode() == ISD::INSERT_SUBVECTOR &&
40839+ Op.getOperand(1).getOpcode() == ISD::EXTRACT_SUBVECTOR &&
40840+ Op.getOperand(1).getOperand(0).getValueSizeInBits() >
40841+ RootSizeInBits) {
40842+ // If we're inserting an subvector extracted from a vector larger than
40843+ // RootVT, then combine the insert_subvector as a shuffle, the
40844+ // extract_subvector will be folded in a later recursion.
40845+ SDValue BaseVec = Op.getOperand(0);
40846+ SDValue SubVec = Op.getOperand(1);
40847+ int InsertIdx = Op.getConstantOperandVal(2);
40848+ unsigned NumBaseElts = VT.getVectorNumElements();
40849+ unsigned NumSubElts = SubVec.getValueType().getVectorNumElements();
40850+ OpInputs.assign({BaseVec, SubVec});
40851+ OpMask.resize(NumBaseElts);
40852+ std::iota(OpMask.begin(), OpMask.end(), 0);
40853+ std::iota(OpMask.begin() + InsertIdx,
40854+ OpMask.begin() + InsertIdx + NumSubElts, NumBaseElts);
40855+ OpZero = OpUndef = APInt::getZero(NumBaseElts);
4097340856 } else {
4097440857 return SDValue();
4097540858 }
@@ -41316,25 +41199,9 @@ static SDValue combineX86ShufflesRecursively(
4131641199 AllowVariableCrossLaneMask, AllowVariablePerLaneMask,
4131741200 IsMaskedShuffle, DAG, DL, Subtarget))
4131841201 return Shuffle;
41319-
41320- // If all the operands come from the same larger vector, fallthrough and try
41321- // to use combineX86ShuffleChainWithExtract.
41322- SDValue LHS = peekThroughBitcasts(Ops.front());
41323- SDValue RHS = peekThroughBitcasts(Ops.back());
41324- if (Ops.size() != 2 || !Subtarget.hasAVX2() || RootSizeInBits != 128 ||
41325- (RootSizeInBits / Mask.size()) != 64 ||
41326- LHS.getOpcode() != ISD::EXTRACT_SUBVECTOR ||
41327- RHS.getOpcode() != ISD::EXTRACT_SUBVECTOR ||
41328- LHS.getOperand(0) != RHS.getOperand(0))
41329- return SDValue();
4133041202 }
4133141203
41332- // If that failed and any input is extracted then try to combine as a
41333- // shuffle with the larger type.
41334- return combineX86ShuffleChainWithExtract(
41335- Ops, RootOpc, RootVT, Mask, Depth, CombinedNodes,
41336- AllowVariableCrossLaneMask, AllowVariablePerLaneMask, IsMaskedShuffle,
41337- DAG, DL, Subtarget);
41204+ return SDValue();
4133841205}
4133941206
4134041207/// Helper entry wrapper to combineX86ShufflesRecursively.
@@ -43947,6 +43814,7 @@ bool X86TargetLowering::SimplifyDemandedVectorEltsForTargetNode(
4394743814 case X86ISD::UNPCKL:
4394843815 case X86ISD::UNPCKH:
4394943816 case X86ISD::BLENDI:
43817+ case X86ISD::SHUFP:
4395043818 // Integer ops.
4395143819 case X86ISD::PACKSS:
4395243820 case X86ISD::PACKUS:
0 commit comments