@@ -54246,7 +54246,7 @@ static SDValue combineTruncatedArithmetic(SDNode *N, SelectionDAG &DAG,
54246
54246
}
54247
54247
54248
54248
// Try to form a MULHU or MULHS node by looking for
54249
- // (trunc (srl (mul ext, ext), 16))
54249
+ // (trunc (srl (mul ext, ext), >= 16))
54250
54250
// TODO: This is X86 specific because we want to be able to handle wide types
54251
54251
// before type legalization. But we can only do it if the vector will be
54252
54252
// legalized via widening/splitting. Type legalization can't handle promotion
@@ -54271,10 +54271,16 @@ static SDValue combinePMULH(SDValue Src, EVT VT, const SDLoc &DL,
54271
54271
54272
54272
// First instruction should be a right shift by 16 of a multiply.
54273
54273
SDValue LHS, RHS;
54274
+ APInt ShiftAmt;
54274
54275
if (!sd_match(Src,
54275
- m_Srl(m_Mul(m_Value(LHS), m_Value(RHS)), m_SpecificInt(16 ))))
54276
+ m_Srl(m_Mul(m_Value(LHS), m_Value(RHS)), m_ConstInt(ShiftAmt ))))
54276
54277
return SDValue();
54277
54278
54279
+ if (ShiftAmt.ult(16) || ShiftAmt.uge(InVT.getScalarSizeInBits()))
54280
+ return SDValue();
54281
+
54282
+ uint64_t AdditionalShift = ShiftAmt.getZExtValue() - 16;
54283
+
54278
54284
// Count leading sign/zero bits on both inputs - if there are enough then
54279
54285
// truncation back to vXi16 will be cheap - either as a pack/shuffle
54280
54286
// sequence or using AVX512 truncations. If the inputs are sext/zext then the
@@ -54312,15 +54318,19 @@ static SDValue combinePMULH(SDValue Src, EVT VT, const SDLoc &DL,
54312
54318
InVT.getSizeInBits() / 16);
54313
54319
SDValue Res = DAG.getNode(ISD::MULHU, DL, BCVT, DAG.getBitcast(BCVT, LHS),
54314
54320
DAG.getBitcast(BCVT, RHS));
54315
- return DAG.getNode(ISD::TRUNCATE, DL, VT, DAG.getBitcast(InVT, Res));
54321
+ Res = DAG.getNode(ISD::TRUNCATE, DL, VT, DAG.getBitcast(InVT, Res));
54322
+ return DAG.getNode(ISD::SRL, DL, VT, Res,
54323
+ DAG.getShiftAmountConstant(AdditionalShift, VT, DL));
54316
54324
}
54317
54325
54318
54326
// Truncate back to source type.
54319
54327
LHS = DAG.getNode(ISD::TRUNCATE, DL, VT, LHS);
54320
54328
RHS = DAG.getNode(ISD::TRUNCATE, DL, VT, RHS);
54321
54329
54322
54330
unsigned Opc = IsSigned ? ISD::MULHS : ISD::MULHU;
54323
- return DAG.getNode(Opc, DL, VT, LHS, RHS);
54331
+ SDValue Res = DAG.getNode(Opc, DL, VT, LHS, RHS);
54332
+ return DAG.getNode(ISD::SRL, DL, VT, Res,
54333
+ DAG.getShiftAmountConstant(AdditionalShift, VT, DL));
54324
54334
}
54325
54335
54326
54336
// Attempt to match PMADDUBSW, which multiplies corresponding unsigned bytes
0 commit comments