Skip to content

Conversation

@davemgreen
Copy link
Collaborator

This patch adds recognition of high-half multiply by parts into a single larger
multiply.

Considering a multiply made up of high and low parts, we can split the
multiply into:

 x * y == (xh*T + xl) * (yh*T + yl)

where xh == x>>32 and xl == x & 0xffffffff. T = 2^32.
This expands to

 xh*yh*T*T + xh*yl*T + xl*yh*T + xl*yl

which I find it helpful to be drawn as

[  xh*yh  ]
     [  xh*yl  ]
     [  xl*yh  ]
          [  xl*yl  ]

We are looking for the "high" half, which is xhyh + xhyl>>32 + xl*yh>>32 +
carrys. The carry makes this difficult and there are multiple ways of
representing it. The ones we attempt to support here are:

 Carry:  xh*yh + carry + lowsum
         carry = lowsum < xh*yl ? 0x1000000 : 0
         lowsum = xh*yl + xl*yh + (xl*yl>>32)
 Ladder: xh*yh + c2>>32 + c3>>32
         c2 = xh*yl + (xl*yl >> 32); c3 = c2&0xffffffff + xl*yh
 Carry4: xh*yh + carry + crosssum>>32 + (xl*yl + crosssum&0xffffffff) >> 32
         crosssum = xh*yl + xl*yh
         carry = crosssum < xh*yl ? 0x1000000 : 0
 Ladder4: xh*yh + (xl*yh)>>32 + (xh*yl)>>32 + low>>32;
         low = (xl*yl)>>32 + (xl*yh)&0xffffffff + (xh*yl)&0xffffffff

They all start by matching xh*yh + 2 or 3 other operands. The bottom of the
tree is xh*yh, xh*yl, xl*yh and xl*yl.

Based on #156879 by @c-rhodes

@llvmbot
Copy link
Member

llvmbot commented Nov 17, 2025

@llvm/pr-subscribers-llvm-transforms

Author: David Green (davemgreen)

Changes

This patch adds recognition of high-half multiply by parts into a single larger
multiply.

Considering a multiply made up of high and low parts, we can split the
multiply into:

 x * y == (xh*T + xl) * (yh*T + yl)

where xh == x&gt;&gt;32 and xl == x &amp; 0xffffffff. T = 2^32.
This expands to

 xh*yh*T*T + xh*yl*T + xl*yh*T + xl*yl

which I find it helpful to be drawn as

[  xh*yh  ]
     [  xh*yl  ]
     [  xl*yh  ]
          [  xl*yl  ]

We are looking for the "high" half, which is xhyh + xhyl>>32 + xl*yh>>32 +
carrys. The carry makes this difficult and there are multiple ways of
representing it. The ones we attempt to support here are:

 Carry:  xh*yh + carry + lowsum
         carry = lowsum &lt; xh*yl ? 0x1000000 : 0
         lowsum = xh*yl + xl*yh + (xl*yl&gt;&gt;32)
 Ladder: xh*yh + c2&gt;&gt;32 + c3&gt;&gt;32
         c2 = xh*yl + (xl*yl &gt;&gt; 32); c3 = c2&amp;0xffffffff + xl*yh
 Carry4: xh*yh + carry + crosssum&gt;&gt;32 + (xl*yl + crosssum&amp;0xffffffff) &gt;&gt; 32
         crosssum = xh*yl + xl*yh
         carry = crosssum &lt; xh*yl ? 0x1000000 : 0
 Ladder4: xh*yh + (xl*yh)&gt;&gt;32 + (xh*yl)&gt;&gt;32 + low&gt;&gt;32;
         low = (xl*yl)&gt;&gt;32 + (xl*yh)&amp;0xffffffff + (xh*yl)&amp;0xffffffff

They all start by matching xh*yh + 2 or 3 other operands. The bottom of the
tree is xh*yh, xh*yl, xl*yh and xl*yl.

Based on #156879 by @c-rhodes


Patch is 226.74 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/168396.diff

5 Files Affected:

  • (modified) llvm/lib/Transforms/AggressiveInstCombine/AggressiveInstCombine.cpp (+301)
  • (added) llvm/test/Transforms/AggressiveInstCombine/umulh_carry.ll (+755)
  • (added) llvm/test/Transforms/AggressiveInstCombine/umulh_carry4.ll (+3019)
  • (added) llvm/test/Transforms/AggressiveInstCombine/umulh_ladder.ll (+818)
  • (added) llvm/test/Transforms/AggressiveInstCombine/umulh_ladder4.ll (+530)
diff --git a/llvm/lib/Transforms/AggressiveInstCombine/AggressiveInstCombine.cpp b/llvm/lib/Transforms/AggressiveInstCombine/AggressiveInstCombine.cpp
index b575d76e897d2..fb71f57eaa502 100644
--- a/llvm/lib/Transforms/AggressiveInstCombine/AggressiveInstCombine.cpp
+++ b/llvm/lib/Transforms/AggressiveInstCombine/AggressiveInstCombine.cpp
@@ -1466,6 +1466,306 @@ static bool foldLibCalls(Instruction &I, TargetTransformInfo &TTI,
   return false;
 }
 
