@@ -47690,12 +47690,47 @@ static SDValue combineSelect(SDNode *N, SelectionDAG &DAG,
4769047690 return V;
4769147691
4769247692 if (N->getOpcode() == ISD::VSELECT || N->getOpcode() == X86ISD::BLENDV) {
47693- SmallVector<int, 64> Mask ;
47694- if (createShuffleMaskFromVSELECT(Mask , Cond,
47693+ SmallVector<int, 64> CondMask ;
47694+ if (createShuffleMaskFromVSELECT(CondMask , Cond,
4769547695 N->getOpcode() == X86ISD::BLENDV)) {
4769647696 // Convert vselects with constant condition into shuffles.
4769747697 if (DCI.isBeforeLegalizeOps())
47698- return DAG.getVectorShuffle(VT, DL, LHS, RHS, Mask);
47698+ return DAG.getVectorShuffle(VT, DL, LHS, RHS, CondMask);
47699+
47700+ // fold vselect(cond, pshufb(x), pshufb(y)) -> or (pshufb(x), pshufb(y))
47701+ // by forcing the unselected elements to zero.
47702+ // TODO: Can we handle more shuffles with this?
47703+ if (LHS.hasOneUse() && RHS.hasOneUse()) {
47704+ SmallVector<SDValue, 1> LHSOps, RHSOps;
47705+ SmallVector<int, 64> LHSMask, RHSMask, ByteMask;
47706+ SDValue LHSShuf = peekThroughOneUseBitcasts(LHS);
47707+ SDValue RHSShuf = peekThroughOneUseBitcasts(RHS);
47708+ if (LHSShuf.getOpcode() == X86ISD::PSHUFB &&
47709+ RHSShuf.getOpcode() == X86ISD::PSHUFB &&
47710+ scaleShuffleMaskElts(VT.getSizeInBits() / 8, CondMask, ByteMask) &&
47711+ getTargetShuffleMask(LHSShuf, true, LHSOps, LHSMask) &&
47712+ getTargetShuffleMask(RHSShuf, true, RHSOps, RHSMask)) {
47713+ assert(ByteMask.size() == LHSMask.size() &&
47714+ ByteMask.size() == RHSMask.size() && "Shuffle mask mismatch");
47715+ for (auto [I, M] : enumerate(ByteMask)) {
47716+ // getConstVector sets negative shuffle mask values as undef, so
47717+ // ensure we hardcode SM_SentinelZero values to zero (0x80).
47718+ if (M < (int)ByteMask.size()) {
47719+ LHSMask[I] = isUndefOrZero(LHSMask[I]) ? 0x80 : LHSMask[I];
47720+ RHSMask[I] = 0x80;
47721+ } else {
47722+ LHSMask[I] = 0x80;
47723+ RHSMask[I] = isUndefOrZero(RHSMask[I]) ? 0x80 : RHSMask[I];
47724+ }
47725+ }
47726+ MVT ByteVT = LHSShuf.getSimpleValueType();
47727+ LHS = DAG.getNode(X86ISD::PSHUFB, DL, ByteVT, LHSOps[0],
47728+ getConstVector(LHSMask, ByteVT, DAG, DL, true));
47729+ RHS = DAG.getNode(X86ISD::PSHUFB, DL, ByteVT, RHSOps[0],
47730+ getConstVector(RHSMask, ByteVT, DAG, DL, true));
47731+ return DAG.getBitcast(VT, DAG.getNode(ISD::OR, DL, ByteVT, LHS, RHS));
47732+ }
47733+ }
4769947734
4770047735 // Attempt to combine as shuffle.
4770147736 SDValue Op(N, 0);
@@ -47704,43 +47739,6 @@ static SDValue combineSelect(SDNode *N, SelectionDAG &DAG,
4770447739 }
4770547740 }
4770647741
47707- // fold vselect(cond, pshufb(x), pshufb(y)) -> or (pshufb(x), pshufb(y))
47708- // by forcing the unselected elements to zero.
47709- // TODO: Can we handle more shuffles with this?
47710- if (N->getOpcode() == ISD::VSELECT && CondVT.isVector() && LHS.hasOneUse() &&
47711- RHS.hasOneUse()) {
47712- SmallVector<SDValue, 1> LHSOps, RHSOps;
47713- SmallVector<int, 64> LHSMask, RHSMask, CondMask, ByteMask;
47714- SDValue LHSShuf = peekThroughOneUseBitcasts(LHS);
47715- SDValue RHSShuf = peekThroughOneUseBitcasts(RHS);
47716- if (LHSShuf.getOpcode() == X86ISD::PSHUFB &&
47717- RHSShuf.getOpcode() == X86ISD::PSHUFB &&
47718- createShuffleMaskFromVSELECT(CondMask, Cond) &&
47719- scaleShuffleMaskElts(VT.getSizeInBits() / 8, CondMask, ByteMask) &&
47720- getTargetShuffleMask(LHSShuf, true, LHSOps, LHSMask) &&
47721- getTargetShuffleMask(RHSShuf, true, RHSOps, RHSMask)) {
47722- assert(ByteMask.size() == LHSMask.size() &&
47723- ByteMask.size() == RHSMask.size() && "Shuffle mask mismatch");
47724- for (auto [I, M] : enumerate(ByteMask)) {
47725- // getConstVector sets negative shuffle mask values as undef, so ensure
47726- // we hardcode SM_SentinelZero values to zero (0x80).
47727- if (M < (int)ByteMask.size()) {
47728- LHSMask[I] = isUndefOrZero(LHSMask[I]) ? 0x80 : LHSMask[I];
47729- RHSMask[I] = 0x80;
47730- } else {
47731- LHSMask[I] = 0x80;
47732- RHSMask[I] = isUndefOrZero(RHSMask[I]) ? 0x80 : RHSMask[I];
47733- }
47734- }
47735- MVT ByteVT = LHSShuf.getSimpleValueType();
47736- LHS = DAG.getNode(X86ISD::PSHUFB, DL, ByteVT, LHSOps[0],
47737- getConstVector(LHSMask, ByteVT, DAG, DL, true));
47738- RHS = DAG.getNode(X86ISD::PSHUFB, DL, ByteVT, RHSOps[0],
47739- getConstVector(RHSMask, ByteVT, DAG, DL, true));
47740- return DAG.getBitcast(VT, DAG.getNode(ISD::OR, DL, ByteVT, LHS, RHS));
47741- }
47742- }
47743-
4774447742 // If we have SSE[12] support, try to form min/max nodes. SSE min/max
4774547743 // instructions match the semantics of the common C idiom x<y?x:y but not
4774647744 // x<=y?x:y, because of how they handle negative zero (which can be
0 commit comments