@@ -54459,6 +54459,7 @@ static SDValue combinePMULH(SDValue Src, EVT VT, const SDLoc &DL,
5445954459static SDValue detectPMADDUBSW(SDValue In, EVT VT, SelectionDAG &DAG,
5446054460 const X86Subtarget &Subtarget,
5446154461 const SDLoc &DL) {
54462+ using namespace SDPatternMatch;
5446254463 if (!VT.isVector() || !Subtarget.hasSSSE3())
5446354464 return SDValue();
5446454465
@@ -54468,42 +54469,19 @@ static SDValue detectPMADDUBSW(SDValue In, EVT VT, SelectionDAG &DAG,
5446854469 return SDValue();
5446954470
5447054471 SDValue SSatVal = detectSSatPattern(In, VT);
54471- if (!SSatVal || SSatVal.getOpcode() != ISD::ADD)
54472- return SDValue();
54473-
54474- // Ok this is a signed saturation of an ADD. See if this ADD is adding pairs
54475- // of multiplies from even/odd elements.
54476- SDValue N0 = SSatVal.getOperand(0);
54477- SDValue N1 = SSatVal.getOperand(1);
54478-
54479- if (N0.getOpcode() != ISD::MUL || N1.getOpcode() != ISD::MUL)
54472+ if (!SSatVal)
5448054473 return SDValue();
5448154474
54482- SDValue N00 = N0.getOperand(0);
54483- SDValue N01 = N0.getOperand(1);
54484- SDValue N10 = N1.getOperand(0);
54485- SDValue N11 = N1.getOperand(1);
54486-
54475+ // See if this is a signed saturation of an ADD, adding pairs of multiplies
54476+ // from even/odd elements, from zero_extend/sign_extend operands.
54477+ //
5448754478 // TODO: Handle constant vectors and use knownbits/computenumsignbits?
54488- // Canonicalize zero_extend to LHS.
54489- if (N01.getOpcode() == ISD::ZERO_EXTEND)
54490- std::swap(N00, N01);
54491- if (N11.getOpcode() == ISD::ZERO_EXTEND)
54492- std::swap(N10, N11);
54493-
54494- // Ensure we have a zero_extend and a sign_extend.
54495- if (N00.getOpcode() != ISD::ZERO_EXTEND ||
54496- N01.getOpcode() != ISD::SIGN_EXTEND ||
54497- N10.getOpcode() != ISD::ZERO_EXTEND ||
54498- N11.getOpcode() != ISD::SIGN_EXTEND)
54479+ SDValue N00, N01, N10, N11;
54480+ if (!sd_match(SSatVal,
54481+ m_Add(m_Mul(m_ZExt(m_Value(N00)), m_SExt(m_Value(N01))),
54482+ m_Mul(m_ZExt(m_Value(N10)), m_SExt(m_Value(N11))))))
5449954483 return SDValue();
5450054484
54501- // Peek through the extends.
54502- N00 = N00.getOperand(0);
54503- N01 = N01.getOperand(0);
54504- N10 = N10.getOperand(0);
54505- N11 = N11.getOperand(0);
54506-
5450754485 // Ensure the extend is from vXi8.
5450854486 if (N00.getValueType().getVectorElementType() != MVT::i8 ||
5450954487 N01.getValueType().getVectorElementType() != MVT::i8 ||
0 commit comments