+/// Match high part of long multiplication.
+///
+/// Considering a multiply made up of high and low parts, we can split the
+/// multiply into:
+///  x * y == (xh*T + xl) * (yh*T + yl)
+/// where xh == x>>32 and xl == x & 0xffffffff. T = 2^32.
+/// This expands to
+///  xh*yh*T*T + xh*yl*T + xl*yh*T + xl*yl
+/// which can be drawn as
+/// [  xh*yh  ]
+///      [  xh*yl  ]
+///      [  xl*yh  ]
+///           [  xl*yl  ]
+/// We are looking for the "high" half, which is xh*yh + xh*yl>>32 + xl*yh>>32 +
+/// some carrys. The carry makes this difficult and there are multiple ways of
+/// representing it. The ones we attempt to support here are:
+///  Carry:  xh*yh + carry + lowsum
+///          carry = lowsum < xh*yl ? 0x1000000 : 0
+///          lowsum = xh*yl + xl*yh + (xl*yl>>32)
+///  Ladder: xh*yh + c2>>32 + c3>>32
+///          c2 = xh*yl + (xl*yl>>32); c3 = c2&0xffffffff + xl*yh
+///  Carry4: xh*yh + carry + crosssum>>32 + (xl*yl + crosssum&0xffffffff) >> 32
+///          crosssum = xh*yl + xl*yh
+///          carry = crosssum < xh*yl ? 0x1000000 : 0
+///  Ladder4: xh*yh + (xl*yh)>>32 + (xh*yl)>>32 + low>>32;
+///          low = (xl*yl)>>32 + (xl*yh)&0xffffffff + (xh*yl)&0xffffffff
+///
+/// They all start by matching xh*yh + 2 or 3 other operands. The bottom of the
+/// tree is xh*yh, xh*yl, xl*yh and xl*yl.
+static bool foldMulHigh(Instruction &I) {
+  Type *Ty = I.getType();
+  if (!Ty->isIntOrIntVectorTy())
+    return false;
+
+  unsigned BW = Ty->getScalarSizeInBits();
+  APInt LowMask = APInt::getLowBitsSet(BW, BW / 2);
+  if (BW % 2 != 0)
+    return false;
+
+  auto CreateMulHigh = [&](Value *X, Value *Y) {
+    IRBuilder<> Builder(&I);
+    Type *NTy = Ty->getWithNewBitWidth(BW * 2);
+    Value *XExt = Builder.CreateZExt(X, NTy);
+    Value *YExt = Builder.CreateZExt(Y, NTy);
+    Value *Mul = Builder.CreateMul(XExt, YExt);
+    Value *High = Builder.CreateLShr(Mul, BW);
+    Value *Res = Builder.CreateTrunc(High, Ty);
+    I.replaceAllUsesWith(Res);
+    LLVM_DEBUG(dbgs() << "Created long multiply from parts of " << *X << " and "
+                      << *Y << "\n");
+    return true;
+  };
+
+  // Common check routines for X_lo*Y_lo and X_hi*Y_lo
+  auto CheckLoLo = [&](Value *XlYl, Value *X, Value *Y) {
+    return match(XlYl, m_c_Mul(m_And(m_Specific(X), m_SpecificInt(LowMask)),
+                               m_And(m_Specific(Y), m_SpecificInt(LowMask))));
+  };
+  auto CheckHiLo = [&](Value *XhYl, Value *X, Value *Y) {
+    return match(XhYl, m_c_Mul(m_LShr(m_Specific(X), m_SpecificInt(BW / 2)),
+                               m_And(m_Specific(Y), m_SpecificInt(LowMask))));
+  };
+
+  auto foldMulHighCarry = [&](Value *X, Value *Y, Instruction *Carry,
+                              Instruction *B) {
+    // Looking for LowSum >> 32 and carry (select)
+    if (Carry->getOpcode() != Instruction::Select)
+      std::swap(Carry, B);
+
+    // Carry = LowSum < XhYl ? 0x100000000 : 0
+    CmpPredicate Pred;
+    Value *LowSum, *XhYl;
+    if (!match(Carry,
+               m_OneUse(m_Select(
+                   m_OneUse(m_ICmp(Pred, m_Value(LowSum), m_Value(XhYl))),
+                   m_SpecificInt(APInt(BW, 1) << BW / 2), m_SpecificInt(0)))) ||
+        Pred != ICmpInst::ICMP_ULT)
+      return false;
+
+    // XhYl can be Xh*Yl or Xl*Yh
+    if (!CheckHiLo(XhYl, X, Y)) {
+      if (CheckHiLo(XhYl, Y, X))
+        std::swap(X, Y);
+      else
+        return false;
+    }
+    if (XhYl->hasNUsesOrMore(3))
+      return false;
+
+    // B = LowSum >> 16
+    if (!match(B,
+               m_OneUse(m_LShr(m_Specific(LowSum), m_SpecificInt(BW / 2)))) ||
+        LowSum->hasNUsesOrMore(3))
+      return false;
+
+    // LowSum = XhYl + XlYh + XlYl>>32
+    Value *XlYh, *XlYl;
+    auto XlYlHi = m_LShr(m_Value(XlYl), m_SpecificInt(BW / 2));
+    if (!match(LowSum,
+               m_c_Add(m_Specific(XhYl),
+                       m_OneUse(m_c_Add(m_OneUse(m_Value(XlYh)), XlYlHi)))) &&
+        !match(LowSum, m_c_Add(m_OneUse(m_Value(XlYh)),
+                               m_OneUse(m_c_Add(m_Specific(XhYl), XlYlHi)))) &&
+        !match(LowSum,
+               m_c_Add(XlYlHi, m_OneUse(m_c_Add(m_Specific(XhYl),
+                                                m_OneUse(m_Value(XlYh)))))))
+      return false;
+
+    // Check XlYl and XlYh
+    if (!CheckLoLo(XlYl, X, Y))
+      return false;
+    if (!CheckHiLo(XlYh, Y, X))
+      return false;
+
+    return CreateMulHigh(X, Y);
+  };
+
+  auto foldMulHighLadder = [&](Value *X, Value *Y, Instruction *A,
+                               Instruction *B) {
+    //  xh*yh + c2>>32 + c3>>32
+    //  c2 = xh*yl + (xl*yl >> 32); c3 = c2&0xffffffff + xl*yh
+    Value *XlYh, *XhYl, *C2, *C3;
+    // Strip off the two expected shifts.
+    if (!match(A, m_LShr(m_Value(C2), m_SpecificInt(BW / 2))) ||
+        !match(B, m_LShr(m_Value(C3), m_SpecificInt(BW / 2))))
+      return false;
+
+    // Match c3 = c2&0xffffffff + xl*yh
+    if (!match(C3, m_c_Add(m_And(m_Specific(C2), m_SpecificInt(LowMask)),
+                           m_Value(XhYl))))
+      std::swap(C2, C3);
+    if (!match(C3,
+               m_c_Add(m_OneUse(m_And(m_Specific(C2), m_SpecificInt(LowMask))),
+                       m_Value(XhYl))) ||
+        !C3->hasOneUse() || C2->hasNUsesOrMore(3))
+      return false;
+
+    // Match c2 = xh*yl + (xl*yl >> 32)
+    Value *XlYl;
+    if (!match(C2, m_c_Add(m_LShr(m_Value(XlYl), m_SpecificInt(BW / 2)),
+                           m_Value(XlYh))))
+      return false;
+
+    // Match XhYl and XlYh - they can appear either way around.
+    if (!CheckHiLo(XlYh, Y, X))
+      std::swap(XlYh, XhYl);
+    if (!CheckHiLo(XlYh, Y, X))
+      return false;
+    if (!CheckHiLo(XhYl, X, Y))
+      return false;
+    if (!CheckLoLo(XlYl, X, Y))
+      return false;
+
+    return CreateMulHigh(X, Y);
+  };
+
+  auto foldMulHighLadder4 = [&](Value *X, Value *Y, Instruction *A,
+                                Instruction *B, Instruction *C) {
+    ///  Ladder4: xh*yh + (xl*yh)>>32 + (xh+yl)>>32 + low>>32;
+    ///           low = (xl*yl)>>32 + (xl*yh)&0xffffffff + (xh*yl)&0xffffffff
+
+    // Find A = Low >> 32 and B/C = XhYl>>32, XlYh>>32.
+    auto ShiftAdd = m_LShr(m_Add(m_Value(), m_Value()), m_SpecificInt(BW / 2));
+    if (!match(A, ShiftAdd))
+      std::swap(A, B);
+    if (!match(A, ShiftAdd))
+      std::swap(A, C);
+    Value *Low;
+    if (!match(A, m_LShr(m_OneUse(m_Value(Low)), m_SpecificInt(BW / 2))))
+      return false;
+
+    // Match B == XhYl>>32 and C == XlYh>>32
+    Value *XhYl, *XlYh;
+    if (!match(B, m_LShr(m_Value(XhYl), m_SpecificInt(BW / 2))) ||
+        !match(C, m_LShr(m_Value(XlYh), m_SpecificInt(BW / 2))))
+      return false;
+    if (!CheckHiLo(XhYl, X, Y))
+      std::swap(XhYl, XlYh);
+    if (!CheckHiLo(XhYl, X, Y) || XhYl->hasNUsesOrMore(3))
+      return false;
+    if (!CheckHiLo(XlYh, Y, X) || XlYh->hasNUsesOrMore(3))
+      return false;
+
+    // Match Low as XlYl>>32 + XhYl&0xffffffff + XlYh&0xffffffff
+    Value *XlYl;
+    if (!match(
+            Low,
+            m_c_Add(
+                m_OneUse(m_c_Add(
+                    m_OneUse(m_And(m_Specific(XhYl), m_SpecificInt(LowMask))),
+                    m_OneUse(m_And(m_Specific(XlYh), m_SpecificInt(LowMask))))),
+                m_OneUse(m_LShr(m_Value(XlYl), m_SpecificInt(BW / 2))))) &&
+        !match(
+            Low,
+            m_c_Add(
+                m_OneUse(m_c_Add(
+                    m_OneUse(m_And(m_Specific(XhYl), m_SpecificInt(LowMask))),
+                    m_OneUse(m_LShr(m_Value(XlYl), m_SpecificInt(BW / 2))))),
+                m_OneUse(m_And(m_Specific(XlYh), m_SpecificInt(LowMask))))) &&
+        !match(
+            Low,
+            m_c_Add(
+                m_OneUse(m_c_Add(
+                    m_OneUse(m_And(m_Specific(XlYh), m_SpecificInt(LowMask))),
+                    m_OneUse(m_LShr(m_Value(XlYl), m_SpecificInt(BW / 2))))),
+                m_OneUse(m_And(m_Specific(XhYl), m_SpecificInt(LowMask))))))
+      return false;
+    if (!CheckLoLo(XlYl, X, Y))
+      return false;
+
+    return CreateMulHigh(X, Y);
+  };
+
+  auto foldMulHighCarry4 = [&](Value *X, Value *Y, Instruction *Carry,
+                               Instruction *B, Instruction *C) {
+    //  xh*yh + carry + crosssum>>32 + (xl*yl + crosssum&0xffffffff) >> 32
+    //  crosssum = xh*yl+xl*yh
+    //  carry = crosssum < xh*yl ? 0x1000000 : 0
+    if (Carry->getOpcode() != Instruction::Select)
+      std::swap(Carry, B);
+    if (Carry->getOpcode() != Instruction::Select)
+      std::swap(Carry, C);
+
+    // Carry = CrossSum < XhYl ? 0x100000000 : 0
+    CmpPredicate Pred;
+    Value *CrossSum, *XhYl;
+    if (!match(Carry,
+               m_OneUse(m_Select(
+                   m_OneUse(m_ICmp(Pred, m_Value(CrossSum), m_Value(XhYl))),
+                   m_SpecificInt(APInt(BW, 1) << BW / 2), m_SpecificInt(0)))) ||
+        Pred != ICmpInst::ICMP_ULT)
+      return false;
+
+    if (!match(B, m_LShr(m_Specific(CrossSum), m_SpecificInt(BW / 2))))
+      std::swap(B, C);
+    if (!match(B, m_LShr(m_Specific(CrossSum), m_SpecificInt(BW / 2))))
+      return false;
+
+    Value *XlYl, *LowAccum;
+    if (!match(C, m_LShr(m_Value(LowAccum), m_SpecificInt(BW / 2))) ||
+        !match(LowAccum,
+               m_c_Add(m_OneUse(m_LShr(m_Value(XlYl), m_SpecificInt(BW / 2))),
+                       m_OneUse(m_And(m_Specific(CrossSum),
+                                      m_SpecificInt(LowMask))))) ||
+        LowAccum->hasNUsesOrMore(3))
+      return false;
+    if (!CheckLoLo(XlYl, X, Y))
+      return false;
+
+    if (!CheckHiLo(XhYl, X, Y))
+      std::swap(X, Y);
+    if (!CheckHiLo(XhYl, X, Y))
+      return false;
+    if (!match(CrossSum,
+               m_c_Add(m_Specific(XhYl),
+                       m_OneUse(m_c_Mul(
+                           m_LShr(m_Specific(Y), m_SpecificInt(BW / 2)),
+                           m_And(m_Specific(X), m_SpecificInt(LowMask)))))) ||
+        CrossSum->hasNUsesOrMore(4) || XhYl->hasNUsesOrMore(3))
+      return false;
+
+    return CreateMulHigh(X, Y);
+  };
+
+  // X and Y are the two inputs, A, B and C are other parts of the pattern
+  // (crosssum>>32, carry, etc).
+  Value *X, *Y;
+  Instruction *A, *B, *C;
+  auto HiHi = m_OneUse(m_Mul(m_LShr(m_Value(X), m_SpecificInt(BW / 2)),
+                             m_LShr(m_Value(Y), m_SpecificInt(BW / 2))));
+  if ((match(&I, m_c_Add(HiHi, m_OneUse(m_Add(m_Instruction(A),
+                                              m_Instruction(B))))) ||
+       match(&I, m_c_Add(m_Instruction(A),
+                         m_OneUse(m_c_Add(HiHi, m_Instruction(B)))))) &&
+      A->hasOneUse() && B->hasOneUse())
+    if (foldMulHighCarry(X, Y, A, B) || foldMulHighLadder(X, Y, A, B))
+      return true;
+
+  if ((match(&I, m_c_Add(HiHi, m_OneUse(m_c_Add(
+                                   m_Instruction(A),
+                                   m_OneUse(m_Add(m_Instruction(B),
+                                                  m_Instruction(C))))))) ||
+       match(&I, m_c_Add(m_Instruction(A),
+                         m_OneUse(m_c_Add(
+                             HiHi, m_OneUse(m_Add(m_Instruction(B),
+                                                  m_Instruction(C))))))) ||
+       match(&I, m_c_Add(m_Instruction(A),
+                         m_OneUse(m_c_Add(
+                             m_Instruction(B),
+                             m_OneUse(m_c_Add(HiHi, m_Instruction(C))))))) ||
+       match(&I,
+             m_c_Add(m_OneUse(m_c_Add(HiHi, m_Instruction(A))),
+                     m_OneUse(m_Add(m_Instruction(B), m_Instruction(C)))))) &&
+      A->hasOneUse() && B->hasOneUse() && C->hasOneUse())
+    return foldMulHighCarry4(X, Y, A, B, C) ||
+           foldMulHighLadder4(X, Y, A, B, C);
+
+  return false;
+}
+
 /// This is the entry point for folds that could be implemented in regular
 /// InstCombine, but they are separated because they are not expected to
 /// occur frequently and/or have more than a constant-length pattern match.
