-
Notifications
You must be signed in to change notification settings - Fork 15.3k
[AggressiveInstCombine] Fold i64 x i64 -> i128 multiply-by-parts #156879
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -1457,6 +1457,268 @@ static bool foldLibCalls(Instruction &I, TargetTransformInfo &TTI, | |
| return false; | ||
| } | ||
|
|
||
| /// Match low part of 128-bit multiplication. | ||
| /// | ||
| /// Use counts are checked to prevent total instruction count increase as per | ||
| /// contributors guide: | ||
| /// https://llvm.org/docs/InstCombineContributorGuide.html#multi-use-handling | ||
| static bool foldMul128Low(Instruction &I) { | ||
| auto *Ty = I.getType(); | ||
| if (!Ty->isIntegerTy(64)) | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What would happen if this was generalized to more types? Would it trigger a lot more, and would those be profitable for smaller types? It would help make this more generic. |
||
| return false; | ||
|
|
||
| // (low_accum << 32) | lo(lo(y) * lo(x)) | ||
| Value *LowAccum = nullptr, *YLowXLow = nullptr; | ||
| if (!match(&I, m_c_DisjointOr( | ||
| m_OneUse(m_Shl(m_Value(LowAccum), m_SpecificInt(32))), | ||
| m_OneUse( | ||
| m_And(m_Value(YLowXLow), m_SpecificInt(0xffffffff)))))) | ||
| return false; | ||
|
|
||
| // lo(cross_sum) + hi(lo(y) * lo(x)) | ||
| Value *CrossSum = nullptr; | ||
| if (!match( | ||
| LowAccum, | ||
| m_c_Add(m_OneUse(m_And(m_Value(CrossSum), m_SpecificInt(0xffffffff))), | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think this And might be optional, as it only clears top bits that we will ignore in the shift. |
||
| m_OneUse(m_LShr(m_Specific(YLowXLow), m_SpecificInt(32))))) || | ||
| LowAccum->hasNUsesOrMore(3)) | ||
c-rhodes marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| return false; | ||
|
|
||
| // (hi(y) * lo(x)) + (lo(y) * hi(x)) | ||
| Value *YHigh = nullptr, *XLow = nullptr, *YLowXHigh = nullptr; | ||
| if (!match(CrossSum, m_c_Add(m_OneUse(m_c_Mul(m_Value(YHigh), m_Value(XLow))), | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. m_c_Mul can be m_Mul as both the operands are m_Value. |
||
| m_Value(YLowXHigh))) || | ||
| CrossSum->hasNUsesOrMore(4)) | ||
| return false; | ||
|
|
||
| // lo(y) * lo(x) | ||
| Value *YLow = nullptr; | ||
| if (!match(YLowXLow, m_c_Mul(m_Value(YLow), m_Specific(XLow))) || | ||
| YLowXLow->hasNUsesOrMore(3)) | ||
| return false; | ||
|
|
||
| // lo(y) * hi(x) | ||
| Value *XHigh = nullptr; | ||
| if (!match(YLowXHigh, m_c_Mul(m_Specific(YLow), m_Value(XHigh))) || | ||
| !YLowXHigh->hasNUses(2)) | ||
| return false; | ||
|
|
||
| Value *X = nullptr; | ||
| // lo(x) = x & 0xffffffff | ||
| if (!match(XLow, m_c_And(m_Value(X), m_SpecificInt(0xffffffff))) || | ||
| !XLow->hasNUses(2)) | ||
| return false; | ||
| // hi(x) = x >> 32 | ||
| if (!match(XHigh, m_LShr(m_Specific(X), m_SpecificInt(32))) || | ||
| !XHigh->hasNUses(2)) | ||
| return false; | ||
|
|
||
| // Same for Y. | ||
| Value *Y = nullptr; | ||
| if (!match(YLow, m_c_And(m_Value(Y), m_SpecificInt(0xffffffff))) || | ||
| !YLow->hasNUses(2)) | ||
| return false; | ||
| if (!match(YHigh, m_LShr(m_Specific(Y), m_SpecificInt(32))) || | ||
| !YHigh->hasNUses(2)) | ||
| return false; | ||
|
|
||
| IRBuilder<> Builder(&I); | ||
| Value *XExt = Builder.CreateZExt(X, Builder.getInt128Ty()); | ||
| Value *YExt = Builder.CreateZExt(Y, Builder.getInt128Ty()); | ||
| Value *Mul128 = Builder.CreateMul(XExt, YExt); | ||
| Value *Res = Builder.CreateTrunc(Mul128, Builder.getInt64Ty()); | ||
|
Comment on lines
+1526
to
+1529
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can this just be a mul? |
||
| I.replaceAllUsesWith(Res); | ||
|
|
||
| return true; | ||
| } | ||
|
|
||
| /// Match high part of 128-bit multiplication. | ||
| /// | ||
| /// Use counts are checked to prevent total instruction count increase as per | ||
| /// contributors guide: | ||
| /// https://llvm.org/docs/InstCombineContributorGuide.html#multi-use-handling | ||
| static bool foldMul128High(Instruction &I) { | ||
| auto *Ty = I.getType(); | ||
| if (!Ty->isIntegerTy(64)) | ||
| return false; | ||
|
|
||
| // intermediate_plus_carry + hi(low_accum) | ||
| Value *IntermediatePlusCarry = nullptr, *LowAccum = nullptr; | ||
| if (!match(&I, | ||
| m_c_Add(m_OneUse(m_Value(IntermediatePlusCarry)), | ||
| m_OneUse(m_LShr(m_Value(LowAccum), m_SpecificInt(32)))))) | ||
| return false; | ||
|
|
||
| // match: | ||
| // (((hi(y) * hi(x)) + carry) + hi(cross_sum)) | ||
| // or: | ||
| // ((hi(cross_sum) + (hi(y) * hi(x))) + carry) | ||
| CmpPredicate Pred; | ||
| Value *CrossSum = nullptr, *XHigh = nullptr, *YHigh = nullptr, | ||
| *Carry = nullptr; | ||
| if (!match(IntermediatePlusCarry, | ||
| m_c_Add(m_c_Add(m_OneUse(m_c_Mul(m_Value(YHigh), m_Value(XHigh))), | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. m_c_Mul -> m_Mul |
||
| m_Value(Carry)), | ||
| m_OneUse(m_LShr(m_Value(CrossSum), m_SpecificInt(32))))) && | ||
| !match(IntermediatePlusCarry, | ||
| m_c_Add(m_OneUse(m_c_Add( | ||
| m_OneUse(m_LShr(m_Value(CrossSum), m_SpecificInt(32))), | ||
| m_OneUse(m_c_Mul(m_Value(YHigh), m_Value(XHigh))))), | ||
| m_Value(Carry)))) | ||
| return false; | ||
|
|
||
| // (select (icmp ult cross_sum, (lo(y) * hi(x))), (1 << 32), 0) | ||
| Value *YLowXHigh = nullptr; | ||
| if (!match(Carry, | ||
| m_OneUse(m_Select(m_OneUse(m_ICmp(Pred, m_Specific(CrossSum), | ||
| m_Value(YLowXHigh))), | ||
| m_SpecificInt(4294967296), m_SpecificInt(0)))) || | ||
| Pred != ICmpInst::ICMP_ULT) | ||
| return false; | ||
|
|
||
| // (hi(y) * lo(x)) + (lo(y) * hi(x)) | ||
| Value *XLow = nullptr; | ||
| if (!match(CrossSum, | ||
| m_c_Add(m_OneUse(m_c_Mul(m_Specific(YHigh), m_Value(XLow))), | ||
| m_Specific(YLowXHigh))) || | ||
| CrossSum->hasNUsesOrMore(4)) | ||
| return false; | ||
|
|
||
| // lo(y) * hi(x) | ||
| Value *YLow = nullptr; | ||
| if (!match(YLowXHigh, m_c_Mul(m_Value(YLow), m_Specific(XHigh))) || | ||
| !YLowXHigh->hasNUses(2)) | ||
| return false; | ||
|
|
||
| // lo(cross_sum) + hi(lo(y) * lo(x)) | ||
| Value *YLowXLow = nullptr; | ||
| if (!match(LowAccum, | ||
| m_c_Add(m_OneUse(m_c_And(m_Specific(CrossSum), | ||
| m_SpecificInt(0xffffffff))), | ||
| m_OneUse(m_LShr(m_Value(YLowXLow), m_SpecificInt(32))))) || | ||
| LowAccum->hasNUsesOrMore(3)) | ||
| return false; | ||
|
|
||
| // lo(y) * lo(x) | ||
| // | ||
| // When only doing the high part there's a single use and 2 uses when doing | ||
| // full multiply. Given the low/high patterns are separate, it's non-trivial | ||
| // to vary the number of uses to check this, but applying the optimization | ||
| // when there's an unrelated use when only doing the high part still results | ||
| // in less instructions and is likely profitable, so an upper bound of 2 uses | ||
| // should be fine. | ||
| if (!match(YLowXLow, m_c_Mul(m_Specific(YLow), m_Specific(XLow))) || | ||
| YLowXLow->hasNUsesOrMore(3)) | ||
| return false; | ||
|
|
||
| Value *X = nullptr; | ||
| // lo(x) = x & 0xffffffff | ||
| if (!match(XLow, m_c_And(m_Value(X), m_SpecificInt(0xffffffff))) || | ||
| !XLow->hasNUses(2)) | ||
| return false; | ||
| // hi(x) = x >> 32 | ||
| if (!match(XHigh, m_LShr(m_Specific(X), m_SpecificInt(32))) || | ||
| !XHigh->hasNUses(2)) | ||
| return false; | ||
|
|
||
| // Same for Y. | ||
| Value *Y = nullptr; | ||
| if (!match(YLow, m_c_And(m_Value(Y), m_SpecificInt(0xffffffff))) || | ||
| !YLow->hasNUses(2)) | ||
| return false; | ||
| if (!match(YHigh, m_LShr(m_Specific(Y), m_SpecificInt(32))) || | ||
| !YHigh->hasNUses(2)) | ||
| return false; | ||
|
|
||
| IRBuilder<> Builder(&I); | ||
| Value *XExt = Builder.CreateZExt(X, Builder.getInt128Ty()); | ||
| Value *YExt = Builder.CreateZExt(Y, Builder.getInt128Ty()); | ||
| Value *Mul128 = Builder.CreateMul(XExt, YExt); | ||
| Value *High = Builder.CreateLShr(Mul128, 64); | ||
| Value *Res = Builder.CreateTrunc(High, Builder.getInt64Ty()); | ||
| I.replaceAllUsesWith(Res); | ||
|
|
||
| return true; | ||
| } | ||
|
|
||
| /// Match another variant of high part of 128-bit multiplication. | ||
| /// | ||
| /// %t0 = mul nuw i64 %y_lo, %x_lo | ||
| /// %t1 = mul nuw i64 %y_lo, %x_hi | ||
| /// %t2 = mul nuw i64 %y_hi, %x_lo | ||
| /// %t3 = mul nuw i64 %y_hi, %x_hi | ||
| /// %t0_hi = lshr i64 %t0, 32 | ||
| /// %u0 = add nuw i64 %t0_hi, %t1 | ||
| /// %u0_lo = and i64 %u0, 4294967295 | ||
| /// %u0_hi = lshr i64 %u0, 32 | ||
| /// %u1 = add nuw i64 %u0_lo, %t2 | ||
| /// %u1_hi = lshr i64 %u1, 32 | ||
| /// %u2 = add nuw i64 %u0_hi, %t3 | ||
| /// %hw64 = add nuw i64 %u2, %u1_hi | ||
| /// | ||
| /// Use counts are checked to prevent total instruction count increase as per | ||
| /// contributors guide: | ||
| /// https://llvm.org/docs/InstCombineContributorGuide.html#multi-use-handling | ||
| static bool foldMul128HighVariant(Instruction &I) { | ||
| auto *Ty = I.getType(); | ||
| if (!Ty->isIntegerTy(64)) | ||
| return false; | ||
|
|
||
| // hw64 = (hi(u0) + (hi(y) * hi(x)) + (lo(u0) + (hi(y) * lo(x)) >> 32)) | ||
| Value *U0 = nullptr, *XHigh = nullptr, *YHigh = nullptr, *XLow = nullptr; | ||
| if (!match( | ||
| &I, | ||
| m_c_Add(m_OneUse(m_c_Add( | ||
| m_OneUse(m_LShr(m_Value(U0), m_SpecificInt(32))), | ||
| m_OneUse(m_c_Mul(m_Value(YHigh), m_Value(XHigh))))), | ||
| m_OneUse(m_LShr( | ||
| m_OneUse(m_c_Add( | ||
| m_OneUse(m_c_And(m_Deferred(U0), | ||
| m_SpecificInt(0xffffffff))), | ||
| m_OneUse(m_c_Mul(m_Deferred(YHigh), m_Value(XLow))))), | ||
| m_SpecificInt(32)))))) | ||
| return false; | ||
|
|
||
| // u0 = (hi(lo(y) * lo(x)) + (lo(y) * hi(x))) | ||
| Value *YLow = nullptr; | ||
| if (!match(U0, | ||
| m_c_Add(m_OneUse(m_LShr( | ||
| m_OneUse(m_c_Mul(m_Value(YLow), m_Specific(XLow))), | ||
| m_SpecificInt(32))), | ||
| m_OneUse(m_c_Mul(m_Deferred(YLow), m_Specific(XHigh))))) || | ||
| !U0->hasNUses(2)) | ||
| return false; | ||
|
|
||
| Value *X = nullptr; | ||
| // lo(x) = x & 0xffffffff | ||
| if (!match(XLow, m_c_And(m_Value(X), m_SpecificInt(0xffffffff))) || | ||
| !XLow->hasNUses(2)) | ||
| return false; | ||
| // hi(x) = x >> 32 | ||
| if (!match(XHigh, m_LShr(m_Specific(X), m_SpecificInt(32))) || | ||
| !XHigh->hasNUses(2)) | ||
| return false; | ||
|
|
||
| // Same for Y. | ||
| Value *Y = nullptr; | ||
| if (!match(YLow, m_c_And(m_Value(Y), m_SpecificInt(0xffffffff))) || | ||
| !YLow->hasNUses(2)) | ||
| return false; | ||
| if (!match(YHigh, m_LShr(m_Specific(Y), m_SpecificInt(32))) || | ||
| !YHigh->hasNUses(2)) | ||
| return false; | ||
|
|
||
| IRBuilder<> Builder(&I); | ||
| Value *XExt = Builder.CreateZExt(X, Builder.getInt128Ty()); | ||
| Value *YExt = Builder.CreateZExt(Y, Builder.getInt128Ty()); | ||
| Value *Mul128 = Builder.CreateMul(XExt, YExt); | ||
| Value *High = Builder.CreateLShr(Mul128, 64); | ||
| Value *Res = Builder.CreateTrunc(High, Builder.getInt64Ty()); | ||
| I.replaceAllUsesWith(Res); | ||
|
|
||
| return true; | ||
| } | ||
|
|
||
| /// This is the entry point for folds that could be implemented in regular | ||
| /// InstCombine, but they are separated because they are not expected to | ||
| /// occur frequently and/or have more than a constant-length pattern match. | ||
|
|
@@ -1486,6 +1748,9 @@ static bool foldUnusualPatterns(Function &F, DominatorTree &DT, | |
| MadeChange |= foldConsecutiveLoads(I, DL, TTI, AA, DT); | ||
| MadeChange |= foldPatternedLoads(I, DL); | ||
| MadeChange |= foldICmpOrChain(I, DL, TTI, AA, DT); | ||
| MadeChange |= foldMul128Low(I); | ||
| MadeChange |= foldMul128High(I); | ||
| MadeChange |= foldMul128HighVariant(I); | ||
| // NOTE: This function introduces erasing of the instruction `I`, so it | ||
| // needs to be called at the end of this sequence, otherwise we may make | ||
| // bugs. | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think we need links to the ContributorGuide here. It might be useful to explain what it is matching instead.