@@ -54021,7 +54021,7 @@ static SDValue combineTruncatedArithmetic(SDNode *N, SelectionDAG &DAG,
5402154021}
5402254022
5402354023// Try to form a MULHU or MULHS node by looking for
54024- // (trunc (srl (mul ext, ext), 16))
54024+ // (trunc (srl (mul ext, ext), >= 16))
5402554025// TODO: This is X86 specific because we want to be able to handle wide types
5402654026// before type legalization. But we can only do it if the vector will be
5402754027// legalized via widening/splitting. Type legalization can't handle promotion
@@ -54046,10 +54046,16 @@ static SDValue combinePMULH(SDValue Src, EVT VT, const SDLoc &DL,
5404654046
5404754047 // First instruction should be a right shift by 16 of a multiply.
5404854048 SDValue LHS, RHS;
54049+ APInt ShiftAmt;
5404954050 if (!sd_match(Src,
54050- m_Srl(m_Mul(m_Value(LHS), m_Value(RHS)), m_SpecificInt(16))))
54051+ m_Srl(m_Mul(m_Value(LHS), m_Value(RHS)), m_ConstInt(ShiftAmt))))
54052+ return SDValue();
54053+
54054+ if (ShiftAmt.ult(16))
5405154055 return SDValue();
5405254056
54057+ APInt AdditionalShift = (ShiftAmt - 16).trunc(16);
54058+
5405354059 // Count leading sign/zero bits on both inputs - if there are enough then
5405454060 // truncation back to vXi16 will be cheap - either as a pack/shuffle
5405554061 // sequence or using AVX512 truncations. If the inputs are sext/zext then the
@@ -54087,15 +54093,19 @@ static SDValue combinePMULH(SDValue Src, EVT VT, const SDLoc &DL,
5408754093 InVT.getSizeInBits() / 16);
5408854094 SDValue Res = DAG.getNode(ISD::MULHU, DL, BCVT, DAG.getBitcast(BCVT, LHS),
5408954095 DAG.getBitcast(BCVT, RHS));
54090- return DAG.getNode(ISD::TRUNCATE, DL, VT, DAG.getBitcast(InVT, Res));
54096+ Res = DAG.getNode(ISD::TRUNCATE, DL, VT, DAG.getBitcast(InVT, Res));
54097+ return DAG.getNode(ISD::SRL, DL, VT, Res,
54098+ DAG.getConstant(AdditionalShift, DL, VT));
5409154099 }
5409254100
5409354101 // Truncate back to source type.
5409454102 LHS = DAG.getNode(ISD::TRUNCATE, DL, VT, LHS);
5409554103 RHS = DAG.getNode(ISD::TRUNCATE, DL, VT, RHS);
5409654104
5409754105 unsigned Opc = IsSigned ? ISD::MULHS : ISD::MULHU;
54098- return DAG.getNode(Opc, DL, VT, LHS, RHS);
54106+ SDValue Res = DAG.getNode(Opc, DL, VT, LHS, RHS);
54107+ return DAG.getNode(ISD::SRL, DL, VT, Res,
54108+ DAG.getConstant(AdditionalShift, DL, VT));
5409954109}
5410054110
5410154111// Attempt to match PMADDUBSW, which multiplies corresponding unsigned bytes
0 commit comments