-
Notifications
You must be signed in to change notification settings - Fork 15.3k
[AggressiveInstCombine] Fold i64 x i64 -> i128 multiply-by-parts #156879
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
|
@llvm/pr-subscribers-llvm-transforms Author: Cullen Rhodes (c-rhodes) ChangesThis patch adds patterns to recognize a full i64 x i64 -> i128 multiplication by 4 x i32 parts, folding it to a full 128-bit multiply. The low/high parts are implemented as independent patterns. There's also an additional pattern for the high part, both patterns have been seen in real code, and there's one more I'm aware of but I thought I'd share a patch first to see what people think before handling any further cases. On AArch64 the mul and umulh instructions can be used to efficiently compute the low/high parts. I also believe X86 can do the i128 mul in one instruction (returning both halves). So it seems like this is relatively common and could be a useful optimization for several targets. Patch is 104.30 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/156879.diff 2 Files Affected:
diff --git a/llvm/lib/Transforms/AggressiveInstCombine/AggressiveInstCombine.cpp b/llvm/lib/Transforms/AggressiveInstCombine/AggressiveInstCombine.cpp
index 40de36d81ddd2..cc9c341da032e 100644
--- a/llvm/lib/Transforms/AggressiveInstCombine/AggressiveInstCombine.cpp
+++ b/llvm/lib/Transforms/AggressiveInstCombine/AggressiveInstCombine.cpp
@@ -1428,6 +1428,259 @@ static bool foldLibCalls(Instruction &I, TargetTransformInfo &TTI,
return false;
}
+/// Match low part of 128-bit multiplication.
+static bool foldMul128Low(Instruction &I, const DataLayout &DL,
+ DominatorTree &DT) {
+ auto *Ty = I.getType();
+ if (!Ty->isIntegerTy(64))
+ return false;
+
+ // (low_accum << 32) | lo(lo(y) * lo(x))
+ Value *LowAccum = nullptr, *YLowXLow = nullptr;
+ if (!match(&I, m_c_DisjointOr(
+ m_OneUse(m_Shl(m_Value(LowAccum), m_SpecificInt(32))),
+ m_OneUse(
+ m_And(m_Value(YLowXLow), m_SpecificInt(0xffffffff))))))
+ return false;
+
+ // lo(cross_sum) + hi(lo(y) * lo(x))
+ Value *CrossSum = nullptr;
+ if (!match(
+ LowAccum,
+ m_c_Add(m_OneUse(m_And(m_Value(CrossSum), m_SpecificInt(0xffffffff))),
+ m_OneUse(m_LShr(m_Specific(YLowXLow), m_SpecificInt(32))))) ||
+ LowAccum->hasNUsesOrMore(3))
+ return false;
+
+ // (hi(y) * lo(x)) + (lo(y) * hi(x))
+ Value *YHigh = nullptr, *XLow = nullptr, *YLowXHigh = nullptr;
+ if (!match(CrossSum, m_c_Add(m_OneUse(m_c_Mul(m_Value(YHigh), m_Value(XLow))),
+ m_Value(YLowXHigh))) ||
+ CrossSum->hasNUsesOrMore(4))
+ return false;
+
+ // lo(y) * lo(x)
+ Value *YLow = nullptr;
+ if (!match(YLowXLow, m_c_Mul(m_Value(YLow), m_Specific(XLow))) ||
+ YLowXLow->hasNUsesOrMore(3))
+ return false;
+
+ // lo(y) * hi(x)
+ Value *XHigh = nullptr;
+ if (!match(YLowXHigh, m_c_Mul(m_Specific(YLow), m_Value(XHigh))) ||
+ !YLowXHigh->hasNUses(2))
+ return false;
+
+ Value *X = nullptr;
+ // lo(x) = x & 0xffffffff
+ if (!match(XLow, m_c_And(m_Value(X), m_SpecificInt(0xffffffff))) ||
+ !XLow->hasNUses(2))
+ return false;
+ // hi(x) = x >> 32
+ if (!match(XHigh, m_LShr(m_Specific(X), m_SpecificInt(32))) ||
+ !XHigh->hasNUses(2))
+ return false;
+
+ // Same for Y.
+ Value *Y = nullptr;
+ if (!match(YLow, m_c_And(m_Value(Y), m_SpecificInt(0xffffffff))) ||
+ !YLow->hasNUses(2))
+ return false;
+ if (!match(YHigh, m_LShr(m_Specific(Y), m_SpecificInt(32))) ||
+ !YHigh->hasNUses(2))
+ return false;
+
+ IRBuilder<> Builder(&I);
+ Value *XExt = Builder.CreateZExt(X, Builder.getInt128Ty());
+ Value *YExt = Builder.CreateZExt(Y, Builder.getInt128Ty());
+ Value *Mul128 = Builder.CreateMul(XExt, YExt);
+ Value *Res = Builder.CreateTrunc(Mul128, Builder.getInt64Ty());
+ I.replaceAllUsesWith(Res);
+
+ return true;
+}
+
+/// Match high part of 128-bit multiplication.
+static bool foldMul128High(Instruction &I, const DataLayout &DL,
+ DominatorTree &DT) {
+ auto *Ty = I.getType();
+ if (!Ty->isIntegerTy(64))
+ return false;
+
+ // intermediate_plus_carry + hi(low_accum)
+ Value *IntermediatePlusCarry = nullptr, *LowAccum = nullptr;
+ if (!match(&I,
+ m_c_Add(m_OneUse(m_Value(IntermediatePlusCarry)),
+ m_OneUse(m_LShr(m_Value(LowAccum), m_SpecificInt(32))))))
+ return false;
+
+ // match:
+ // (((hi(y) * hi(x)) + carry) + hi(cross_sum))
+ // or:
+ // ((hi(cross_sum) + (hi(y) * hi(x))) + carry)
+ CmpPredicate Pred;
+ Value *CrossSum = nullptr, *XHigh = nullptr, *YHigh = nullptr,
+ *Carry = nullptr;
+ if (!match(IntermediatePlusCarry,
+ m_c_Add(m_c_Add(m_OneUse(m_c_Mul(m_Value(YHigh), m_Value(XHigh))),
+ m_Value(Carry)),
+ m_OneUse(m_LShr(m_Value(CrossSum), m_SpecificInt(32))))) &&
+ !match(IntermediatePlusCarry,
+ m_c_Add(m_OneUse(m_c_Add(
+ m_OneUse(m_LShr(m_Value(CrossSum), m_SpecificInt(32))),
+ m_OneUse(m_c_Mul(m_Value(YHigh), m_Value(XHigh))))),
+ m_Value(Carry))))
+ return false;
+
+ // (select (icmp ult cross_sum, (lo(y) * hi(x))), (1 << 32), 0)
+ Value *YLowXHigh = nullptr;
+ if (!match(Carry,
+ m_OneUse(m_Select(m_OneUse(m_ICmp(Pred, m_Specific(CrossSum),
+ m_Value(YLowXHigh))),
+ m_SpecificInt(4294967296), m_SpecificInt(0)))) ||
+ Pred != ICmpInst::ICMP_ULT)
+ return false;
+
+ // (hi(y) * lo(x)) + (lo(y) * hi(x))
+ Value *XLow = nullptr;
+ if (!match(CrossSum,
+ m_c_Add(m_OneUse(m_c_Mul(m_Specific(YHigh), m_Value(XLow))),
+ m_Specific(YLowXHigh))) ||
+ CrossSum->hasNUsesOrMore(4))
+ return false;
+
+ // lo(y) * hi(x)
+ Value *YLow = nullptr;
+ if (!match(YLowXHigh, m_c_Mul(m_Value(YLow), m_Specific(XHigh))) ||
+ !YLowXHigh->hasNUses(2))
+ return false;
+
+ // lo(cross_sum) + hi(lo(y) * lo(x))
+ Value *YLowXLow = nullptr;
+ if (!match(LowAccum,
+ m_c_Add(m_OneUse(m_c_And(m_Specific(CrossSum),
+ m_SpecificInt(0xffffffff))),
+ m_OneUse(m_LShr(m_Value(YLowXLow), m_SpecificInt(32))))) ||
+ LowAccum->hasNUsesOrMore(3))
+ return false;
+
+ // lo(y) * lo(x)
+ //
+ // When only doing the high part there's a single use and 2 uses when doing
+ // full multiply. Given the low/high patterns are separate, it's non-trivial
+ // to vary the number of uses to check this, but applying the optimization
+ // when there's an unrelated use when only doing the high part still results
+ // in less instructions and is likely profitable, so an upper bound of 2 uses
+ // should be fine.
+ if (!match(YLowXLow, m_c_Mul(m_Specific(YLow), m_Specific(XLow))) ||
+ YLowXLow->hasNUsesOrMore(3))
+ return false;
+
+ Value *X = nullptr;
+ // lo(x) = x & 0xffffffff
+ if (!match(XLow, m_c_And(m_Value(X), m_SpecificInt(0xffffffff))) ||
+ !XLow->hasNUses(2))
+ return false;
+ // hi(x) = x >> 32
+ if (!match(XHigh, m_LShr(m_Specific(X), m_SpecificInt(32))) ||
+ !XHigh->hasNUses(2))
+ return false;
+
+ // Same for Y.
+ Value *Y = nullptr;
+ if (!match(YLow, m_c_And(m_Value(Y), m_SpecificInt(0xffffffff))) ||
+ !YLow->hasNUses(2))
+ return false;
+ if (!match(YHigh, m_LShr(m_Specific(Y), m_SpecificInt(32))) ||
+ !YHigh->hasNUses(2))
+ return false;
+
+ IRBuilder<> Builder(&I);
+ Value *XExt = Builder.CreateZExt(X, Builder.getInt128Ty());
+ Value *YExt = Builder.CreateZExt(Y, Builder.getInt128Ty());
+ Value *Mul128 = Builder.CreateMul(XExt, YExt);
+ Value *High = Builder.CreateLShr(Mul128, 64);
+ Value *Res = Builder.CreateTrunc(High, Builder.getInt64Ty());
+ I.replaceAllUsesWith(Res);
+
+ return true;
+}
+
+/// Match another variant of high part of 128-bit multiplication.
+///
+/// %t0 = mul nuw i64 %y_lo, %x_lo
+/// %t1 = mul nuw i64 %y_lo, %x_hi
+/// %t2 = mul nuw i64 %y_hi, %x_lo
+/// %t3 = mul nuw i64 %y_hi, %x_hi
+/// %t0_hi = lshr i64 %t0, 32
+/// %u0 = add nuw i64 %t0_hi, %t1
+/// %u0_lo = and i64 %u0, 4294967295
+/// %u0_hi = lshr i64 %u0, 32
+/// %u1 = add nuw i64 %u0_lo, %t2
+/// %u1_hi = lshr i64 %u1, 32
+/// %u2 = add nuw i64 %u0_hi, %t3
+/// %hw64 = add nuw i64 %u2, %u1_hi
+static bool foldMul128HighVariant(Instruction &I, const DataLayout &DL,
+ DominatorTree &DT) {
+ auto *Ty = I.getType();
+ if (!Ty->isIntegerTy(64))
+ return false;
+
+ // hw64 = (hi(u0) + (hi(y) * hi(x)) + (lo(u0) + (hi(y) * lo(x)) >> 32))
+ Value *U0 = nullptr, *XHigh = nullptr, *YHigh = nullptr, *XLow = nullptr;
+ if (!match(
+ &I,
+ m_c_Add(m_OneUse(m_c_Add(
+ m_OneUse(m_LShr(m_Value(U0), m_SpecificInt(32))),
+ m_OneUse(m_c_Mul(m_Value(YHigh), m_Value(XHigh))))),
+ m_OneUse(m_LShr(
+ m_OneUse(m_c_Add(
+ m_OneUse(m_c_And(m_Deferred(U0),
+ m_SpecificInt(0xffffffff))),
+ m_OneUse(m_c_Mul(m_Deferred(YHigh), m_Value(XLow))))),
+ m_SpecificInt(32))))))
+ return false;
+
+ // u0 = (hi(lo(y) * lo(x)) + (lo(y) * hi(x)))
+ Value *YLow = nullptr;
+ if (!match(U0,
+ m_c_Add(m_OneUse(m_LShr(
+ m_OneUse(m_c_Mul(m_Value(YLow), m_Specific(XLow))),
+ m_SpecificInt(32))),
+ m_OneUse(m_c_Mul(m_Deferred(YLow), m_Specific(XHigh))))) ||
+ !U0->hasNUses(2))
+ return false;
+
+ Value *X = nullptr;
+ // lo(x) = x & 0xffffffff
+ if (!match(XLow, m_c_And(m_Value(X), m_SpecificInt(0xffffffff))) ||
+ !XLow->hasNUses(2))
+ return false;
+ // hi(x) = x >> 32
+ if (!match(XHigh, m_LShr(m_Specific(X), m_SpecificInt(32))) ||
+ !XHigh->hasNUses(2))
+ return false;
+
+ // Same for Y.
+ Value *Y = nullptr;
+ if (!match(YLow, m_c_And(m_Value(Y), m_SpecificInt(0xffffffff))) ||
+ !YLow->hasNUses(2))
+ return false;
+ if (!match(YHigh, m_LShr(m_Specific(Y), m_SpecificInt(32))) ||
+ !YHigh->hasNUses(2))
+ return false;
+
+ IRBuilder<> Builder(&I);
+ Value *XExt = Builder.CreateZExt(X, Builder.getInt128Ty());
+ Value *YExt = Builder.CreateZExt(Y, Builder.getInt128Ty());
+ Value *Mul128 = Builder.CreateMul(XExt, YExt);
+ Value *High = Builder.CreateLShr(Mul128, 64);
+ Value *Res = Builder.CreateTrunc(High, Builder.getInt64Ty());
+ I.replaceAllUsesWith(Res);
+
+ return true;
+}
+
/// This is the entry point for folds that could be implemented in regular
/// InstCombine, but they are separated because they are not expected to
/// occur frequently and/or have more than a constant-length pattern match.
@@ -1457,6 +1710,9 @@ static bool foldUnusualPatterns(Function &F, DominatorTree &DT,
MadeChange |= foldConsecutiveLoads(I, DL, TTI, AA, DT);
MadeChange |= foldPatternedLoads(I, DL);
MadeChange |= foldICmpOrChain(I, DL, TTI, AA, DT);
+ MadeChange |= foldMul128Low(I, DL, DT);
+ MadeChange |= foldMul128High(I, DL, DT);
+ MadeChange |= foldMul128HighVariant(I, DL, DT);
// NOTE: This function introduces erasing of the instruction `I`, so it
// needs to be called at the end of this sequence, otherwise we may make
// bugs.
diff --git a/llvm/test/Transforms/AggressiveInstCombine/umulh.ll b/llvm/test/Transforms/AggressiveInstCombine/umulh.ll
new file mode 100644
index 0000000000000..7ffc86c2299ec
--- /dev/null
+++ b/llvm/test/Transforms/AggressiveInstCombine/umulh.ll
@@ -0,0 +1,2571 @@
+; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 5
+; RUN: opt < %s -passes=aggressive-instcombine -S | FileCheck %s
+
+; https://alive2.llvm.org/ce/z/KuJPnU
+define i64 @umulh(i64 %x, i64 %y) {
+; CHECK-LABEL: define i64 @umulh(
+; CHECK-SAME: i64 [[X:%.*]], i64 [[Y:%.*]]) {
+; CHECK-NEXT: [[TMP0:%.*]] = zext i64 [[X]] to i128
+; CHECK-NEXT: [[TMP1:%.*]] = zext i64 [[Y]] to i128
+; CHECK-NEXT: [[TMP2:%.*]] = mul i128 [[TMP0]], [[TMP1]]
+; CHECK-NEXT: [[TMP3:%.*]] = lshr i128 [[TMP2]], 64
+; CHECK-NEXT: [[TMP4:%.*]] = trunc i128 [[TMP3]] to i64
+; CHECK-NEXT: ret i64 [[TMP4]]
+;
+ ; Extract low and high 32 bits
+ %x_lo = and i64 %x, 4294967295 ; x & 0xffffffff
+ %y_lo = and i64 %y, 4294967295 ; y & 0xffffffff
+ %x_hi = lshr i64 %x, 32 ; x >> 32
+ %y_hi = lshr i64 %y, 32 ; y >> 32
+
+ ; Cross products
+ %y_lo_x_hi = mul nuw i64 %y_lo, %x_hi ; y_lo * x_hi
+ %y_hi_x_hi = mul nuw i64 %y_hi, %x_hi ; y_hi * x_hi
+ %y_hi_x_lo = mul nuw i64 %y_hi, %x_lo ; y_hi * x_lo
+ %y_lo_x_lo = mul nuw i64 %y_lo, %x_lo ; y_lo * x_lo
+
+ ; Add cross terms
+ %cross_sum = add i64 %y_hi_x_lo, %y_lo_x_hi ; full 64-bit sum
+
+ ; Carry if overflowed
+ %carry_out = icmp ult i64 %cross_sum, %y_lo_x_hi
+ %carry = select i1 %carry_out, i64 4294967296, i64 0 ; if overflow, add 1 << 32
+
+ ; High 32 bits of low product
+ %y_lo_x_lo_hi = lshr i64 %y_lo_x_lo, 32
+
+ ; Low and high 32 bits of cross_sum
+ %cross_sum_lo = and i64 %cross_sum, 4294967295
+ %cross_sum_hi = lshr i64 %cross_sum, 32
+
+ %low_accum = add nuw nsw i64 %cross_sum_lo, %y_lo_x_lo_hi
+
+ ; Final result accumulation
+ %intermediate = add nuw i64 %cross_sum_hi, %y_hi_x_hi
+ %low_accum_hi = lshr i64 %low_accum, 32
+ %intermediate_plus_carry = add i64 %intermediate, %carry
+ %hw64 = add i64 %intermediate_plus_carry, %low_accum_hi
+
+ ret i64 %hw64
+}
+
+; https://alive2.llvm.org/ce/z/MSo5S_
+define i64 @umulh_variant(i64 %x, i64 %y) {
+; CHECK-LABEL: define i64 @umulh_variant(
+; CHECK-SAME: i64 [[X:%.*]], i64 [[Y:%.*]]) {
+; CHECK-NEXT: [[TMP1:%.*]] = zext i64 [[X]] to i128
+; CHECK-NEXT: [[TMP2:%.*]] = zext i64 [[Y]] to i128
+; CHECK-NEXT: [[TMP3:%.*]] = mul i128 [[TMP1]], [[TMP2]]
+; CHECK-NEXT: [[TMP4:%.*]] = lshr i128 [[TMP3]], 64
+; CHECK-NEXT: [[TMP5:%.*]] = trunc i128 [[TMP4]] to i64
+; CHECK-NEXT: ret i64 [[TMP5]]
+;
+ %x_lo = and i64 %x, 4294967295
+ %y_lo = and i64 %y, 4294967295
+ %x_hi = lshr i64 %x, 32
+ %y_hi = lshr i64 %y, 32
+
+ %t0 = mul nuw i64 %y_lo, %x_lo
+ %t1 = mul nuw i64 %y_lo, %x_hi
+ %t2 = mul nuw i64 %y_hi, %x_lo
+ %t3 = mul nuw i64 %y_hi, %x_hi
+
+ %t0_hi = lshr i64 %t0, 32
+
+ %u0 = add nuw i64 %t0_hi, %t1
+ %u0_lo = and i64 %u0, 4294967295
+ %u0_hi = lshr i64 %u0, 32
+ %u1 = add nuw i64 %u0_lo, %t2
+ %u1_hi = lshr i64 %u1, 32
+ %u2 = add nuw i64 %u0_hi, %t3
+ %hw64 = add nuw i64 %u2, %u1_hi
+ ret i64 %hw64
+}
+
+; Commutative ops should match in any order. Ops where operand order has been
+; reversed from above are marked 'commuted'. As per instcombine contributors
+; guide, constants are always canonicalized to RHS, so don't both commuting
+; constants.
+define i64 @umulh__commuted(i64 %x, i64 %y) {
+; CHECK-LABEL: define i64 @umulh__commuted(
+; CHECK-SAME: i64 [[X:%.*]], i64 [[Y:%.*]]) {
+; CHECK-NEXT: [[TMP0:%.*]] = zext i64 [[X]] to i128
+; CHECK-NEXT: [[TMP1:%.*]] = zext i64 [[Y]] to i128
+; CHECK-NEXT: [[TMP2:%.*]] = mul i128 [[TMP0]], [[TMP1]]
+; CHECK-NEXT: [[TMP3:%.*]] = lshr i128 [[TMP2]], 64
+; CHECK-NEXT: [[TMP4:%.*]] = trunc i128 [[TMP3]] to i64
+; CHECK-NEXT: ret i64 [[TMP4]]
+;
+ ; Extract low and high 32 bits
+ %x_lo = and i64 %x, 4294967295
+ %y_lo = and i64 %y, 4294967295
+ %x_hi = lshr i64 %x, 32 ; x >> 32
+ %y_hi = lshr i64 %y, 32 ; y >> 32
+
+ ; Cross products
+ %y_lo_x_hi = mul nuw i64 %x_hi, %y_lo ; commuted
+ %y_hi_x_hi = mul nuw i64 %y_hi, %x_hi
+ %y_hi_x_lo = mul nuw i64 %x_lo, %y_hi ; commuted
+ %y_lo_x_lo = mul nuw i64 %x_lo, %y_lo ; commuted
+
+ ; Add cross terms
+ %cross_sum = add i64 %y_lo_x_hi, %y_hi_x_lo ; commuted
+
+ ; Carry if overflowed
+ %carry_out = icmp ult i64 %cross_sum, %y_lo_x_hi
+ %carry = select i1 %carry_out, i64 4294967296, i64 0 ; if overflow, add 1 << 32
+
+ ; High 32 bits of low product
+ %y_lo_x_lo_hi = lshr i64 %y_lo_x_lo, 32
+
+ ; Low and high 32 bits of cross_sum
+ %cross_sum_lo = and i64 4294967295, %cross_sum ; commuted
+ %cross_sum_hi = lshr i64 %cross_sum, 32
+
+ %low_accum = add nuw nsw i64 %y_lo_x_lo_hi, %cross_sum_lo ; commuted
+
+ ; Final result accumulation
+ %intermediate = add nuw i64 %y_hi_x_hi, %cross_sum_hi ; commuted
+ %low_accum_hi = lshr i64 %low_accum, 32
+ %intermediate_plus_carry = add i64 %carry, %intermediate ; commuted
+ %hw64 = add i64 %low_accum_hi, %intermediate_plus_carry ; commuted
+
+ ret i64 %hw64
+}
+
+define i64 @umulh_variant_commuted(i64 %x, i64 %y) {
+; CHECK-LABEL: define i64 @umulh_variant_commuted(
+; CHECK-SAME: i64 [[X:%.*]], i64 [[Y:%.*]]) {
+; CHECK-NEXT: [[TMP1:%.*]] = zext i64 [[X]] to i128
+; CHECK-NEXT: [[TMP2:%.*]] = zext i64 [[Y]] to i128
+; CHECK-NEXT: [[TMP3:%.*]] = mul i128 [[TMP1]], [[TMP2]]
+; CHECK-NEXT: [[TMP4:%.*]] = lshr i128 [[TMP3]], 64
+; CHECK-NEXT: [[TMP5:%.*]] = trunc i128 [[TMP4]] to i64
+; CHECK-NEXT: ret i64 [[TMP5]]
+;
+ %x_lo = and i64 %x, 4294967295
+ %y_lo = and i64 %y, 4294967295
+ %x_hi = lshr i64 %x, 32
+ %y_hi = lshr i64 %y, 32
+
+ %t0 = mul nuw i64 %x_lo, %y_lo ; commuted
+ %t1 = mul nuw i64 %x_hi, %y_lo ; commuted
+ %t2 = mul nuw i64 %x_lo, %y_hi ; commuted
+ %t3 = mul nuw i64 %y_hi, %x_hi
+
+ %t0_hi = lshr i64 %t0, 32
+
+ %u0 = add nuw i64 %t1, %t0_hi ; commuted
+ %u0_lo = and i64 4294967295, %u0 ; commuted
+ %u0_hi = lshr i64 %u0, 32
+ %u1 = add nuw i64 %t2, %u0_lo ; commuted
+ %u1_hi = lshr i64 %u1, 32
+ %u2 = add nuw i64 %t3, %u0_hi ; commuted
+ %hw64 = add nuw i64 %u1_hi, %u2 ; commuted
+ ret i64 %hw64
+}
+
+; https://alive2.llvm.org/ce/z/PPXtkR
+define void @full_mul_int128(i64 %x, i64 %y, ptr %p) {
+; CHECK-LABEL: define void @full_mul_int128(
+; CHECK-SAME: i64 [[X:%.*]], i64 [[Y:%.*]], ptr [[P:%.*]]) {
+; CHECK-NEXT: [[TMP0:%.*]] = zext i64 [[X]] to i128
+; CHECK-NEXT: [[TMP1:%.*]] = zext i64 [[Y]] to i128
+; CHECK-NEXT: [[TMP2:%.*]] = mul i128 [[TMP0]], [[TMP1]]
+; CHECK-NEXT: [[TMP3:%.*]] = lshr i128 [[TMP2]], 64
+; CHECK-NEXT: [[TMP4:%.*]] = trunc i128 [[TMP3]] to i64
+; CHECK-NEXT: [[HI_PTR:%.*]] = getelementptr inbounds i8, ptr [[P]], i64 8
+; CHECK-NEXT: store i64 [[TMP4]], ptr [[HI_PTR]], align 8
+; CHECK-NEXT: [[TMP5:%.*]] = zext i64 [[X]] to i128
+; CHECK-NEXT: [[TMP6:%.*]] = zext i64 [[Y]] to i128
+; CHECK-NEXT: [[TMP7:%.*]] = mul i128 [[TMP5]], [[TMP6]]
+; CHECK-NEXT: [[TMP8:%.*]] = trunc i128 [[TMP7]] to i64
+; CHECK-NEXT: store i64 [[TMP8]], ptr [[P]], align 8
+; CHECK-NEXT: ret void
+;
+ ; Extract low and high 32 bits
+ %x_lo = and i64 %x, 4294967295 ; x & 0xffffffff
+ %y_lo = and i64 %y, 4294967295 ; y & 0xffffffff
+ %x_hi = lshr i64 %x, 32 ; x >> 32
+ %y_hi = lshr i64 %y, 32 ; y >> 32
+
+ ; Cross products
+ %y_lo_x_hi = mul nuw i64 %y_lo, %x_hi ; y_lo * x_hi
+ %y_hi_x_hi = mul nuw i64 %y_hi, %x_hi ; y_hi * x_hi
+ %y_hi_x_lo = mul nuw i64 %y_hi, %x_lo ; y_hi * x_lo
+ %y_lo_x_lo = mul nuw i64 %y_lo, %x_lo ; y_lo * x_lo
+
+ ; Add cross terms
+ %cross_sum = add i64 %y_hi_x_lo, %y_lo_x_hi ; full 64-bit sum
+
+ ; Carry if overflowed
+ %carry_out = icmp ult i64 %cross_sum, %y_lo_x_hi
+ %carry = select i1 %carry_out, i64 4294967296, i64 0 ; if overflow, add 1 << 32
+
+ ; High 32 bits of low product
+ %y_lo_x_lo_hi = lshr i64 %y_lo_x_lo, 32
+
+ ; Low and high 32 bits of cross_sum
+ %cross_sum_lo = and i64 %cross_sum, 4294967295
+ %cross_sum_hi = lshr i64 %cross_sum, 32
+
+ %low_accum = add nuw nsw i64 %cross_sum_lo, %y_lo_x_lo_hi
+
+ ; Final result accumulation
+ %upper_mid = add nuw i64 %y_hi_x_hi, %carry
+ %low_accum_hi = lshr i64 %low_accum, 32
+ %upper_mid_with_cross = add i64 %upper_mid, %cross_sum_hi
+ %hw64 = add i64 %upper_mid_with_cross, %low_accum_hi
+
+ ; Store high 64 bits
+ %hi_ptr = getelementptr inbounds i8, ptr %p, i64 8
+ store i64 %hw64, ptr %hi_ptr, align 8
+
+ ; Reconstruct low 64 bits
+ %low_accum_shifted = shl i64 %low_accum, 32
+ %y_lo_x_lo_lo = and i64 %y_lo_x_lo, 4294967295
+ %lw64 = or disjoint i64 %low_accum_shifted, %y_lo_x_lo_lo
+
+ ; Store low 64 bits
+ store i64 %lw64, ptr %p, align 8
+
+ ret void
+}
+
+; Negative tests
+
+; 'x_lo' must have exactly 2 uses.
+define i64 @umulh__mul_use__x_lo(i64 %x, i64 %y) {
+; CHECK-LABEL: define i64 @umulh__mul_use__x_lo(
+; CHECK-NOT: i128
+ ; Extract low and high 32 bits
+ %x_lo = and i64 %x, 4294967295 ; x & 0xffffffff
+ call void (...) @llvm.fake.use(i64 %x_lo)
+ %y_lo = and i64 %y, 4294967295 ; y & 0xffffffff
+ %x_hi = lshr i64 %x, 32 ; x >> 32
+ %y_hi = lshr i64 %y, 32 ; y >> 32
+
+ ; Cross products
+ %y_lo_x_hi = mul nuw i64 %y_lo, %x_hi ; y_lo * x_hi
+ %y_hi_x_hi = mul nuw i64 %y_hi, %x_hi ; y_hi * x_hi
+ %y_hi_x_lo = mul nuw i64 %y_hi, %x_lo ; y_hi * x_lo
+ %y_lo_x_lo = mul nuw i64 %y_lo, %x_lo ; y_lo * x_lo
+
+ ; Add cross terms
+ %cross_sum = add i64 %y_hi_x_lo, %y_lo_x_hi ; full 64...
[truncated]
|
|
#60200 is related, but this is the other pattern I'm aware of that I mention in the summary and isn't handled. I discovered this issue after I started working on this. |
|
✅ With the latest revision this PR passed the C/C++ code formatter. |
|
ping - any thoughts on this? |
llvm/lib/Transforms/AggressiveInstCombine/AggressiveInstCombine.cpp
Outdated
Show resolved
Hide resolved
| MadeChange |= foldConsecutiveLoads(I, DL, TTI, AA, DT); | ||
| MadeChange |= foldPatternedLoads(I, DL); | ||
| MadeChange |= foldICmpOrChain(I, DL, TTI, AA, DT); | ||
| MadeChange |= foldMul128Low(I, DL, DT); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe gate on i128 legality? Otherwise, It may lead to regressions on targets without native i128 mul.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
the target doesn't need native 128-bit, AArch64 for example doesnt have a single instruction to do a scalar 128-bit mul, but it can be done in 2 x i64 parts.
cbadf5d to
5bce6f8
Compare
usha1830
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM. Thanks!
This patch adds patterns to recognize a full i64 x i64 -> i128 multiplication by 4 x i32 parts, folding it to a full 128-bit multiply. The low/high parts are implemented as independent patterns. There's also an additional pattern for the high part, both patterns have been seen in real code, and there's one more I'm aware of but I thought I'd share a patch first to see what people think before handling any further cases. On AArch64 the mul and umulh instructions can be used to efficiently compute the low/high parts. I also believe X86 can do the i128 mul in one instruction (returning both halves). So it seems like this is relatively common and could be a useful optimization for several targets.
5bce6f8 to
34672da
Compare
davemgreen
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi. Could we try splitting one of these into a separate patch, perhaps one of the high variants, to split this up a little?
It looks like having a phase-ordering test would be useful too, to protect against codegen changing in the future.
| static bool foldMul128Low(Instruction &I, const DataLayout &DL, | ||
| DominatorTree &DT) { | ||
| auto *Ty = I.getType(); | ||
| if (!Ty->isIntegerTy(64)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What would happen if this was generalized to more types? Would it trigger a lot more, and would those be profitable for smaller types? It would help make this more generic.
| /// | ||
| /// Use counts are checked to prevent total instruction count increase as per | ||
| /// contributors guide: | ||
| /// https://llvm.org/docs/InstCombineContributorGuide.html#multi-use-handling |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think we need links to the ContributorGuide here. It might be useful to explain what it is matching instead.
| Value *XExt = Builder.CreateZExt(X, Builder.getInt128Ty()); | ||
| Value *YExt = Builder.CreateZExt(Y, Builder.getInt128Ty()); | ||
| Value *Mul128 = Builder.CreateMul(XExt, YExt); | ||
| Value *Res = Builder.CreateTrunc(Mul128, Builder.getInt64Ty()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can this just be a mul?
| ; CHECK-NOT: i128 | ||
| ; Extract low and high 32 bits | ||
| %x_lo = and i64 %x, 4294967295 ; x & 0xffffffff | ||
| call void (...) @llvm.fake.use(i64 %x_lo) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It may be fine to transform if the uses are so close to the final operands - it can remove more code overall to do the transform.
|
|
||
| // (hi(y) * lo(x)) + (lo(y) * hi(x)) | ||
| Value *YHigh = nullptr, *XLow = nullptr, *YLowXHigh = nullptr; | ||
| if (!match(CrossSum, m_c_Add(m_OneUse(m_c_Mul(m_Value(YHigh), m_Value(XLow))), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
m_c_Mul can be m_Mul as both the operands are m_Value.
| Value *CrossSum = nullptr; | ||
| if (!match( | ||
| LowAccum, | ||
| m_c_Add(m_OneUse(m_And(m_Value(CrossSum), m_SpecificInt(0xffffffff))), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this And might be optional, as it only clears top bits that we will ignore in the shift.
| Value *CrossSum = nullptr, *XHigh = nullptr, *YHigh = nullptr, | ||
| *Carry = nullptr; | ||
| if (!match(IntermediatePlusCarry, | ||
| m_c_Add(m_c_Add(m_OneUse(m_c_Mul(m_Value(YHigh), m_Value(XHigh))), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
m_c_Mul -> m_Mul
This patch adds recognition of high-half multiply by parts into a single larger
multiply.
Considering a multiply made up of high and low parts, we can split the
multiply into:
x * y == (xh*T + xl) * (yh*T + yl)
where xh == x>>32 and xl == x & 0xffffffff. T = 2^32.
This expands to
xh*yh*T*T + xh*yl*T + xl*yh*T + xl*yl
which I find it helpful to be drawn as
[ xh*yh ]
[ xh*yl ]
[ xl*yh ]
[ xl*yl ]
We are looking for the "high" half, which is xh*yh + xh*yl>>32 + xl*yh>>32 +
carrys. The carry makes this difficult and there are multiple ways of
representing it. The ones we attempt to support here are:
Carry: xh*yh + carry + lowsum
carry = lowsum < xh*yl ? 0x1000000 : 0
lowsum = xh*yl + xl*yh + (xl*yl>>32)
Ladder: xh*yh + c2>>32 + c3>>32
c2 = xh*yl + (xl*yl >> 32); c3 = c2&0xffffffff + xl*yh
Carry4: xh*yh + carry + crosssum>>32 + (xl*yl + crosssum&0xffffffff) >> 32
crosssum = xh*yl + xl*yh
carry = crosssum < xh*yl ? 0x1000000 : 0
Ladder4: xh*yh + (xl*yh)>>32 + (xh*yl)>>32 + low>>32;
low = (xl*yl)>>32 + (xl*yh)&0xffffffff + (xh*yl)&0xffffffff
They all start by matching xh*yh + 2 or 3 other operands. The bottom of the
tree is xh*yh, xh*yl, xl*yh and xl*yl.
Based on llvm#156879 by @c-rhodes
This patch adds recognition of high-half multiply by parts into a single larger
multiply.
Considering a multiply made up of high and low parts, we can split the
multiply into:
x * y == (xh*T + xl) * (yh*T + yl)
where xh == x>>32 and xl == x & 0xffffffff. T = 2^32.
This expands to
xh*yh*T*T + xh*yl*T + xl*yh*T + xl*yl
which I find it helpful to be drawn as
[ xh*yh ]
[ xh*yl ]
[ xl*yh ]
[ xl*yl ]
We are looking for the "high" half, which is xh*yh + xh*yl>>32 + xl*yh>>32 +
carrys. The carry makes this difficult and there are multiple ways of
representing it. The ones we attempt to support here are:
Carry: xh*yh + carry + lowsum
carry = lowsum < xh*yl ? 0x1000000 : 0
lowsum = xh*yl + xl*yh + (xl*yl>>32)
Ladder: xh*yh + c2>>32 + c3>>32
c2 = xh*yl + (xl*yl >> 32); c3 = c2&0xffffffff + xl*yh
Carry4: xh*yh + carry + crosssum>>32 + (xl*yl + crosssum&0xffffffff) >> 32
crosssum = xh*yl + xl*yh
carry = crosssum < xh*yl ? 0x1000000 : 0
Ladder4: xh*yh + (xl*yh)>>32 + (xh*yl)>>32 + low>>32;
low = (xl*yl)>>32 + (xl*yh)&0xffffffff + (xh*yl)&0xffffffff
They all start by matching xh*yh + 2 or 3 other operands. The bottom of the
tree is xh*yh, xh*yl, xl*yh and xl*yl.
Based on llvm#156879 by @c-rhodes
This patch adds patterns to recognize a full i64 x i64 -> i128 multiplication by 4 x i32 parts, folding it to a full 128-bit multiply. The low/high parts are implemented as independent patterns. There's also an additional pattern for the high part, both patterns have been seen in real code, and there's one more I'm aware of but I thought I'd share a patch first to see what people think before handling any further cases.
On AArch64 the mul and umulh instructions can be used to efficiently compute the low/high parts. I also believe X86 can do the i128 mul in one instruction (returning both halves). So it seems like this is relatively common and could be a useful optimization for several targets.
high variant 1 - https://alive2.llvm.org/ce/z/KuJPnU
high variant 2 - https://alive2.llvm.org/ce/z/MSo5S_
low & high - https://alive2.llvm.org/ce/z/PPXtkR