@@ -47237,32 +47237,37 @@ static SDValue combineSelect(SDNode *N, SelectionDAG &DAG,
4723747237 // fold vselect(cond, pshufb(x), pshufb(y)) -> or (pshufb(x), pshufb(y))
4723847238 // by forcing the unselected elements to zero.
4723947239 // TODO: Can we handle more shuffles with this?
47240- if (N->getOpcode() == ISD::VSELECT && CondVT.isVector() &&
47241- LHS.getOpcode() == X86ISD::PSHUFB && RHS.getOpcode() == X86ISD::PSHUFB &&
47242- LHS.hasOneUse() && RHS.hasOneUse()) {
47243- MVT SimpleVT = VT.getSimpleVT();
47240+ if (N->getOpcode() == ISD::VSELECT && CondVT.isVector() && LHS.hasOneUse() &&
47241+ RHS.hasOneUse()) {
4724447242 SmallVector<SDValue, 1> LHSOps, RHSOps;
47245- SmallVector<int, 64> LHSMask, RHSMask, CondMask;
47246- if (createShuffleMaskFromVSELECT(CondMask, Cond) &&
47247- getTargetShuffleMask(LHS, true, LHSOps, LHSMask) &&
47248- getTargetShuffleMask(RHS, true, RHSOps, RHSMask)) {
47249- int NumElts = VT.getVectorNumElements();
47250- for (int i = 0; i != NumElts; ++i) {
47243+ SmallVector<int, 64> LHSMask, RHSMask, CondMask, ByteMask;
47244+ SDValue LHSShuf = peekThroughOneUseBitcasts(LHS);
47245+ SDValue RHSShuf = peekThroughOneUseBitcasts(RHS);
47246+ if (LHSShuf.getOpcode() == X86ISD::PSHUFB &&
47247+ RHSShuf.getOpcode() == X86ISD::PSHUFB &&
47248+ createShuffleMaskFromVSELECT(CondMask, Cond) &&
47249+ scaleShuffleMaskElts(VT.getSizeInBits() / 8, CondMask, ByteMask) &&
47250+ getTargetShuffleMask(LHSShuf, true, LHSOps, LHSMask) &&
47251+ getTargetShuffleMask(RHSShuf, true, RHSOps, RHSMask)) {
47252+ assert(ByteMask.size() == LHSMask.size() &&
47253+ ByteMask.size() == RHSMask.size() && "Shuffle mask mismatch");
47254+ for (auto [I, M] : enumerate(ByteMask)) {
4725147255 // getConstVector sets negative shuffle mask values as undef, so ensure
4725247256 // we hardcode SM_SentinelZero values to zero (0x80).
47253- if (CondMask[i] < NumElts ) {
47254- LHSMask[i ] = isUndefOrZero(LHSMask[i ]) ? 0x80 : LHSMask[i ];
47255- RHSMask[i ] = 0x80;
47257+ if (M < ByteMask.size() ) {
47258+ LHSMask[I ] = isUndefOrZero(LHSMask[I ]) ? 0x80 : LHSMask[I ];
47259+ RHSMask[I ] = 0x80;
4725647260 } else {
47257- LHSMask[i ] = 0x80;
47258- RHSMask[i ] = isUndefOrZero(RHSMask[i ]) ? 0x80 : RHSMask[i ];
47261+ LHSMask[I ] = 0x80;
47262+ RHSMask[I ] = isUndefOrZero(RHSMask[I ]) ? 0x80 : RHSMask[I ];
4725947263 }
4726047264 }
47261- LHS = DAG.getNode(X86ISD::PSHUFB, DL, VT, LHS.getOperand(0),
47262- getConstVector(LHSMask, SimpleVT, DAG, DL, true));
47263- RHS = DAG.getNode(X86ISD::PSHUFB, DL, VT, RHS.getOperand(0),
47264- getConstVector(RHSMask, SimpleVT, DAG, DL, true));
47265- return DAG.getNode(ISD::OR, DL, VT, LHS, RHS);
47265+ MVT ByteVT = LHSShuf.getSimpleValueType();
47266+ LHS = DAG.getNode(X86ISD::PSHUFB, DL, ByteVT, LHSOps[0],
47267+ getConstVector(LHSMask, ByteVT, DAG, DL, true));
47268+ RHS = DAG.getNode(X86ISD::PSHUFB, DL, ByteVT, RHSOps[0],
47269+ getConstVector(RHSMask, ByteVT, DAG, DL, true));
47270+ return DAG.getBitcast(VT, DAG.getNode(ISD::OR, DL, ByteVT, LHS, RHS));
4726647271 }
4726747272 }
4726847273
0 commit comments