@@ -54246,7 +54246,7 @@ static SDValue combineTruncatedArithmetic(SDNode *N, SelectionDAG &DAG,
5424654246}
5424754247
5424854248// Try to form a MULHU or MULHS node by looking for
54249- // (trunc (srl (mul ext, ext), 16))
54249+ // (trunc (srl (mul ext, ext), >= 16))
5425054250// TODO: This is X86 specific because we want to be able to handle wide types
5425154251// before type legalization. But we can only do it if the vector will be
5425254252// legalized via widening/splitting. Type legalization can't handle promotion
@@ -54271,10 +54271,16 @@ static SDValue combinePMULH(SDValue Src, EVT VT, const SDLoc &DL,
5427154271
5427254272 // First instruction should be a right shift by 16 of a multiply.
5427354273 SDValue LHS, RHS;
54274+ APInt ShiftAmt;
5427454275 if (!sd_match(Src,
54275- m_Srl(m_Mul(m_Value(LHS), m_Value(RHS)), m_SpecificInt(16 ))))
54276+ m_Srl(m_Mul(m_Value(LHS), m_Value(RHS)), m_ConstInt(ShiftAmt ))))
5427654277 return SDValue();
5427754278
54279+ if (ShiftAmt.ult(16) || ShiftAmt.uge(InVT.getScalarSizeInBits()))
54280+ return SDValue();
54281+
54282+ uint64_t AdditionalShift = ShiftAmt.getZExtValue() - 16;
54283+
5427854284 // Count leading sign/zero bits on both inputs - if there are enough then
5427954285 // truncation back to vXi16 will be cheap - either as a pack/shuffle
5428054286 // sequence or using AVX512 truncations. If the inputs are sext/zext then the
@@ -54312,15 +54318,19 @@ static SDValue combinePMULH(SDValue Src, EVT VT, const SDLoc &DL,
5431254318 InVT.getSizeInBits() / 16);
5431354319 SDValue Res = DAG.getNode(ISD::MULHU, DL, BCVT, DAG.getBitcast(BCVT, LHS),
5431454320 DAG.getBitcast(BCVT, RHS));
54315- return DAG.getNode(ISD::TRUNCATE, DL, VT, DAG.getBitcast(InVT, Res));
54321+ Res = DAG.getNode(ISD::TRUNCATE, DL, VT, DAG.getBitcast(InVT, Res));
54322+ return DAG.getNode(ISD::SRL, DL, VT, Res,
54323+ DAG.getShiftAmountConstant(AdditionalShift, VT, DL));
5431654324 }
5431754325
5431854326 // Truncate back to source type.
5431954327 LHS = DAG.getNode(ISD::TRUNCATE, DL, VT, LHS);
5432054328 RHS = DAG.getNode(ISD::TRUNCATE, DL, VT, RHS);
5432154329
5432254330 unsigned Opc = IsSigned ? ISD::MULHS : ISD::MULHU;
54323- return DAG.getNode(Opc, DL, VT, LHS, RHS);
54331+ SDValue Res = DAG.getNode(Opc, DL, VT, LHS, RHS);
54332+ return DAG.getNode(ISD::SRL, DL, VT, Res,
54333+ DAG.getShiftAmountConstant(AdditionalShift, VT, DL));
5432454334}
5432554335
5432654336// Attempt to match PMADDUBSW, which multiplies corresponding unsigned bytes
0 commit comments