Skip to content

Commit 779e946

Browse files
RKSimonDebadri Basak
authored andcommitted
[X86] detectPMADDUBSW - use SDPatternMatch to handle some of the commutative matching for the zext/sext pairs. NFC. (llvm#165861)
1 parent 8abd5c0 commit 779e946

File tree

1 file changed

+9
-31
lines changed

1 file changed

+9
-31
lines changed

llvm/lib/Target/X86/X86ISelLowering.cpp

Lines changed: 9 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -54459,6 +54459,7 @@ static SDValue combinePMULH(SDValue Src, EVT VT, const SDLoc &DL,
5445954459
static SDValue detectPMADDUBSW(SDValue In, EVT VT, SelectionDAG &DAG,
5446054460
const X86Subtarget &Subtarget,
5446154461
const SDLoc &DL) {
54462+
using namespace SDPatternMatch;
5446254463
if (!VT.isVector() || !Subtarget.hasSSSE3())
5446354464
return SDValue();
5446454465

@@ -54468,42 +54469,19 @@ static SDValue detectPMADDUBSW(SDValue In, EVT VT, SelectionDAG &DAG,
5446854469
return SDValue();
5446954470

5447054471
SDValue SSatVal = detectSSatPattern(In, VT);
54471-
if (!SSatVal || SSatVal.getOpcode() != ISD::ADD)
54472-
return SDValue();
54473-
54474-
// Ok this is a signed saturation of an ADD. See if this ADD is adding pairs
54475-
// of multiplies from even/odd elements.
54476-
SDValue N0 = SSatVal.getOperand(0);
54477-
SDValue N1 = SSatVal.getOperand(1);
54478-
54479-
if (N0.getOpcode() != ISD::MUL || N1.getOpcode() != ISD::MUL)
54472+
if (!SSatVal)
5448054473
return SDValue();
5448154474

54482-
SDValue N00 = N0.getOperand(0);
54483-
SDValue N01 = N0.getOperand(1);
54484-
SDValue N10 = N1.getOperand(0);
54485-
SDValue N11 = N1.getOperand(1);
54486-
54475+
// See if this is a signed saturation of an ADD, adding pairs of multiplies
54476+
// from even/odd elements, from zero_extend/sign_extend operands.
54477+
//
5448754478
// TODO: Handle constant vectors and use knownbits/computenumsignbits?
54488-
// Canonicalize zero_extend to LHS.
54489-
if (N01.getOpcode() == ISD::ZERO_EXTEND)
54490-
std::swap(N00, N01);
54491-
if (N11.getOpcode() == ISD::ZERO_EXTEND)
54492-
std::swap(N10, N11);
54493-
54494-
// Ensure we have a zero_extend and a sign_extend.
54495-
if (N00.getOpcode() != ISD::ZERO_EXTEND ||
54496-
N01.getOpcode() != ISD::SIGN_EXTEND ||
54497-
N10.getOpcode() != ISD::ZERO_EXTEND ||
54498-
N11.getOpcode() != ISD::SIGN_EXTEND)
54479+
SDValue N00, N01, N10, N11;
54480+
if (!sd_match(SSatVal,
54481+
m_Add(m_Mul(m_ZExt(m_Value(N00)), m_SExt(m_Value(N01))),
54482+
m_Mul(m_ZExt(m_Value(N10)), m_SExt(m_Value(N11))))))
5449954483
return SDValue();
5450054484

54501-
// Peek through the extends.
54502-
N00 = N00.getOperand(0);
54503-
N01 = N01.getOperand(0);
54504-
N10 = N10.getOperand(0);
54505-
N11 = N11.getOperand(0);
54506-
5450754485
// Ensure the extend is from vXi8.
5450854486
if (N00.getValueType().getVectorElementType() != MVT::i8 ||
5450954487
N01.getValueType().getVectorElementType() != MVT::i8 ||

0 commit comments

Comments
 (0)