@@ -1495,6 +1795,7 @@ static bool foldUnusualPatterns(Function &F, DominatorTree &DT,
       MadeChange |= foldConsecutiveLoads(I, DL, TTI, AA, DT);
       MadeChange |= foldPatternedLoads(I, DL);
       MadeChange |= foldICmpOrChain(I, DL, TTI, AA, DT);
+      MadeChange |= foldMulHigh(I);
       // NOTE: This function introduces erasing of the instruction `I`, so it
       // needs to be called at the end of this sequence, otherwise we may make
       // bugs.
diff --git a/llvm/test/Transforms/AggressiveInstCombine/umulh_carry.ll b/llvm/test/Transforms/AggressiveInstCombine/umulh_carry.ll
new file mode 100644
index 0000000000000..b78095cac0df9
--- /dev/null
+++ b/llvm/test/Transforms/AggressiveInstCombine/umulh_carry.ll
@@ -0,0 +1,755 @@
+; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 5
+; RUN: opt < %s -passes=aggressive-instcombine,instcombine -S | FileCheck %s
+
+; Carry variant of mul-high. https://alive2.llvm.org/ce/z/G2bD6o
+define i32 @mul_carry(i32 %x, i32 %y) {
+; CHECK-LABEL: define i32 @mul_carry(
+; CHECK-SAME: i32 [[X:%.*]], i32 [[Y:%.*]]) {
+; CHECK-NEXT:  [[ENTRY:.*:]]
+; CHECK-NEXT:    [[TMP0:%.*]] = zext i32 [[X]] to i64
+; CHECK-NEXT:    [[TMP1:%.*]] = zext i32 [[Y]] to i64
+; CHECK-NEXT:    [[TMP2:%.*]] = mul nuw i64 [[TMP0]], [[TMP1]]
+; CHECK-NEXT:    [[TMP3:%.*]] = lshr i64 [[TMP2]], 32
+; CHECK-NEXT:    [[ADD11:%.*]] = trunc nuw i64 [[TMP3]] to i32
+; CHECK-NEXT:    ret i32 [[ADD11]]
+;
+entry:
+  %shr = lshr i32 %x, 16
+  %and = and i32 %x, 65535
+  %shr1 = lshr i32 %y, 16
+  %and2 = and i32 %y, 65535
+  %mul = mul nuw i32 %shr, %and2
+  %mul3 = mul nuw i32 %and, %shr1
+  %add = add i32 %mul, %mul3
+  %mul4 = mul nuw i32 %and, %and2
+  %shr5 = lshr i32 %mul4, 16
+  %add6 = add i32 %add, %shr5
+  %cmp = icmp ult i32 %add6, %mul
+  %cond = select i1 %cmp, i32 65536, i32 0
+  %mul8 = mul nuw i32 %shr, %shr1
+  %add9 = add nuw i32 %mul8, %cond
+  %shr10 = lshr i32 %add6, 16
+  %add11 = add i32 %add9, %shr10
+  ret i32 %add11
+}
+
+; Carry variant of mul-high. https://alive2.llvm.org/ce/z/G2bD6o
+define i128 @mul_carry_i128(i128 %x, i128 %y) {
+; CHECK-LABEL: define i128 @mul_carry_i128(
+; CHECK-SAME: i128 [[X:%.*]], i128 [[Y:%.*]]) {
+; CHECK-NEXT:  [[ENTRY:.*:]]
+; CHECK-NEXT:    [[TMP0:%.*]] = zext i128 [[X]] to i256
+; CHECK-NEXT:    [[TMP1:%.*]] = zext i128 [[Y]] to i256
+; CHECK-NEXT:    [[TMP2:%.*]] = mul nuw i256 [[TMP0]], [[TMP1]]
+; CHECK-NEXT:    [[TMP3:%.*]] = lshr i256 [[TMP2]], 128
+; CHECK-NEXT:    [[ADD11:%.*]] = trunc nuw i256 [[TMP3]] to i128
+; CHECK-NEXT:    ret i128 [[ADD11]]
+;
+entry:
+  %shr = lshr i128 %x, 64
+  %and = and i128 %x, u0xffffffffffffffff
+  %shr1 = lshr i128 %y, 64
+  %and2 = and i128 %y, u0xffffffffffffffff
+  %mul = mul nuw i128 %shr, %and2
+  %mul3 = mul nuw i128 %and, %shr1
+  %add = add i128 %mul, %mul3
+  %mul4 = mul nuw i128 %and, %and2
+  %shr5 = lshr i128 %mul4, 64
+  %add6 = add i128 %add, %shr5
+  %cmp = icmp ult i128 %add6, %mul
+  %cond = select i1 %cmp, i128 u0x10000000000000000, i128 0
+  %mul8 = mul nuw i128 %shr, %shr1
+  %add9 = add nuw i128 %mul8, %cond
+  %shr10 = lshr i128 %add6, 64
+  %add11 = add i128 %add9, %shr10
+  ret i128 %add11
+}
+
+; Carry variant of mul-high. https://alive2.llvm.org/ce/z/G2bD6o
+define <4 x i32> @mul_carry_v4i32(<4 x i32> %x, <4 x i32> %y) {
+; CHECK-LABEL: define <4 x i32> @mul_carry_v4i32(
+; CHECK-SAME: <4 x i32> [[X:%.*]], <4 x i32> [[Y:%.*]]) {
+; CHECK-NEXT:  [[ENTRY:.*:]]
+; CHECK-NEXT:    [[TMP0:%.*]] = zext <4 x i32> [[X]] to <4 x i64>
+; CHECK-NEXT:    [[TMP1:%.*]] = zext <4 x i32> [[Y]] to <4 x i64>
+; CHECK-NEXT:    [[TMP2:%.*]] = mul nuw <4 x i64> [[TMP0]], [[TMP1]]
+; CHECK-NEXT:    [[TMP3:%.*]] = lshr <4 x i64> [[TMP2]], splat (i64 32)
+; CHECK-NEXT:    [[ADD11:%.*]] = trunc nuw <4 x i64> [[TMP3]] to <4 x i32>
+; CHECK-NEXT:    ret <4 x i32> [[ADD11]]
+;
+entry:
+  %shr = lshr <4 x i32> %x, <i32 16, i32 16, i32 16, i32 16>
+  %and = and <4 x i32> %x, <i32 65535, i32 65535, i32 65535, i32 65535>
+  %shr1 = lshr <4 x i32> %y, <i32 16, i32 16, i32 16, i32 16>
+  %and2 = and <4 x i32> %y, <i32 65535, i32 65535, i32 65535, i32 65535>
+  %mul = mul nuw <4 x i32> %shr, %and2
+  %mul3 = mul nuw <4 x i32> %and, %shr1
+  %add = add <4 x i32> %mul, %mul3
+  %mul4 = mul nuw <4 x i32> %and, %and2
+  %shr5 = lshr <4 x i32> %mul4, <i32 16, i32 16, i32 16, i32 16>
+  %add6 = add <4 x i32> %add, %shr5
+  %cmp = icmp ult <4 x i32> %add6, %mul
+  %cond = select <4 x i1> %cmp, <4 x i32> <i32 65536, i32 65536, i32 65536, i32 65536>, <4 x i32> zeroinitializer
+  %mul8 = mul nuw <4 x i32> %shr, %shr1
+  %add9 = add nuw <4 x i32> %mul8, %cond
+  %shr10 = lshr <4 x i32> %add6, <i32 16, i32 16, i32 16, i32 16>
+  %add11 = add <4 x i32> %add9, %shr10
+  ret <4 x i32> %add11
+}
+
+; Check carry against xlyh, not xhyl
+define i32 @mul_carry_xlyh(i32 %x, i32 %y) {
+; CHECK-LABEL: define i32 @mul_carry_xlyh(
+; CHECK-SAME: i32 [[X:%.*]], i32 [[Y:%.*]]) {
+; CHECK-NEXT:  [[ENTRY:.*:]]
+; CHECK-NEXT:    [[TMP0:%.*]] = zext i32 [[Y]] to i64
+; CHECK-NEXT:    [[TMP1:%.*]] = zext i32 [[X]] to i64
+; CHECK-NEXT:    [[TMP2:%.*]] = mul nuw i64 [[TMP0]], [[TMP1]]
+; CHECK-NEXT:    [[TMP3:%.*]] = lshr i64 [[TMP2]], 32
+; CHECK-NEXT:    [[ADD11:%.*]] = trunc nuw i64 [[TMP3]] to i32
+; CHECK-NEXT:    ret i32 [[ADD11]]
+;
+entry:
+  %shr = lshr i32 %x, 16
+  %and = and i32 %x, 65535
+  %shr1 = lshr i32 %y, 16
+  %and2 = and i32 %y, 65535
+  %mul = mul nuw i32 %shr, %and2
+  %mul3 = mul nuw i32 %and, %shr1
+  %add = add i32 %mul, %mul3
+  %mul4 = mul nuw i32 %and, %and2
+  %shr5 = lshr i32 %mul4, 16
+  %add6 = add i32 %add, %shr5
+  %cmp = icmp ult i32 %add6, %mul3
+  %cond = select i1 %cmp, i32 65536, i32 0
+  %mul8 = mul nuw i32 %shr, %shr1
+  %add9 = add nuw i32 %mul8, %cond
+  %shr10 = lshr i32 %add6, 16
+  %add11 = add i32 %add9, %shr10
+  ret i32 %add11
+}
+
+define i32 @mul_carry_comm(i32 %x, i32 %y) {
+; CHECK-LABEL: define i32 @mul_carry_comm(
+; CHECK-SAME: i32 [[X:%.*]], i32 [[Y:%.*]]) {
+; CHECK-NEXT:  [[ENTRY:.*:]]
+; CHECK-NEXT:    [[TMP0:%.*]] = zext i32 [[X]] to i64
+; CHECK-NEXT:    [[TMP1:%.*]] = zext i32 [[Y]] to i64
+; CHECK-NEXT:    [[TMP2:%.*]] = mul nuw i64 [[TMP0]], [[TMP1]]
+; CHECK-NEXT:    [[TMP3:%.*]] = lshr i64 [[TMP2]], 32
+; CHECK-NEXT:    [[ADD11:%.*]] = trunc nuw i64 [[TMP3]] to i32
+; CHECK-NEXT:    ret i32 [[ADD11]]
+;
+entry:
+  %shr = lshr i32 %x, 16
+  %and = and i32 %x, 65535
+  %shr1 = lshr i32 %y, 16
+  %and2 = and i32 %y, 65535
+  %mul = mul nuw i32 %and2, %shr
+  %mul3 = mul nuw i32 %shr1, %and
+  %add = add i32 %mul3, %mul
+  %mul4 = mul nuw i32 %and, %and2
+  %shr5 = lshr i32 %mul4, 16
+  %add6 = add i32 %shr5, %add
+  %cmp = icmp ult i32 %add6, %mul
+  %cond = select i1 %cmp, i32 65536, i32 0
+  %mul8 = mul nuw i32 %shr, %shr1
+  %shr10 = lshr i32 %add6, 16
+  %add9 = add nuw i32 %cond, %shr10
+  %add11 = add i32 %add9, %mul8
+  ret i32 %add11
+}
+
+
+; Negative tests
+
+
+define i32 @mul_carry_notxlo(i32 %x, i32 %y) {
+; CHECK-LABEL: define i32 @mul_carry_notxlo(
+; CHECK-SAME: i32 [[X:%.*]], i32 [[Y:%.*]]) {
+; CHECK-NEXT:  [[ENTRY:.*:]]
+; CHECK-NEXT:    [[SHR:%.*]] = lshr i32 [[X]], 16
+; CHECK-NEXT:    [[AND:%.*]] = and i32 [[X]], 32767
+; CHECK-NEXT:    [[SHR1:%.*]] = lshr i32 [[Y]], 16
+; CHECK-NEXT:    [[AND2:%.*]] = and i32 [[Y]], 65535
+; CHECK-NEXT:    [[MUL:%.*]] = mul nuw i32 [[SHR]], [[AND2]]
+; CHECK-NEXT:    [[MUL3:%.*]] = mul nuw nsw i32 [[AND]], [[SHR1]]
+; CHECK-NEXT:    [[ADD:%.*]] = add i32 [[MUL]], [[MUL3]]
+; CHECK-NEXT:    [[MUL4:%.*]] = mul nuw nsw i32 [[AND]], [[AND2]]
+; CHECK-NEXT:    [[SHR5:%.*]] = lshr i32 [[MUL4]], 16
+; CHECK-NEXT:    [[ADD6:%.*]] = add i32 [[ADD]], [[SHR5]]
+; CHECK-NEXT:    [[CMP:%.*]] = icmp ult i32 [[ADD6]], [[MUL]]
+; CHECK-NEXT:    [[COND:%.*]] = select i1 [[CMP]], i32 65536, i32 0
+; CHECK-NEXT:    [[MUL8:%.*]] = mul nuw i32 [[SHR]], [[SHR1]]
+; CHECK-NE...
[truncated]

@github-actions
Copy link

github-actions bot commented Nov 19, 2025

🐧 Linux x64 Test Results

  • 186421 tests passed
  • 4863 tests skipped

Type *NTy = Ty->getWithNewBitWidth(BW * 2);
Value *XExt = Builder.CreateZExt(X, NTy);
Value *YExt = Builder.CreateZExt(Y, NTy);
Value *Mul = Builder.CreateMul(XExt, YExt);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not sure if it is always profitable to use 2BW multiplication. If 2BW multiplication (or BW mulh) is not natively supported, we have to re-expand it during type legalization. But it seems that simpler IR results in shorter assembly: https://godbolt.org/z/sGvEWa8qz

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A i128->i256 variant actually came up in some of the testing I had ran (I was surprised) - it wasn't any faster or slower though when I tried it, just the same instructions in a different order. The i64->i128 is the really important one. Ideally the compiler should be able to optimize the result to something that is better or the same. In this case it looks like we don't realise we are multiplying by 0 if we convert straight to a libcall. I will put together a patch.

/// carry = lowsum < xh*yl ? 0x1000000 : 0
/// lowsum = xh*yl + xl*yh + (xl*yl>>32)
/// Ladder: xh*yh + c2>>32 + c3>>32
/// c2 = xh*yl + (xl*yl>>32); c3 = c2&0xffffffff + xl*yh
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here is another variant that is a bit different from this pattern: https://github.com/Cyan4973/xxHash/blob/136cc1f8fe4d5ea62a7c16c8424d4fa5158f6d68/xxhash.h#L4568-L4582

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's an interesting one - the graph looks like a simpler version of what I have called ladder. It has less cross-edges, but I've incorporated it into the logic of FoldMulHighLadder. It would be nice if some of these canonicalized together but they are different enough that it seems difficult without matching the whole tree again.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

looks like the same one as #60200

This patch adds recognition of high-half multiply by parts into a single larger
multiply.

Considering a multiply made up of high and low parts, we can split the
multiply into:
 x * y == (xh*T + xl) * (yh*T + yl)
where xh == x>>32 and xl == x & 0xffffffff. T = 2^32.
This expands to
 xh*yh*T*T + xh*yl*T + xl*yh*T + xl*yl
which I find it helpful to be drawn as
[  xh*yh  ]
     [  xh*yl  ]
     [  xl*yh  ]
          [  xl*yl  ]

We are looking for the "high" half, which is xh*yh + xh*yl>>32 + xl*yh>>32 +
carrys. The carry makes this difficult and there are multiple ways of
representing it. The ones we attempt to support here are:
 Carry:  xh*yh + carry + lowsum
         carry = lowsum < xh*yl ? 0x1000000 : 0
         lowsum = xh*yl + xl*yh + (xl*yl>>32)
 Ladder: xh*yh + c2>>32 + c3>>32
         c2 = xh*yl + (xl*yl >> 32); c3 = c2&0xffffffff + xl*yh
 Carry4: xh*yh + carry + crosssum>>32 + (xl*yl + crosssum&0xffffffff) >> 32
         crosssum = xh*yl + xl*yh
         carry = crosssum < xh*yl ? 0x1000000 : 0
 Ladder4: xh*yh + (xl*yh)>>32 + (xh*yl)>>32 + low>>32;
         low = (xl*yl)>>32 + (xl*yh)&0xffffffff + (xh*yl)&0xffffffff

They all start by matching xh*yh + 2 or 3 other operands. The bottom of the
tree is xh*yh, xh*yl, xl*yh and xl*yl.

Based on llvm#156879 by @c-rhodes
if (!Ty->isIntOrIntVectorTy())
return false;

unsigned BW = Ty->getScalarSizeInBits();
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: please spell it out

Suggested change
unsigned BW = Ty->getScalarSizeInBits();
unsigned BitWidth = Ty->getScalarSizeInBits();

/// carry = lowsum < xh*yl ? 0x1000000 : 0
/// lowsum = xh*yl + xl*yh + (xl*yl>>32)
/// Ladder: xh*yh + c2>>32 + c3>>32
/// c2 = xh*yl + (xl*yl>>32); c3 = c2&0xffffffff + xl*yh
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

looks like the same one as #60200

Type *NTy = Ty->getWithNewBitWidth(BW * 2);
Value *XExt = Builder.CreateZExt(X, NTy);
Value *YExt = Builder.CreateZExt(Y, NTy);
Value *Mul = Builder.CreateMul(XExt, YExt, "", true);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
Value *Mul = Builder.CreateMul(XExt, YExt, "", true);
Value *Mul = Builder.CreateMul(XExt, YExt, "", /*HasNUW=*/true);

same below

Copy link
Member

@dtcxzyw dtcxzyw left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM w/ nits fixed. Thank you!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants