@@ -9782,20 +9782,23 @@ static bool IsElementEquivalent(int MaskSize, SDValue Op, SDValue ExpectedOp,
97829782 return false;
97839783
97849784 EVT VT = Op.getValueType();
9785+ EVT ExpectedVT = ExpectedOp.getValueType();
9786+
9787+ // Sources must be vectors and match the mask's element count.
9788+ if (!VT.isVector() || !ExpectedVT.isVector() ||
9789+ (int)VT.getVectorNumElements() != MaskSize ||
9790+ (int)ExpectedVT.getVectorNumElements() != MaskSize)
9791+ return false;
9792+
97859793 switch (Op.getOpcode()) {
97869794 case ISD::BUILD_VECTOR:
97879795 // If the values are build vectors, we can look through them to find
97889796 // equivalent inputs that make the shuffles equivalent.
9789- // TODO: Handle MaskSize != Op.getNumOperands()?
9790- if (MaskSize == (int)Op.getNumOperands() &&
9791- MaskSize == (int)ExpectedOp.getNumOperands())
9792- return Op.getOperand(Idx) == ExpectedOp.getOperand(ExpectedIdx);
9793- break;
9797+ return Op.getOperand(Idx) == ExpectedOp.getOperand(ExpectedIdx);
97949798 case ISD::BITCAST: {
97959799 SDValue Src = peekThroughBitcasts(Op);
97969800 EVT SrcVT = Src.getValueType();
9797- if (Op == ExpectedOp && SrcVT.isVector() &&
9798- (int)VT.getVectorNumElements() == MaskSize) {
9801+ if (Op == ExpectedOp && SrcVT.isVector()) {
97999802 if ((SrcVT.getScalarSizeInBits() % VT.getScalarSizeInBits()) == 0) {
98009803 unsigned Scale = SrcVT.getScalarSizeInBits() / VT.getScalarSizeInBits();
98019804 return (Idx % Scale) == (ExpectedIdx % Scale) &&
@@ -9816,23 +9819,21 @@ static bool IsElementEquivalent(int MaskSize, SDValue Op, SDValue ExpectedOp,
98169819 }
98179820 case ISD::VECTOR_SHUFFLE: {
98189821 auto *SVN = cast<ShuffleVectorSDNode>(Op);
9819- return Op == ExpectedOp && (int)VT.getVectorNumElements() == MaskSize &&
9822+ return Op == ExpectedOp &&
98209823 SVN->getMaskElt(Idx) == SVN->getMaskElt(ExpectedIdx);
98219824 }
98229825 case X86ISD::VBROADCAST:
98239826 case X86ISD::VBROADCAST_LOAD:
9824- // TODO: Handle MaskSize != VT.getVectorNumElements()?
9825- return (Op == ExpectedOp && (int)VT.getVectorNumElements() == MaskSize);
9827+ return Op == ExpectedOp;
98269828 case X86ISD::SUBV_BROADCAST_LOAD:
9827- // TODO: Handle MaskSize != VT.getVectorNumElements()?
9828- if (Op == ExpectedOp && (int)VT.getVectorNumElements() == MaskSize) {
9829+ if (Op == ExpectedOp) {
98299830 auto *MemOp = cast<MemSDNode>(Op);
98309831 unsigned NumMemElts = MemOp->getMemoryVT().getVectorNumElements();
98319832 return (Idx % NumMemElts) == (ExpectedIdx % NumMemElts);
98329833 }
98339834 break;
98349835 case X86ISD::VPERMI: {
9835- if (Op == ExpectedOp && (int)VT.getVectorNumElements() == MaskSize ) {
9836+ if (Op == ExpectedOp) {
98369837 SmallVector<int, 8> Mask;
98379838 DecodeVPERMMask(MaskSize, Op.getConstantOperandVal(1), Mask);
98389839 SDValue Src = Op.getOperand(0);
@@ -9849,20 +9850,16 @@ static bool IsElementEquivalent(int MaskSize, SDValue Op, SDValue ExpectedOp,
98499850 case X86ISD::PACKSS:
98509851 case X86ISD::PACKUS:
98519852 // HOP(X,X) can refer to the elt from the lower/upper half of a lane.
9852- // TODO: Handle MaskSize != NumElts?
98539853 // TODO: Handle HOP(X,Y) vs HOP(Y,X) equivalence cases.
98549854 if (Op == ExpectedOp && Op.getOperand(0) == Op.getOperand(1)) {
98559855 int NumElts = VT.getVectorNumElements();
9856- if (MaskSize == NumElts) {
9857- int NumLanes = VT.getSizeInBits() / 128;
9858- int NumEltsPerLane = NumElts / NumLanes;
9859- int NumHalfEltsPerLane = NumEltsPerLane / 2;
9860- bool SameLane =
9861- (Idx / NumEltsPerLane) == (ExpectedIdx / NumEltsPerLane);
9862- bool SameElt =
9863- (Idx % NumHalfEltsPerLane) == (ExpectedIdx % NumHalfEltsPerLane);
9864- return SameLane && SameElt;
9865- }
9856+ int NumLanes = VT.getSizeInBits() / 128;
9857+ int NumEltsPerLane = NumElts / NumLanes;
9858+ int NumHalfEltsPerLane = NumEltsPerLane / 2;
9859+ bool SameLane = (Idx / NumEltsPerLane) == (ExpectedIdx / NumEltsPerLane);
9860+ bool SameElt =
9861+ (Idx % NumHalfEltsPerLane) == (ExpectedIdx % NumHalfEltsPerLane);
9862+ return SameLane && SameElt;
98669863 }
98679864 break;
98689865 }
0 commit comments