-
Notifications
You must be signed in to change notification settings - Fork 15.3k
[AggressiveInstCombine] Match long high-half multiply #168396
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
|
|
@@ -1466,6 +1466,319 @@ static bool foldLibCalls(Instruction &I, TargetTransformInfo &TTI, | |||||
| return false; | ||||||
| } | ||||||
|
|
||||||
| /// Match high part of long multiplication. | ||||||
| /// | ||||||
| /// Considering a multiply made up of high and low parts, we can split the | ||||||
| /// multiply into: | ||||||
| /// x * y == (xh*T + xl) * (yh*T + yl) | ||||||
| /// where xh == x>>32 and xl == x & 0xffffffff. T = 2^32. | ||||||
| /// This expands to | ||||||
| /// xh*yh*T*T + xh*yl*T + xl*yh*T + xl*yl | ||||||
| /// which can be drawn as | ||||||
| /// [ xh*yh ] | ||||||
| /// [ xh*yl ] | ||||||
| /// [ xl*yh ] | ||||||
| /// [ xl*yl ] | ||||||
| /// We are looking for the "high" half, which is xh*yh + xh*yl>>32 + xl*yh>>32 + | ||||||
| /// some carrys. The carry makes this difficult and there are multiple ways of | ||||||
| /// representing it. The ones we attempt to support here are: | ||||||
| /// Carry: xh*yh + carry + lowsum | ||||||
| /// carry = lowsum < xh*yl ? 0x1000000 : 0 | ||||||
| /// lowsum = xh*yl + xl*yh + (xl*yl>>32) | ||||||
| /// Ladder: xh*yh + c2>>32 + c3>>32 | ||||||
| /// c2 = xh*yl + (xl*yl>>32); c3 = c2&0xffffffff + xl*yh | ||||||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Here is another variant that is a bit different from this pattern: https://github.com/Cyan4973/xxHash/blob/136cc1f8fe4d5ea62a7c16c8424d4fa5158f6d68/xxhash.h#L4568-L4582
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. That's an interesting one - the graph looks like a simpler version of what I have called ladder. It has less cross-edges, but I've incorporated it into the logic of FoldMulHighLadder. It would be nice if some of these canonicalized together but they are different enough that it seems difficult without matching the whole tree again.
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. looks like the same one as #60200 |
||||||
| /// or c2 = (xl*yh&0xffffffff) + xh*yl + (xl*yl>>32); c3 = xl*yh | ||||||
| /// Carry4: xh*yh + carry + crosssum>>32 + (xl*yl + crosssum&0xffffffff) >> 32 | ||||||
| /// crosssum = xh*yl + xl*yh | ||||||
| /// carry = crosssum < xh*yl ? 0x1000000 : 0 | ||||||
| /// Ladder4: xh*yh + (xl*yh)>>32 + (xh*yl)>>32 + low>>32; | ||||||
| /// low = (xl*yl)>>32 + (xl*yh)&0xffffffff + (xh*yl)&0xffffffff | ||||||
| /// | ||||||
| /// They all start by matching xh*yh + 2 or 3 other operands. The bottom of the | ||||||
| /// tree is xh*yh, xh*yl, xl*yh and xl*yl. | ||||||
| static bool foldMulHigh(Instruction &I) { | ||||||
| Type *Ty = I.getType(); | ||||||
| if (!Ty->isIntOrIntVectorTy()) | ||||||
| return false; | ||||||
|
|
||||||
| unsigned BW = Ty->getScalarSizeInBits(); | ||||||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: please spell it out
Suggested change
|
||||||
| APInt LowMask = APInt::getLowBitsSet(BW, BW / 2); | ||||||
| if (BW % 2 != 0) | ||||||
| return false; | ||||||
|
|
||||||
| auto CreateMulHigh = [&](Value *X, Value *Y) { | ||||||
| IRBuilder<> Builder(&I); | ||||||
| Type *NTy = Ty->getWithNewBitWidth(BW * 2); | ||||||
| Value *XExt = Builder.CreateZExt(X, NTy); | ||||||
| Value *YExt = Builder.CreateZExt(Y, NTy); | ||||||
| Value *Mul = Builder.CreateMul(XExt, YExt, "", true); | ||||||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
same below |
||||||
| Value *High = Builder.CreateLShr(Mul, BW); | ||||||
| Value *Res = Builder.CreateTrunc(High, Ty, "", true); | ||||||
| Res->takeName(&I); | ||||||
| I.replaceAllUsesWith(Res); | ||||||
davemgreen marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||
| LLVM_DEBUG(dbgs() << "Created long multiply from parts of " << *X << " and " | ||||||
| << *Y << "\n"); | ||||||
| return true; | ||||||
| }; | ||||||
|
|
||||||
| // Common check routines for X_lo*Y_lo and X_hi*Y_lo | ||||||
| auto CheckLoLo = [&](Value *XlYl, Value *X, Value *Y) { | ||||||
| return match(XlYl, m_c_Mul(m_And(m_Specific(X), m_SpecificInt(LowMask)), | ||||||
| m_And(m_Specific(Y), m_SpecificInt(LowMask)))); | ||||||
| }; | ||||||
| auto CheckHiLo = [&](Value *XhYl, Value *X, Value *Y) { | ||||||
| return match(XhYl, m_c_Mul(m_LShr(m_Specific(X), m_SpecificInt(BW / 2)), | ||||||
| m_And(m_Specific(Y), m_SpecificInt(LowMask)))); | ||||||
| }; | ||||||
|
|
||||||
| auto FoldMulHighCarry = [&](Value *X, Value *Y, Instruction *Carry, | ||||||
| Instruction *B) { | ||||||
| // Looking for LowSum >> 32 and carry (select) | ||||||
| if (Carry->getOpcode() != Instruction::Select) | ||||||
| std::swap(Carry, B); | ||||||
|
|
||||||
| // Carry = LowSum < XhYl ? 0x100000000 : 0 | ||||||
| Value *LowSum, *XhYl; | ||||||
| if (!match(Carry, | ||||||
| m_OneUse(m_Select( | ||||||
| m_OneUse(m_SpecificICmp(ICmpInst::ICMP_ULT, m_Value(LowSum), | ||||||
| m_Value(XhYl))), | ||||||
| m_SpecificInt(APInt::getOneBitSet(BW, BW / 2)), m_Zero())))) | ||||||
| return false; | ||||||
|
|
||||||
| // XhYl can be Xh*Yl or Xl*Yh | ||||||
| if (!CheckHiLo(XhYl, X, Y)) { | ||||||
| if (CheckHiLo(XhYl, Y, X)) | ||||||
| std::swap(X, Y); | ||||||
| else | ||||||
| return false; | ||||||
| } | ||||||
| if (XhYl->hasNUsesOrMore(3)) | ||||||
| return false; | ||||||
|
|
||||||
| // B = LowSum >> 32 | ||||||
| if (!match(B, | ||||||
| m_OneUse(m_LShr(m_Specific(LowSum), m_SpecificInt(BW / 2)))) || | ||||||
| LowSum->hasNUsesOrMore(3)) | ||||||
| return false; | ||||||
|
|
||||||
| // LowSum = XhYl + XlYh + XlYl>>32 | ||||||
| Value *XlYh, *XlYl; | ||||||
| auto XlYlHi = m_LShr(m_Value(XlYl), m_SpecificInt(BW / 2)); | ||||||
| if (!match(LowSum, | ||||||
| m_c_Add(m_Specific(XhYl), | ||||||
| m_OneUse(m_c_Add(m_OneUse(m_Value(XlYh)), XlYlHi)))) && | ||||||
| !match(LowSum, m_c_Add(m_OneUse(m_Value(XlYh)), | ||||||
| m_OneUse(m_c_Add(m_Specific(XhYl), XlYlHi)))) && | ||||||
| !match(LowSum, | ||||||
| m_c_Add(XlYlHi, m_OneUse(m_c_Add(m_Specific(XhYl), | ||||||
| m_OneUse(m_Value(XlYh))))))) | ||||||
| return false; | ||||||
|
|
||||||
| // Check XlYl and XlYh | ||||||
| if (!CheckLoLo(XlYl, X, Y)) | ||||||
| return false; | ||||||
| if (!CheckHiLo(XlYh, Y, X)) | ||||||
| return false; | ||||||
|
|
||||||
| return CreateMulHigh(X, Y); | ||||||
| }; | ||||||
|
|
||||||
| auto FoldMulHighLadder = [&](Value *X, Value *Y, Instruction *A, | ||||||
| Instruction *B) { | ||||||
| // xh*yh + c2>>32 + c3>>32 | ||||||
| // c2 = xh*yl + (xl*yl>>32); c3 = c2&0xffffffff + xl*yh | ||||||
| // or c2 = (xl*yh&0xffffffff) + xh*yl + (xl*yl>>32); c3 = xh*yl | ||||||
| Value *XlYh, *XhYl, *XlYl, *C2, *C3; | ||||||
| // Strip off the two expected shifts. | ||||||
| if (!match(A, m_LShr(m_Value(C2), m_SpecificInt(BW / 2))) || | ||||||
| !match(B, m_LShr(m_Value(C3), m_SpecificInt(BW / 2)))) | ||||||
| return false; | ||||||
|
|
||||||
| if (match(C3, m_c_Add(m_Add(m_Value(), m_Value()), m_Value()))) | ||||||
| std::swap(C2, C3); | ||||||
| // Try to match c2 = (xl*yh&0xffffffff) + xh*yl + (xl*yl>>32) | ||||||
| if (match(C2, m_c_Add(m_c_Add(m_And(m_Specific(C3), m_SpecificInt(LowMask)), | ||||||
| m_Value(XlYh)), | ||||||
| m_LShr(m_Value(XlYl), m_SpecificInt(BW / 2)))) || | ||||||
| match(C2, m_c_Add(m_c_Add(m_And(m_Specific(C3), m_SpecificInt(LowMask)), | ||||||
| m_LShr(m_Value(XlYl), m_SpecificInt(BW / 2))), | ||||||
| m_Value(XlYh))) || | ||||||
| match(C2, m_c_Add(m_c_Add(m_LShr(m_Value(XlYl), m_SpecificInt(BW / 2)), | ||||||
| m_Value(XlYh)), | ||||||
| m_And(m_Specific(C3), m_SpecificInt(LowMask))))) { | ||||||
| XhYl = C3; | ||||||
| } else { | ||||||
| // Match c3 = c2&0xffffffff + xl*yh | ||||||
| if (!match(C3, m_c_Add(m_And(m_Specific(C2), m_SpecificInt(LowMask)), | ||||||
| m_Value(XlYh)))) | ||||||
| std::swap(C2, C3); | ||||||
| if (!match(C3, m_c_Add(m_OneUse( | ||||||
| m_And(m_Specific(C2), m_SpecificInt(LowMask))), | ||||||
| m_Value(XlYh))) || | ||||||
| !C3->hasOneUse() || C2->hasNUsesOrMore(3)) | ||||||
| return false; | ||||||
|
|
||||||
| // Match c2 = xh*yl + (xl*yl >> 32) | ||||||
| if (!match(C2, m_c_Add(m_LShr(m_Value(XlYl), m_SpecificInt(BW / 2)), | ||||||
| m_Value(XhYl)))) | ||||||
| return false; | ||||||
| } | ||||||
|
|
||||||
| // Match XhYl and XlYh - they can appear either way around. | ||||||
| if (!CheckHiLo(XlYh, Y, X)) | ||||||
| std::swap(XlYh, XhYl); | ||||||
| if (!CheckHiLo(XlYh, Y, X)) | ||||||
| return false; | ||||||
| if (!CheckHiLo(XhYl, X, Y)) | ||||||
| return false; | ||||||
| if (!CheckLoLo(XlYl, X, Y)) | ||||||
| return false; | ||||||
|
|
||||||
| return CreateMulHigh(X, Y); | ||||||
| }; | ||||||
|
|
||||||
| auto FoldMulHighLadder4 = [&](Value *X, Value *Y, Instruction *A, | ||||||
| Instruction *B, Instruction *C) { | ||||||
| /// Ladder4: xh*yh + (xl*yh)>>32 + (xh+yl)>>32 + low>>32; | ||||||
| /// low = (xl*yl)>>32 + (xl*yh)&0xffffffff + (xh*yl)&0xffffffff | ||||||
|
|
||||||
| // Find A = Low >> 32 and B/C = XhYl>>32, XlYh>>32. | ||||||
| auto ShiftAdd = m_LShr(m_Add(m_Value(), m_Value()), m_SpecificInt(BW / 2)); | ||||||
| if (!match(A, ShiftAdd)) | ||||||
| std::swap(A, B); | ||||||
| if (!match(A, ShiftAdd)) | ||||||
| std::swap(A, C); | ||||||
| Value *Low; | ||||||
| if (!match(A, m_LShr(m_OneUse(m_Value(Low)), m_SpecificInt(BW / 2)))) | ||||||
| return false; | ||||||
|
|
||||||
| // Match B == XhYl>>32 and C == XlYh>>32 | ||||||
| Value *XhYl, *XlYh; | ||||||
| if (!match(B, m_LShr(m_Value(XhYl), m_SpecificInt(BW / 2))) || | ||||||
| !match(C, m_LShr(m_Value(XlYh), m_SpecificInt(BW / 2)))) | ||||||
| return false; | ||||||
| if (!CheckHiLo(XhYl, X, Y)) | ||||||
| std::swap(XhYl, XlYh); | ||||||
| if (!CheckHiLo(XhYl, X, Y) || XhYl->hasNUsesOrMore(3)) | ||||||
| return false; | ||||||
| if (!CheckHiLo(XlYh, Y, X) || XlYh->hasNUsesOrMore(3)) | ||||||
| return false; | ||||||
|
|
||||||
| // Match Low as XlYl>>32 + XhYl&0xffffffff + XlYh&0xffffffff | ||||||
| Value *XlYl; | ||||||
| if (!match( | ||||||
| Low, | ||||||
| m_c_Add( | ||||||
| m_OneUse(m_c_Add( | ||||||
| m_OneUse(m_And(m_Specific(XhYl), m_SpecificInt(LowMask))), | ||||||
| m_OneUse(m_And(m_Specific(XlYh), m_SpecificInt(LowMask))))), | ||||||
| m_OneUse(m_LShr(m_Value(XlYl), m_SpecificInt(BW / 2))))) && | ||||||
| !match( | ||||||
| Low, | ||||||
| m_c_Add( | ||||||
| m_OneUse(m_c_Add( | ||||||
| m_OneUse(m_And(m_Specific(XhYl), m_SpecificInt(LowMask))), | ||||||
| m_OneUse(m_LShr(m_Value(XlYl), m_SpecificInt(BW / 2))))), | ||||||
| m_OneUse(m_And(m_Specific(XlYh), m_SpecificInt(LowMask))))) && | ||||||
| !match( | ||||||
| Low, | ||||||
| m_c_Add( | ||||||
| m_OneUse(m_c_Add( | ||||||
| m_OneUse(m_And(m_Specific(XlYh), m_SpecificInt(LowMask))), | ||||||
| m_OneUse(m_LShr(m_Value(XlYl), m_SpecificInt(BW / 2))))), | ||||||
| m_OneUse(m_And(m_Specific(XhYl), m_SpecificInt(LowMask)))))) | ||||||
| return false; | ||||||
| if (!CheckLoLo(XlYl, X, Y)) | ||||||
| return false; | ||||||
|
|
||||||
| return CreateMulHigh(X, Y); | ||||||
| }; | ||||||
|
|
||||||
| auto FoldMulHighCarry4 = [&](Value *X, Value *Y, Instruction *Carry, | ||||||
| Instruction *B, Instruction *C) { | ||||||
| // xh*yh + carry + crosssum>>32 + (xl*yl + crosssum&0xffffffff) >> 32 | ||||||
| // crosssum = xh*yl+xl*yh | ||||||
| // carry = crosssum < xh*yl ? 0x1000000 : 0 | ||||||
| if (Carry->getOpcode() != Instruction::Select) | ||||||
| std::swap(Carry, B); | ||||||
| if (Carry->getOpcode() != Instruction::Select) | ||||||
| std::swap(Carry, C); | ||||||
|
|
||||||
| // Carry = CrossSum < XhYl ? 0x100000000 : 0 | ||||||
| Value *CrossSum, *XhYl; | ||||||
| if (!match(Carry, | ||||||
| m_OneUse(m_Select( | ||||||
| m_OneUse(m_SpecificICmp(ICmpInst::ICMP_ULT, | ||||||
| m_Value(CrossSum), m_Value(XhYl))), | ||||||
| m_SpecificInt(APInt::getOneBitSet(BW, BW / 2)), m_Zero())))) | ||||||
| return false; | ||||||
|
|
||||||
| if (!match(B, m_LShr(m_Specific(CrossSum), m_SpecificInt(BW / 2)))) | ||||||
| std::swap(B, C); | ||||||
| if (!match(B, m_LShr(m_Specific(CrossSum), m_SpecificInt(BW / 2)))) | ||||||
| return false; | ||||||
|
|
||||||
| Value *XlYl, *LowAccum; | ||||||
| if (!match(C, m_LShr(m_Value(LowAccum), m_SpecificInt(BW / 2))) || | ||||||
| !match(LowAccum, | ||||||
| m_c_Add(m_OneUse(m_LShr(m_Value(XlYl), m_SpecificInt(BW / 2))), | ||||||
| m_OneUse(m_And(m_Specific(CrossSum), | ||||||
| m_SpecificInt(LowMask))))) || | ||||||
| LowAccum->hasNUsesOrMore(3)) | ||||||
| return false; | ||||||
| if (!CheckLoLo(XlYl, X, Y)) | ||||||
| return false; | ||||||
|
|
||||||
| if (!CheckHiLo(XhYl, X, Y)) | ||||||
| std::swap(X, Y); | ||||||
| if (!CheckHiLo(XhYl, X, Y)) | ||||||
| return false; | ||||||
| Value *XlYh; | ||||||
| if (!match(CrossSum, m_c_Add(m_Specific(XhYl), m_OneUse(m_Value(XlYh)))) || | ||||||
| !CheckHiLo(XlYh, Y, X) || CrossSum->hasNUsesOrMore(4) || | ||||||
| XhYl->hasNUsesOrMore(3)) | ||||||
| return false; | ||||||
|
|
||||||
| return CreateMulHigh(X, Y); | ||||||
| }; | ||||||
|
|
||||||
| // X and Y are the two inputs, A, B and C are other parts of the pattern | ||||||
| // (crosssum>>32, carry, etc). | ||||||
| Value *X, *Y; | ||||||
| Instruction *A, *B, *C; | ||||||
| auto HiHi = m_OneUse(m_Mul(m_LShr(m_Value(X), m_SpecificInt(BW / 2)), | ||||||
| m_LShr(m_Value(Y), m_SpecificInt(BW / 2)))); | ||||||
| if ((match(&I, m_c_Add(HiHi, m_OneUse(m_Add(m_Instruction(A), | ||||||
| m_Instruction(B))))) || | ||||||
| match(&I, m_c_Add(m_Instruction(A), | ||||||
| m_OneUse(m_c_Add(HiHi, m_Instruction(B)))))) && | ||||||
| A->hasOneUse() && B->hasOneUse()) | ||||||
| if (FoldMulHighCarry(X, Y, A, B) || FoldMulHighLadder(X, Y, A, B)) | ||||||
| return true; | ||||||
|
|
||||||
| if ((match(&I, m_c_Add(HiHi, m_OneUse(m_c_Add( | ||||||
| m_Instruction(A), | ||||||
| m_OneUse(m_Add(m_Instruction(B), | ||||||
| m_Instruction(C))))))) || | ||||||
| match(&I, m_c_Add(m_Instruction(A), | ||||||
| m_OneUse(m_c_Add( | ||||||
| HiHi, m_OneUse(m_Add(m_Instruction(B), | ||||||
| m_Instruction(C))))))) || | ||||||
| match(&I, m_c_Add(m_Instruction(A), | ||||||
| m_OneUse(m_c_Add( | ||||||
| m_Instruction(B), | ||||||
| m_OneUse(m_c_Add(HiHi, m_Instruction(C))))))) || | ||||||
| match(&I, | ||||||
| m_c_Add(m_OneUse(m_c_Add(HiHi, m_Instruction(A))), | ||||||
| m_OneUse(m_Add(m_Instruction(B), m_Instruction(C)))))) && | ||||||
| A->hasOneUse() && B->hasOneUse() && C->hasOneUse()) | ||||||
| return FoldMulHighCarry4(X, Y, A, B, C) || | ||||||
| FoldMulHighLadder4(X, Y, A, B, C); | ||||||
|
|
||||||
| return false; | ||||||
| } | ||||||
|
|
||||||
| /// This is the entry point for folds that could be implemented in regular | ||||||
| /// InstCombine, but they are separated because they are not expected to | ||||||
| /// occur frequently and/or have more than a constant-length pattern match. | ||||||
|
|
@@ -1495,6 +1808,7 @@ static bool foldUnusualPatterns(Function &F, DominatorTree &DT, | |||||
| MadeChange |= foldConsecutiveLoads(I, DL, TTI, AA, DT); | ||||||
| MadeChange |= foldPatternedLoads(I, DL); | ||||||
| MadeChange |= foldICmpOrChain(I, DL, TTI, AA, DT); | ||||||
| MadeChange |= foldMulHigh(I); | ||||||
| // NOTE: This function introduces erasing of the instruction `I`, so it | ||||||
| // needs to be called at the end of this sequence, otherwise we may make | ||||||
| // bugs. | ||||||
|
|
||||||
Uh oh!
There was an error while loading. Please reload this page.