@@ -39653,13 +39653,6 @@ static bool matchBinaryPermuteShuffle(
3965339653 return false;
3965439654}
3965539655
39656- static SDValue combineX86ShuffleChainWithExtract(
39657- ArrayRef<SDValue> Inputs, unsigned RootOpcode, MVT RootVT,
39658- ArrayRef<int> BaseMask, int Depth, ArrayRef<const SDNode *> SrcNodes,
39659- bool AllowVariableCrossLaneMask, bool AllowVariablePerLaneMask,
39660- bool IsMaskedShuffle, SelectionDAG &DAG, const SDLoc &DL,
39661- const X86Subtarget &Subtarget);
39662-
3966339656/// Combine an arbitrary chain of shuffles into a single instruction if
3966439657/// possible.
3966539658///
@@ -40203,14 +40196,6 @@ static SDValue combineX86ShuffleChain(
4020340196 return DAG.getBitcast(RootVT, Res);
4020440197 }
4020540198
40206- // If that failed and either input is extracted then try to combine as a
40207- // shuffle with the larger type.
40208- if (SDValue WideShuffle = combineX86ShuffleChainWithExtract(
40209- Inputs, RootOpc, RootVT, BaseMask, Depth, SrcNodes,
40210- AllowVariableCrossLaneMask, AllowVariablePerLaneMask,
40211- IsMaskedShuffle, DAG, DL, Subtarget))
40212- return WideShuffle;
40213-
4021440199 // If we have a dual input lane-crossing shuffle then lower to VPERMV3,
4021540200 // (non-VLX will pad to 512-bit shuffles).
4021640201 if (AllowVariableCrossLaneMask && !MaskContainsZeros &&
@@ -40376,14 +40361,6 @@ static SDValue combineX86ShuffleChain(
4037640361 return DAG.getBitcast(RootVT, Res);
4037740362 }
4037840363
40379- // If that failed and either input is extracted then try to combine as a
40380- // shuffle with the larger type.
40381- if (SDValue WideShuffle = combineX86ShuffleChainWithExtract(
40382- Inputs, RootOpc, RootVT, BaseMask, Depth, SrcNodes,
40383- AllowVariableCrossLaneMask, AllowVariablePerLaneMask, IsMaskedShuffle,
40384- DAG, DL, Subtarget))
40385- return WideShuffle;
40386-
4038740364 // If we have a dual input shuffle then lower to VPERMV3,
4038840365 // (non-VLX will pad to 512-bit shuffles)
4038940366 if (!UnaryShuffle && AllowVariablePerLaneMask && !MaskContainsZeros &&
@@ -40409,154 +40386,6 @@ static SDValue combineX86ShuffleChain(
4040940386 return SDValue();
4041040387}
4041140388
40412- // Combine an arbitrary chain of shuffles + extract_subvectors into a single
40413- // instruction if possible.
40414- //
40415- // Wrapper for combineX86ShuffleChain that extends the shuffle mask to a larger
40416- // type size to attempt to combine:
40417- // shuffle(extract_subvector(x,c1),extract_subvector(y,c2),m1)
40418- // -->
40419- // extract_subvector(shuffle(x,y,m2),0)
40420- static SDValue combineX86ShuffleChainWithExtract(
40421- ArrayRef<SDValue> Inputs, unsigned RootOpcode, MVT RootVT,
40422- ArrayRef<int> BaseMask, int Depth, ArrayRef<const SDNode *> SrcNodes,
40423- bool AllowVariableCrossLaneMask, bool AllowVariablePerLaneMask,
40424- bool IsMaskedShuffle, SelectionDAG &DAG, const SDLoc &DL,
40425- const X86Subtarget &Subtarget) {
40426- unsigned NumMaskElts = BaseMask.size();
40427- unsigned NumInputs = Inputs.size();
40428- if (NumInputs == 0)
40429- return SDValue();
40430-
40431- unsigned RootSizeInBits = RootVT.getSizeInBits();
40432- unsigned RootEltSizeInBits = RootSizeInBits / NumMaskElts;
40433- assert((RootSizeInBits % NumMaskElts) == 0 && "Unexpected root shuffle mask");
40434-
40435- // Peek through subvectors to find widest legal vector.
40436- // TODO: Handle ISD::TRUNCATE
40437- unsigned WideSizeInBits = RootSizeInBits;
40438- for (SDValue Input : Inputs) {
40439- Input = peekThroughBitcasts(Input);
40440- while (1) {
40441- if (Input.getOpcode() == ISD::EXTRACT_SUBVECTOR) {
40442- Input = peekThroughBitcasts(Input.getOperand(0));
40443- continue;
40444- }
40445- if (Input.getOpcode() == ISD::INSERT_SUBVECTOR &&
40446- Input.getOperand(0).isUndef()) {
40447- Input = peekThroughBitcasts(Input.getOperand(1));
40448- continue;
40449- }
40450- break;
40451- }
40452- if (DAG.getTargetLoweringInfo().isTypeLegal(Input.getValueType()) &&
40453- WideSizeInBits < Input.getValueSizeInBits())
40454- WideSizeInBits = Input.getValueSizeInBits();
40455- }
40456-
40457- // Bail if we fail to find a source larger than the existing root.
40458- unsigned Scale = WideSizeInBits / RootSizeInBits;
40459- if (WideSizeInBits <= RootSizeInBits ||
40460- (WideSizeInBits % RootSizeInBits) != 0)
40461- return SDValue();
40462-
40463- // Create new mask for larger type.
40464- SmallVector<int, 64> WideMask(BaseMask);
40465- for (int &M : WideMask) {
40466- if (M < 0)
40467- continue;
40468- M = (M % NumMaskElts) + ((M / NumMaskElts) * Scale * NumMaskElts);
40469- }
40470- WideMask.append((Scale - 1) * NumMaskElts, SM_SentinelUndef);
40471-
40472- // Attempt to peek through inputs and adjust mask when we extract from an
40473- // upper subvector.
40474- int AdjustedMasks = 0;
40475- SmallVector<SDValue, 4> WideInputs(Inputs);
40476- for (unsigned I = 0; I != NumInputs; ++I) {
40477- SDValue &Input = WideInputs[I];
40478- Input = peekThroughBitcasts(Input);
40479- while (1) {
40480- if (Input.getOpcode() == ISD::EXTRACT_SUBVECTOR &&
40481- Input.getOperand(0).getValueSizeInBits() <= WideSizeInBits) {
40482- uint64_t Idx = Input.getConstantOperandVal(1);
40483- if (Idx != 0) {
40484- ++AdjustedMasks;
40485- unsigned InputEltSizeInBits = Input.getScalarValueSizeInBits();
40486- Idx = (Idx * InputEltSizeInBits) / RootEltSizeInBits;
40487-
40488- int lo = I * WideMask.size();
40489- int hi = (I + 1) * WideMask.size();
40490- for (int &M : WideMask)
40491- if (lo <= M && M < hi)
40492- M += Idx;
40493- }
40494- Input = peekThroughBitcasts(Input.getOperand(0));
40495- continue;
40496- }
40497- // TODO: Handle insertions into upper subvectors.
40498- if (Input.getOpcode() == ISD::INSERT_SUBVECTOR &&
40499- Input.getOperand(0).isUndef() &&
40500- isNullConstant(Input.getOperand(2))) {
40501- Input = peekThroughBitcasts(Input.getOperand(1));
40502- continue;
40503- }
40504- break;
40505- }
40506- }
40507-
40508- // Remove unused/repeated shuffle source ops.
40509- resolveTargetShuffleInputsAndMask(WideInputs, WideMask);
40510- assert(!WideInputs.empty() && "Shuffle with no inputs detected");
40511-
40512- // Bail if we're always extracting from the lowest subvectors,
40513- // combineX86ShuffleChain should match this for the current width, or the
40514- // shuffle still references too many inputs.
40515- if (AdjustedMasks == 0 || WideInputs.size() > 2)
40516- return SDValue();
40517-
40518- // Minor canonicalization of the accumulated shuffle mask to make it easier
40519- // to match below. All this does is detect masks with sequential pairs of
40520- // elements, and shrink them to the half-width mask. It does this in a loop
40521- // so it will reduce the size of the mask to the minimal width mask which
40522- // performs an equivalent shuffle.
40523- while (WideMask.size() > 1) {
40524- SmallVector<int, 64> WidenedMask;
40525- if (!canWidenShuffleElements(WideMask, WidenedMask))
40526- break;
40527- WideMask = std::move(WidenedMask);
40528- }
40529-
40530- // Canonicalization of binary shuffle masks to improve pattern matching by
40531- // commuting the inputs.
40532- if (WideInputs.size() == 2 && canonicalizeShuffleMaskWithCommute(WideMask)) {
40533- ShuffleVectorSDNode::commuteMask(WideMask);
40534- std::swap(WideInputs[0], WideInputs[1]);
40535- }
40536-
40537- // Increase depth for every upper subvector we've peeked through.
40538- Depth += AdjustedMasks;
40539-
40540- // Attempt to combine wider chain.
40541- // TODO: Can we use a better Root?
40542- SDValue WideRoot = WideInputs.front().getValueSizeInBits() >
40543- WideInputs.back().getValueSizeInBits()
40544- ? WideInputs.front()
40545- : WideInputs.back();
40546- assert(WideRoot.getValueSizeInBits() == WideSizeInBits &&
40547- "WideRootSize mismatch");
40548-
40549- if (SDValue WideShuffle = combineX86ShuffleChain(
40550- WideInputs, RootOpcode, WideRoot.getSimpleValueType(), WideMask,
40551- Depth, SrcNodes, AllowVariableCrossLaneMask, AllowVariablePerLaneMask,
40552- IsMaskedShuffle, DAG, SDLoc(WideRoot), Subtarget)) {
40553- WideShuffle = extractSubVector(WideShuffle, 0, DAG, DL, RootSizeInBits);
40554- return DAG.getBitcast(RootVT, WideShuffle);
40555- }
40556-
40557- return SDValue();
40558- }
40559-
4056040389// Canonicalize the combined shuffle mask chain with horizontal ops.
4056140390// NOTE: This may update the Ops and Mask.
4056240391static SDValue canonicalizeShuffleMaskWithHorizOp(
@@ -40969,6 +40798,57 @@ static SDValue combineX86ShufflesRecursively(
4096940798 OpMask.assign(NumElts, SM_SentinelUndef);
4097040799 std::iota(OpMask.begin(), OpMask.end(), ExtractIdx);
4097140800 OpZero = OpUndef = APInt::getZero(NumElts);
40801+ } else if (Op.getOpcode() == ISD::EXTRACT_SUBVECTOR &&
40802+ TLI.isTypeLegal(Op.getOperand(0).getValueType()) &&
40803+ Op.getOperand(0).getValueSizeInBits() > RootSizeInBits &&
40804+ (Op.getOperand(0).getValueSizeInBits() % RootSizeInBits) == 0) {
40805+ // Extracting from vector larger than RootVT - scale the mask and attempt to
40806+ // fold the shuffle with the larger root type, then extract the lower
40807+ // elements.
40808+ unsigned Scale = Op.getOperand(0).getValueSizeInBits() / RootSizeInBits;
40809+ MVT NewRootVT = MVT::getVectorVT(RootVT.getScalarType(),
40810+ Scale * RootVT.getVectorNumElements());
40811+ SmallVector<int, 64> NewRootMask(RootMask);
40812+ NewRootMask.append((Scale - 1) * RootMask.size(), SM_SentinelUndef);
40813+ for (int &M : NewRootMask)
40814+ if (0 <= M)
40815+ M = (M % RootMask.size()) +
40816+ ((M / RootMask.size()) * NewRootMask.size());
40817+ // If we're using the lowest subvector, just replace it directly in the src
40818+ // ops/nodes.
40819+ SmallVector<SDValue, 16> NewSrcOps(SrcOps);
40820+ SmallVector<const SDNode *, 16> NewSrcNodes(SrcNodes);
40821+ if (isNullConstant(Op.getOperand(1))) {
40822+ NewSrcOps[SrcOpIndex] = Op.getOperand(0);
40823+ NewSrcNodes.push_back(Op.getNode());
40824+ }
40825+ // Don't increase the combine depth - we're effectively working on the same
40826+ // nodes, just with a wider type.
40827+ if (SDValue WideShuffle = combineX86ShufflesRecursively(
40828+ NewSrcOps, SrcOpIndex, RootOpc, NewRootVT, NewRootMask, NewSrcNodes,
40829+ Depth, MaxDepth, AllowVariableCrossLaneMask,
40830+ AllowVariablePerLaneMask, IsMaskedShuffle, DAG, DL, Subtarget))
40831+ return DAG.getBitcast(
40832+ RootVT, extractSubVector(WideShuffle, 0, DAG, DL, RootSizeInBits));
40833+ return SDValue();
40834+ } else if (Op.getOpcode() == ISD::INSERT_SUBVECTOR &&
40835+ Op.getOperand(1).getOpcode() == ISD::EXTRACT_SUBVECTOR &&
40836+ Op.getOperand(1).getOperand(0).getValueSizeInBits() >
40837+ RootSizeInBits) {
40838+ // If we're inserting an subvector extracted from a vector larger than
40839+ // RootVT, then combine the insert_subvector as a shuffle, the
40840+ // extract_subvector will be folded in a later recursion.
40841+ SDValue BaseVec = Op.getOperand(0);
40842+ SDValue SubVec = Op.getOperand(1);
40843+ int InsertIdx = Op.getConstantOperandVal(2);
40844+ unsigned NumBaseElts = VT.getVectorNumElements();
40845+ unsigned NumSubElts = SubVec.getValueType().getVectorNumElements();
40846+ OpInputs.assign({BaseVec, SubVec});
40847+ OpMask.assign(NumBaseElts, SM_SentinelUndef);
40848+ std::iota(OpMask.begin(), OpMask.end(), 0);
40849+ std::iota(OpMask.begin() + InsertIdx,
40850+ OpMask.begin() + InsertIdx + NumSubElts, NumBaseElts);
40851+ OpZero = OpUndef = APInt::getZero(NumBaseElts);
4097240852 } else {
4097340853 return SDValue();
4097440854 }
@@ -41324,12 +41204,7 @@ static SDValue combineX86ShufflesRecursively(
4132441204 return SDValue();
4132541205 }
4132641206
41327- // If that failed and any input is extracted then try to combine as a
41328- // shuffle with the larger type.
41329- return combineX86ShuffleChainWithExtract(
41330- Ops, RootOpc, RootVT, Mask, Depth, CombinedNodes,
41331- AllowVariableCrossLaneMask, AllowVariablePerLaneMask, IsMaskedShuffle,
41332- DAG, DL, Subtarget);
41207+ return SDValue();
4133341208}
4133441209
4133541210/// Helper entry wrapper to combineX86ShufflesRecursively.
@@ -43866,6 +43741,7 @@ bool X86TargetLowering::SimplifyDemandedVectorEltsForTargetNode(
4386643741 case X86ISD::UNPCKL:
4386743742 case X86ISD::UNPCKH:
4386843743 case X86ISD::BLENDI:
43744+ case X86ISD::SHUFP:
4386943745 // Integer ops.
4387043746 case X86ISD::PACKSS:
4387143747 case X86ISD::PACKUS:
0 commit comments