Skip to content

Commit df60a87

Browse files
[X86] combinePMULH - combine mulhu + srl
Fixes #132166
1 parent 2909c42 commit df60a87

File tree

2 files changed

+34
-4
lines changed

2 files changed

+34
-4
lines changed

llvm/lib/Target/X86/X86ISelLowering.cpp

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -54021,7 +54021,7 @@ static SDValue combineTruncatedArithmetic(SDNode *N, SelectionDAG &DAG,
5402154021
}
5402254022

5402354023
// Try to form a MULHU or MULHS node by looking for
54024-
// (trunc (srl (mul ext, ext), 16))
54024+
// (trunc (srl (mul ext, ext), >= 16))
5402554025
// TODO: This is X86 specific because we want to be able to handle wide types
5402654026
// before type legalization. But we can only do it if the vector will be
5402754027
// legalized via widening/splitting. Type legalization can't handle promotion
@@ -54046,10 +54046,16 @@ static SDValue combinePMULH(SDValue Src, EVT VT, const SDLoc &DL,
5404654046

5404754047
// First instruction should be a right shift by 16 of a multiply.
5404854048
SDValue LHS, RHS;
54049+
APInt ShiftAmt;
5404954050
if (!sd_match(Src,
54050-
m_Srl(m_Mul(m_Value(LHS), m_Value(RHS)), m_SpecificInt(16))))
54051+
m_Srl(m_Mul(m_Value(LHS), m_Value(RHS)), m_ConstInt(ShiftAmt))))
54052+
return SDValue();
54053+
54054+
if (ShiftAmt.ult(16))
5405154055
return SDValue();
5405254056

54057+
APInt AdditionalShift = (ShiftAmt - 16).trunc(16);
54058+
5405354059
// Count leading sign/zero bits on both inputs - if there are enough then
5405454060
// truncation back to vXi16 will be cheap - either as a pack/shuffle
5405554061
// sequence or using AVX512 truncations. If the inputs are sext/zext then the
@@ -54087,15 +54093,19 @@ static SDValue combinePMULH(SDValue Src, EVT VT, const SDLoc &DL,
5408754093
InVT.getSizeInBits() / 16);
5408854094
SDValue Res = DAG.getNode(ISD::MULHU, DL, BCVT, DAG.getBitcast(BCVT, LHS),
5408954095
DAG.getBitcast(BCVT, RHS));
54090-
return DAG.getNode(ISD::TRUNCATE, DL, VT, DAG.getBitcast(InVT, Res));
54096+
Res = DAG.getNode(ISD::TRUNCATE, DL, VT, DAG.getBitcast(InVT, Res));
54097+
return DAG.getNode(ISD::SRL, DL, VT, Res,
54098+
DAG.getConstant(AdditionalShift, DL, VT));
5409154099
}
5409254100

5409354101
// Truncate back to source type.
5409454102
LHS = DAG.getNode(ISD::TRUNCATE, DL, VT, LHS);
5409554103
RHS = DAG.getNode(ISD::TRUNCATE, DL, VT, RHS);
5409654104

5409754105
unsigned Opc = IsSigned ? ISD::MULHS : ISD::MULHU;
54098-
return DAG.getNode(Opc, DL, VT, LHS, RHS);
54106+
SDValue Res = DAG.getNode(Opc, DL, VT, LHS, RHS);
54107+
return DAG.getNode(ISD::SRL, DL, VT, Res,
54108+
DAG.getConstant(AdditionalShift, DL, VT));
5409954109
}
5410054110

5410154111
// Attempt to match PMADDUBSW, which multiplies corresponding unsigned bytes

llvm/test/CodeGen/X86/pmulh.ll

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2166,3 +2166,23 @@ define <8 x i16> @sse2_pmulhu_w_const(<8 x i16> %a0, <8 x i16> %a1) {
21662166
}
21672167
declare <8 x i16> @llvm.x86.sse2.pmulhu.w(<8 x i16>, <8 x i16>)
21682168

2169+
define <8 x i16> @mul_and_shift17(<8 x i16> %a, <8 x i16> %b) {
2170+
; SSE-LABEL: mul_and_shift17:
2171+
; SSE: # %bb.0:
2172+
; SSE-NEXT: pmulhuw %xmm1, %xmm0
2173+
; SSE-NEXT: psrlw $1, %xmm0
2174+
; SSE-NEXT: retq
2175+
;
2176+
; AVX-LABEL: mul_and_shift17:
2177+
; AVX: # %bb.0:
2178+
; AVX-NEXT: vpmulhuw %xmm1, %xmm0, %xmm0
2179+
; AVX-NEXT: vpsrlw $1, %xmm0, %xmm0
2180+
; AVX-NEXT: retq
2181+
%a.ext = zext <8 x i16> %a to <8 x i32>
2182+
%b.ext = zext <8 x i16> %b to <8 x i32>
2183+
%mul = mul <8 x i32> %a.ext, %b.ext
2184+
%shift = lshr <8 x i32> %mul, splat(i32 17)
2185+
%trunc = trunc <8 x i32> %shift to <8 x i16>
2186+
ret <8 x i16> %trunc
2187+
}
2188+

0 commit comments

Comments
 (0)