@@ -41368,6 +41368,22 @@ static SmallVector<int, 4> getPSHUFShuffleMask(SDValue N) {
4136841368 }
4136941369}
4137041370
41371+ /// Get the expanded blend mask from a BLENDI node.
41372+ /// For v16i16 nodes, this will splat the repeated i8 mask.
41373+ static APInt getBLENDIBlendMask(SDValue V) {
41374+ assert(V.getOpcode() == X86ISD::BLENDI && "Unknown blend shuffle");
41375+ unsigned NumElts = V.getSimpleValueType().getVectorNumElements();
41376+ APInt Mask = V.getConstantOperandAPInt(2);
41377+ if (Mask.getBitWidth() > NumElts)
41378+ Mask = Mask.trunc(NumElts);
41379+ if (NumElts == 16) {
41380+ assert(Mask.getBitWidth() == 8 && "Unexpected v16i16 blend mask width");
41381+ Mask = APInt::getSplat(16, Mask);
41382+ }
41383+ assert(Mask.getBitWidth() == NumElts && "Unexpected blend mask width");
41384+ return Mask;
41385+ }
41386+
4137141387/// Search for a combinable shuffle across a chain ending in pshufd.
4137241388///
4137341389/// We walk up the chain and look for a combinable shuffle, skipping over
@@ -42266,7 +42282,7 @@ static SDValue combineTargetShuffle(SDValue N, const SDLoc &DL,
4226642282 unsigned SrcBits = SrcVT.getScalarSizeInBits();
4226742283 if ((EltBits % SrcBits) == 0 && SrcBits >= 32) {
4226842284 unsigned NewSize = SrcVT.getVectorNumElements();
42269- APInt BlendMask = N.getConstantOperandAPInt(2).zextOrTrunc(NumElts );
42285+ APInt BlendMask = getBLENDIBlendMask(N );
4227042286 APInt NewBlendMask = APIntOps::ScaleBitMask(BlendMask, NewSize);
4227142287 return DAG.getBitcast(
4227242288 VT, DAG.getNode(X86ISD::BLENDI, DL, SrcVT, N0.getOperand(0),
@@ -58488,16 +58504,11 @@ static SDValue combineConcatVectorOps(const SDLoc &DL, MVT VT,
5848858504 break;
5848958505 case X86ISD::BLENDI:
5849058506 if (NumOps == 2 && VT.is512BitVector() && Subtarget.useBWIRegs()) {
58491- uint64_t Mask0 = Ops[0].getConstantOperandVal(2);
58492- uint64_t Mask1 = Ops[1].getConstantOperandVal(2);
58493- // MVT::v16i16 has repeated blend mask.
58494- if (Op0.getSimpleValueType() == MVT::v16i16) {
58495- Mask0 = (Mask0 << 8) | Mask0;
58496- Mask1 = (Mask1 << 8) | Mask1;
58497- }
58498- uint64_t Mask = (Mask1 << (VT.getVectorNumElements() / 2)) | Mask0;
58499- MVT MaskSVT = MVT::getIntegerVT(VT.getVectorNumElements());
58500- MVT MaskVT = MVT::getVectorVT(MVT::i1, VT.getVectorNumElements());
58507+ unsigned NumElts = VT.getVectorNumElements();
58508+ APInt Mask = getBLENDIBlendMask(Ops[0]).zext(NumElts);
58509+ Mask.insertBits(getBLENDIBlendMask(Ops[1]), NumElts / 2);
58510+ MVT MaskSVT = MVT::getIntegerVT(NumElts);
58511+ MVT MaskVT = MVT::getVectorVT(MVT::i1, NumElts);
5850158512 SDValue Sel =
5850258513 DAG.getBitcast(MaskVT, DAG.getConstant(Mask, DL, MaskSVT));
5850358514 return DAG.getSelect(DL, VT, Sel, ConcatSubOperand(VT, Ops, 1),
0 commit comments