@@ -46047,23 +46047,22 @@ static SDValue createVPDPBUSD(SelectionDAG &DAG, SDValue LHS, SDValue RHS,
4604746047 DpBuilder, false);
4604846048}
4604946049
46050- // Given two zexts of <k x i8> to <k x i32>, create a PSADBW of the inputs
46051- // to these zexts.
46052- static SDValue createPSADBW(SelectionDAG &DAG, const SDValue &Zext0,
46053- const SDValue &Zext1, const SDLoc &DL,
46054- const X86Subtarget &Subtarget) {
46050+ // Create a PSADBW given two sources representable as zexts of vXi8.
46051+ static SDValue createPSADBW(SelectionDAG &DAG, SDValue N0, SDValue N1,
46052+ const SDLoc &DL, const X86Subtarget &Subtarget) {
4605546053 // Find the appropriate width for the PSADBW.
46056- EVT InVT = Zext0.getOperand(0).getValueType();
46057- unsigned RegSize = std::max(128u, (unsigned)InVT.getSizeInBits());
46058-
46059- // "Zero-extend" the i8 vectors. This is not a per-element zext, rather we
46060- // fill in the missing vector elements with 0.
46061- unsigned NumConcat = RegSize / InVT.getSizeInBits();
46062- SmallVector<SDValue, 16> Ops(NumConcat, DAG.getConstant(0, DL, InVT));
46063- Ops[0] = Zext0.getOperand(0);
46054+ EVT DstVT = N0.getValueType();
46055+ EVT SrcVT = EVT::getVectorVT(*DAG.getContext(), MVT::i8,
46056+ DstVT.getVectorElementCount());
46057+ unsigned RegSize = std::max(128u, (unsigned)SrcVT.getSizeInBits());
46058+
46059+ // Widen the vXi8 vectors, padding with zero vector elements.
46060+ unsigned NumConcat = RegSize / SrcVT.getSizeInBits();
46061+ SmallVector<SDValue, 16> Ops(NumConcat, DAG.getConstant(0, DL, SrcVT));
46062+ Ops[0] = DAG.getZExtOrTrunc(N0, DL, SrcVT);
4606446063 MVT ExtendedVT = MVT::getVectorVT(MVT::i8, RegSize / 8);
4606546064 SDValue SadOp0 = DAG.getNode(ISD::CONCAT_VECTORS, DL, ExtendedVT, Ops);
46066- Ops[0] = Zext1.getOperand(0 );
46065+ Ops[0] = DAG.getZExtOrTrunc(N1, DL, SrcVT );
4606746066 SDValue SadOp1 = DAG.getNode(ISD::CONCAT_VECTORS, DL, ExtendedVT, Ops);
4606846067
4606946068 // Actually build the SAD, split as 128/256/512 bits for SSE/AVX2/AVX512BW.
@@ -46073,7 +46072,7 @@ static SDValue createPSADBW(SelectionDAG &DAG, const SDValue &Zext0,
4607346072 return DAG.getNode(X86ISD::PSADBW, DL, VT, Ops);
4607446073 };
4607546074 MVT SadVT = MVT::getVectorVT(MVT::i64, RegSize / 64);
46076- return SplitOpsAndApply(DAG, Subtarget, DL, SadVT, { SadOp0, SadOp1 },
46075+ return SplitOpsAndApply(DAG, Subtarget, DL, SadVT, {SadOp0, SadOp1},
4607746076 PSADBWBuilder);
4607846077}
4607946078
@@ -46372,9 +46371,8 @@ static SDValue combineBasicSADPattern(SDNode *Extract, SelectionDAG &DAG,
4637246371 return SDValue();
4637346372
4637446373 EVT ExtractVT = Extract->getValueType(0);
46375- // Verify the type we're extracting is either i32 or i64.
46376- // FIXME: Could support other types, but this is what we have coverage for.
46377- if (ExtractVT != MVT::i32 && ExtractVT != MVT::i64)
46374+ if (ExtractVT != MVT::i8 && ExtractVT != MVT::i16 && ExtractVT != MVT::i32 &&
46375+ ExtractVT != MVT::i64)
4637846376 return SDValue();
4637946377
4638046378 EVT VT = Extract->getOperand(0).getValueType();
@@ -46399,20 +46397,27 @@ static SDValue combineBasicSADPattern(SDNode *Extract, SelectionDAG &DAG,
4639946397 Root.getOpcode() == ISD::ANY_EXTEND)
4640046398 Root = Root.getOperand(0);
4640146399
46402- // Check whether we have an abdu pattern.
46403- // TODO: Add handling for ISD::ABDU.
46404- SDValue Zext0, Zext1 ;
46400+ // Check whether we have an vXi8 abdu pattern.
46401+ // TODO: Just match ISD::ABDU once the DAG is topological sorted .
46402+ SDValue Src0, Src1 ;
4640546403 if (!sd_match(
4640646404 Root,
46407- m_Abs(m_Sub(m_AllOf(m_Value(Zext0),
46408- m_ZExt(m_SpecificVectorElementVT(MVT::i8))),
46409- m_AllOf(m_Value(Zext1),
46410- m_ZExt(m_SpecificVectorElementVT(MVT::i8)))))))
46405+ m_AnyOf(
46406+ m_SpecificVectorElementVT(
46407+ MVT::i8, m_c_BinOp(ISD::ABDU, m_Value(Src0), m_Value(Src1))),
46408+ m_SpecificVectorElementVT(
46409+ MVT::i8, m_Sub(m_UMax(m_Value(Src0), m_Value(Src1)),
46410+ m_UMin(m_Deferred(Src0), m_Deferred(Src1)))),
46411+ m_Abs(
46412+ m_Sub(m_AllOf(m_Value(Src0),
46413+ m_ZExt(m_SpecificVectorElementVT(MVT::i8))),
46414+ m_AllOf(m_Value(Src1),
46415+ m_ZExt(m_SpecificVectorElementVT(MVT::i8))))))))
4641146416 return SDValue();
4641246417
4641346418 // Create the SAD instruction.
4641446419 SDLoc DL(Extract);
46415- SDValue SAD = createPSADBW(DAG, Zext0, Zext1 , DL, Subtarget);
46420+ SDValue SAD = createPSADBW(DAG, Src0, Src1 , DL, Subtarget);
4641646421
4641746422 // If the original vector was wider than 8 elements, sum over the results
4641846423 // in the SAD vector.
0 commit comments