diff --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp index 2541182de1208..d5837ab938d4e 100644 --- a/llvm/lib/Target/X86/X86ISelLowering.cpp +++ b/llvm/lib/Target/X86/X86ISelLowering.cpp @@ -47690,12 +47690,47 @@ static SDValue combineSelect(SDNode *N, SelectionDAG &DAG, return V; if (N->getOpcode() == ISD::VSELECT || N->getOpcode() == X86ISD::BLENDV) { - SmallVector Mask; - if (createShuffleMaskFromVSELECT(Mask, Cond, + SmallVector CondMask; + if (createShuffleMaskFromVSELECT(CondMask, Cond, N->getOpcode() == X86ISD::BLENDV)) { // Convert vselects with constant condition into shuffles. if (DCI.isBeforeLegalizeOps()) - return DAG.getVectorShuffle(VT, DL, LHS, RHS, Mask); + return DAG.getVectorShuffle(VT, DL, LHS, RHS, CondMask); + + // fold vselect(cond, pshufb(x), pshufb(y)) -> or (pshufb(x), pshufb(y)) + // by forcing the unselected elements to zero. + // TODO: Can we handle more shuffles with this? + if (LHS.hasOneUse() && RHS.hasOneUse()) { + SmallVector LHSOps, RHSOps; + SmallVector LHSMask, RHSMask, ByteMask; + SDValue LHSShuf = peekThroughOneUseBitcasts(LHS); + SDValue RHSShuf = peekThroughOneUseBitcasts(RHS); + if (LHSShuf.getOpcode() == X86ISD::PSHUFB && + RHSShuf.getOpcode() == X86ISD::PSHUFB && + scaleShuffleMaskElts(VT.getSizeInBits() / 8, CondMask, ByteMask) && + getTargetShuffleMask(LHSShuf, true, LHSOps, LHSMask) && + getTargetShuffleMask(RHSShuf, true, RHSOps, RHSMask)) { + assert(ByteMask.size() == LHSMask.size() && + ByteMask.size() == RHSMask.size() && "Shuffle mask mismatch"); + for (auto [I, M] : enumerate(ByteMask)) { + // getConstVector sets negative shuffle mask values as undef, so + // ensure we hardcode SM_SentinelZero values to zero (0x80). + if (M < (int)ByteMask.size()) { + LHSMask[I] = isUndefOrZero(LHSMask[I]) ? 0x80 : LHSMask[I]; + RHSMask[I] = 0x80; + } else { + LHSMask[I] = 0x80; + RHSMask[I] = isUndefOrZero(RHSMask[I]) ? 0x80 : RHSMask[I]; + } + } + MVT ByteVT = LHSShuf.getSimpleValueType(); + LHS = DAG.getNode(X86ISD::PSHUFB, DL, ByteVT, LHSOps[0], + getConstVector(LHSMask, ByteVT, DAG, DL, true)); + RHS = DAG.getNode(X86ISD::PSHUFB, DL, ByteVT, RHSOps[0], + getConstVector(RHSMask, ByteVT, DAG, DL, true)); + return DAG.getBitcast(VT, DAG.getNode(ISD::OR, DL, ByteVT, LHS, RHS)); + } + } // Attempt to combine as shuffle. SDValue Op(N, 0); @@ -47704,43 +47739,6 @@ static SDValue combineSelect(SDNode *N, SelectionDAG &DAG, } } - // fold vselect(cond, pshufb(x), pshufb(y)) -> or (pshufb(x), pshufb(y)) - // by forcing the unselected elements to zero. - // TODO: Can we handle more shuffles with this? - if (N->getOpcode() == ISD::VSELECT && CondVT.isVector() && LHS.hasOneUse() && - RHS.hasOneUse()) { - SmallVector LHSOps, RHSOps; - SmallVector LHSMask, RHSMask, CondMask, ByteMask; - SDValue LHSShuf = peekThroughOneUseBitcasts(LHS); - SDValue RHSShuf = peekThroughOneUseBitcasts(RHS); - if (LHSShuf.getOpcode() == X86ISD::PSHUFB && - RHSShuf.getOpcode() == X86ISD::PSHUFB && - createShuffleMaskFromVSELECT(CondMask, Cond) && - scaleShuffleMaskElts(VT.getSizeInBits() / 8, CondMask, ByteMask) && - getTargetShuffleMask(LHSShuf, true, LHSOps, LHSMask) && - getTargetShuffleMask(RHSShuf, true, RHSOps, RHSMask)) { - assert(ByteMask.size() == LHSMask.size() && - ByteMask.size() == RHSMask.size() && "Shuffle mask mismatch"); - for (auto [I, M] : enumerate(ByteMask)) { - // getConstVector sets negative shuffle mask values as undef, so ensure - // we hardcode SM_SentinelZero values to zero (0x80). - if (M < (int)ByteMask.size()) { - LHSMask[I] = isUndefOrZero(LHSMask[I]) ? 0x80 : LHSMask[I]; - RHSMask[I] = 0x80; - } else { - LHSMask[I] = 0x80; - RHSMask[I] = isUndefOrZero(RHSMask[I]) ? 0x80 : RHSMask[I]; - } - } - MVT ByteVT = LHSShuf.getSimpleValueType(); - LHS = DAG.getNode(X86ISD::PSHUFB, DL, ByteVT, LHSOps[0], - getConstVector(LHSMask, ByteVT, DAG, DL, true)); - RHS = DAG.getNode(X86ISD::PSHUFB, DL, ByteVT, RHSOps[0], - getConstVector(RHSMask, ByteVT, DAG, DL, true)); - return DAG.getBitcast(VT, DAG.getNode(ISD::OR, DL, ByteVT, LHS, RHS)); - } - } - // If we have SSE[12] support, try to form min/max nodes. SSE min/max // instructions match the semantics of the common C idiom x