Skip to content

Commit 6d2e00f

Browse files
abhishek-kaushik22krishna2803
authored andcommitted
[X86] combinePMULH - combine mulhu + srl (llvm#132548)
Fixes llvm#132166
1 parent 26f269f commit 6d2e00f

File tree

2 files changed

+719
-4
lines changed

2 files changed

+719
-4
lines changed

llvm/lib/Target/X86/X86ISelLowering.cpp

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -54246,7 +54246,7 @@ static SDValue combineTruncatedArithmetic(SDNode *N, SelectionDAG &DAG,
5424654246
}
5424754247

5424854248
// Try to form a MULHU or MULHS node by looking for
54249-
// (trunc (srl (mul ext, ext), 16))
54249+
// (trunc (srl (mul ext, ext), >= 16))
5425054250
// TODO: This is X86 specific because we want to be able to handle wide types
5425154251
// before type legalization. But we can only do it if the vector will be
5425254252
// legalized via widening/splitting. Type legalization can't handle promotion
@@ -54271,10 +54271,16 @@ static SDValue combinePMULH(SDValue Src, EVT VT, const SDLoc &DL,
5427154271

5427254272
// First instruction should be a right shift by 16 of a multiply.
5427354273
SDValue LHS, RHS;
54274+
APInt ShiftAmt;
5427454275
if (!sd_match(Src,
54275-
m_Srl(m_Mul(m_Value(LHS), m_Value(RHS)), m_SpecificInt(16))))
54276+
m_Srl(m_Mul(m_Value(LHS), m_Value(RHS)), m_ConstInt(ShiftAmt))))
5427654277
return SDValue();
5427754278

54279+
if (ShiftAmt.ult(16) || ShiftAmt.uge(InVT.getScalarSizeInBits()))
54280+
return SDValue();
54281+
54282+
uint64_t AdditionalShift = ShiftAmt.getZExtValue() - 16;
54283+
5427854284
// Count leading sign/zero bits on both inputs - if there are enough then
5427954285
// truncation back to vXi16 will be cheap - either as a pack/shuffle
5428054286
// sequence or using AVX512 truncations. If the inputs are sext/zext then the
@@ -54312,15 +54318,19 @@ static SDValue combinePMULH(SDValue Src, EVT VT, const SDLoc &DL,
5431254318
InVT.getSizeInBits() / 16);
5431354319
SDValue Res = DAG.getNode(ISD::MULHU, DL, BCVT, DAG.getBitcast(BCVT, LHS),
5431454320
DAG.getBitcast(BCVT, RHS));
54315-
return DAG.getNode(ISD::TRUNCATE, DL, VT, DAG.getBitcast(InVT, Res));
54321+
Res = DAG.getNode(ISD::TRUNCATE, DL, VT, DAG.getBitcast(InVT, Res));
54322+
return DAG.getNode(ISD::SRL, DL, VT, Res,
54323+
DAG.getShiftAmountConstant(AdditionalShift, VT, DL));
5431654324
}
5431754325

5431854326
// Truncate back to source type.
5431954327
LHS = DAG.getNode(ISD::TRUNCATE, DL, VT, LHS);
5432054328
RHS = DAG.getNode(ISD::TRUNCATE, DL, VT, RHS);
5432154329

5432254330
unsigned Opc = IsSigned ? ISD::MULHS : ISD::MULHU;
54323-
return DAG.getNode(Opc, DL, VT, LHS, RHS);
54331+
SDValue Res = DAG.getNode(Opc, DL, VT, LHS, RHS);
54332+
return DAG.getNode(ISD::SRL, DL, VT, Res,
54333+
DAG.getShiftAmountConstant(AdditionalShift, VT, DL));
5432454334
}
5432554335

5432654336
// Attempt to match PMADDUBSW, which multiplies corresponding unsigned bytes

0 commit comments

Comments
 (0)