Skip to content

Conversation

@abhishek-kaushik22
Copy link
Contributor

Fixes #132166

@llvmbot
Copy link
Member

llvmbot commented Mar 22, 2025

@llvm/pr-subscribers-backend-x86

Author: Abhishek Kaushik (abhishek-kaushik22)

Changes

Fixes #132166


Full diff: https://github.com/llvm/llvm-project/pull/132548.diff

2 Files Affected:

  • (modified) llvm/lib/Target/X86/X86ISelLowering.cpp (+14-4)
  • (modified) llvm/test/CodeGen/X86/pmulh.ll (+20)
diff --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp
index 02398923ebc90..ec0af8d53b76e 100644
--- a/llvm/lib/Target/X86/X86ISelLowering.cpp
+++ b/llvm/lib/Target/X86/X86ISelLowering.cpp
@@ -54021,7 +54021,7 @@ static SDValue combineTruncatedArithmetic(SDNode *N, SelectionDAG &DAG,
 }
 
 // Try to form a MULHU or MULHS node by looking for
-// (trunc (srl (mul ext, ext), 16))
+// (trunc (srl (mul ext, ext), >= 16))
 // TODO: This is X86 specific because we want to be able to handle wide types
 // before type legalization. But we can only do it if the vector will be
 // legalized via widening/splitting. Type legalization can't handle promotion
@@ -54046,10 +54046,16 @@ static SDValue combinePMULH(SDValue Src, EVT VT, const SDLoc &DL,
 
   // First instruction should be a right shift by 16 of a multiply.
   SDValue LHS, RHS;
+  APInt ShiftAmt;
   if (!sd_match(Src,
-                m_Srl(m_Mul(m_Value(LHS), m_Value(RHS)), m_SpecificInt(16))))
+                m_Srl(m_Mul(m_Value(LHS), m_Value(RHS)), m_ConstInt(ShiftAmt))))
+    return SDValue();
+
+  if (ShiftAmt.ult(16))
     return SDValue();
 
+  APInt AdditionalShift = (ShiftAmt - 16).trunc(16);
+
   // Count leading sign/zero bits on both inputs - if there are enough then
   // truncation back to vXi16 will be cheap - either as a pack/shuffle
   // sequence or using AVX512 truncations. If the inputs are sext/zext then the
@@ -54087,7 +54093,9 @@ static SDValue combinePMULH(SDValue Src, EVT VT, const SDLoc &DL,
                                 InVT.getSizeInBits() / 16);
     SDValue Res = DAG.getNode(ISD::MULHU, DL, BCVT, DAG.getBitcast(BCVT, LHS),
                               DAG.getBitcast(BCVT, RHS));
-    return DAG.getNode(ISD::TRUNCATE, DL, VT, DAG.getBitcast(InVT, Res));
+    Res = DAG.getNode(ISD::TRUNCATE, DL, VT, DAG.getBitcast(InVT, Res));
+    return DAG.getNode(ISD::SRL, DL, VT, Res,
+                       DAG.getConstant(AdditionalShift, DL, VT));
   }
 
   // Truncate back to source type.
@@ -54095,7 +54103,9 @@ static SDValue combinePMULH(SDValue Src, EVT VT, const SDLoc &DL,
   RHS = DAG.getNode(ISD::TRUNCATE, DL, VT, RHS);
 
   unsigned Opc = IsSigned ? ISD::MULHS : ISD::MULHU;
-  return DAG.getNode(Opc, DL, VT, LHS, RHS);
+  SDValue Res = DAG.getNode(Opc, DL, VT, LHS, RHS);
+  return DAG.getNode(ISD::SRL, DL, VT, Res,
+                     DAG.getConstant(AdditionalShift, DL, VT));
 }
 
 // Attempt to match PMADDUBSW, which multiplies corresponding unsigned bytes
diff --git a/llvm/test/CodeGen/X86/pmulh.ll b/llvm/test/CodeGen/X86/pmulh.ll
index 300da68d9a3b3..8ecc3c1575367 100644
--- a/llvm/test/CodeGen/X86/pmulh.ll
+++ b/llvm/test/CodeGen/X86/pmulh.ll
@@ -2166,3 +2166,23 @@ define <8 x i16> @sse2_pmulhu_w_const(<8 x i16> %a0, <8 x i16> %a1) {
 }
 declare <8 x i16> @llvm.x86.sse2.pmulhu.w(<8 x i16>, <8 x i16>)
 
+define <8 x i16> @mul_and_shift17(<8 x i16> %a, <8 x i16> %b) {
+; SSE-LABEL: mul_and_shift17:
+; SSE:       # %bb.0:
+; SSE-NEXT:    pmulhuw %xmm1, %xmm0
+; SSE-NEXT:    psrlw $1, %xmm0
+; SSE-NEXT:    retq
+;
+; AVX-LABEL: mul_and_shift17:
+; AVX:       # %bb.0:
+; AVX-NEXT:    vpmulhuw %xmm1, %xmm0, %xmm0
+; AVX-NEXT:    vpsrlw $1, %xmm0, %xmm0
+; AVX-NEXT:    retq
+  %a.ext = zext <8 x i16> %a to <8 x i32>
+  %b.ext = zext <8 x i16> %b to <8 x i32>
+  %mul = mul <8 x i32> %a.ext, %b.ext
+  %shift = lshr <8 x i32> %mul, splat(i32 17)
+  %trunc = trunc <8 x i32> %shift to <8 x i16>
+  ret <8 x i16> %trunc
+}
+

m_Srl(m_Mul(m_Value(LHS), m_Value(RHS)), m_ConstInt(ShiftAmt))))
return SDValue();

if (ShiftAmt.ult(16))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe worth checking for ShiftAmt.uge(InVT.getScalarSizeInBits()) as well?

Copy link
Contributor Author

@abhishek-kaushik22 abhishek-kaushik22 Jul 13, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tried using a shift value of that size, but the DAG builder replaces that node with undef so the condition was not being hit

return DAG.getNode(Opc, DL, VT, LHS, RHS);
SDValue Res = DAG.getNode(Opc, DL, VT, LHS, RHS);
return DAG.getNode(ISD::SRL, DL, VT, Res,
DAG.getConstant(AdditionalShift, DL, VT));
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should IsSigned be ISD::SRA?

Use getShiftAmountConstant

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Using ISD::SRA gives wrong results

https://godbolt.org/z/3PYaPb1be

return SDValue();

APInt AdditionalShift = (ShiftAmt - 16).trunc(16);

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You need to do some testing to see what happens with MULHS - probably limit AdditionalShift != 0 to just the IsUnsigned case for starters?

@RKSimon
Copy link
Collaborator

RKSimon commented Jul 6, 2025

@abhishek-kaushik22 reverse ping

%shift = lshr <8 x i32> %mul, splat(i32 17)
%trunc = trunc <8 x i32> %shift to <8 x i16>
ret <8 x i16> %trunc
}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It'd be better if we tested a range of shift values as well as the shift by 17 from the original bug report - shifts by 24 and 31 will be interesting for instance.

Copy link
Collaborator

@RKSimon RKSimon left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

2 minors

@abhishek-kaushik22
Copy link
Contributor Author

@RKSimon can you please review again?

Copy link
Collaborator

@RKSimon RKSimon left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM - cheers

@abhishek-kaushik22 abhishek-kaushik22 merged commit 11eeb4d into llvm:main Aug 5, 2025
9 checks passed
@abhishek-kaushik22 abhishek-kaushik22 deleted the 132166 branch August 5, 2025 11:02
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[X86] mulhu + srl pattern not recognized

3 participants