Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
314 changes: 314 additions & 0 deletions llvm/lib/Transforms/AggressiveInstCombine/AggressiveInstCombine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1466,6 +1466,319 @@ static bool foldLibCalls(Instruction &I, TargetTransformInfo &TTI,
return false;
}

/// Match high part of long multiplication.
///
/// Considering a multiply made up of high and low parts, we can split the
/// multiply into:
/// x * y == (xh*T + xl) * (yh*T + yl)
/// where xh == x>>32 and xl == x & 0xffffffff. T = 2^32.
/// This expands to
/// xh*yh*T*T + xh*yl*T + xl*yh*T + xl*yl
/// which can be drawn as
/// [ xh*yh ]
/// [ xh*yl ]
/// [ xl*yh ]
/// [ xl*yl ]
/// We are looking for the "high" half, which is xh*yh + xh*yl>>32 + xl*yh>>32 +
/// some carrys. The carry makes this difficult and there are multiple ways of
/// representing it. The ones we attempt to support here are:
/// Carry: xh*yh + carry + lowsum
/// carry = lowsum < xh*yl ? 0x1000000 : 0
/// lowsum = xh*yl + xl*yh + (xl*yl>>32)
/// Ladder: xh*yh + c2>>32 + c3>>32
/// c2 = xh*yl + (xl*yl>>32); c3 = c2&0xffffffff + xl*yh
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here is another variant that is a bit different from this pattern: https://github.com/Cyan4973/xxHash/blob/136cc1f8fe4d5ea62a7c16c8424d4fa5158f6d68/xxhash.h#L4568-L4582

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's an interesting one - the graph looks like a simpler version of what I have called ladder. It has less cross-edges, but I've incorporated it into the logic of FoldMulHighLadder. It would be nice if some of these canonicalized together but they are different enough that it seems difficult without matching the whole tree again.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

looks like the same one as #60200

/// or c2 = (xl*yh&0xffffffff) + xh*yl + (xl*yl>>32); c3 = xl*yh
/// Carry4: xh*yh + carry + crosssum>>32 + (xl*yl + crosssum&0xffffffff) >> 32
/// crosssum = xh*yl + xl*yh
/// carry = crosssum < xh*yl ? 0x1000000 : 0
/// Ladder4: xh*yh + (xl*yh)>>32 + (xh*yl)>>32 + low>>32;
/// low = (xl*yl)>>32 + (xl*yh)&0xffffffff + (xh*yl)&0xffffffff
///
/// They all start by matching xh*yh + 2 or 3 other operands. The bottom of the
/// tree is xh*yh, xh*yl, xl*yh and xl*yl.
static bool foldMulHigh(Instruction &I) {
Type *Ty = I.getType();
if (!Ty->isIntOrIntVectorTy())
return false;

unsigned BW = Ty->getScalarSizeInBits();
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: please spell it out

Suggested change
unsigned BW = Ty->getScalarSizeInBits();
unsigned BitWidth = Ty->getScalarSizeInBits();

APInt LowMask = APInt::getLowBitsSet(BW, BW / 2);
if (BW % 2 != 0)
return false;

auto CreateMulHigh = [&](Value *X, Value *Y) {
IRBuilder<> Builder(&I);
Type *NTy = Ty->getWithNewBitWidth(BW * 2);
Value *XExt = Builder.CreateZExt(X, NTy);
Value *YExt = Builder.CreateZExt(Y, NTy);
Value *Mul = Builder.CreateMul(XExt, YExt, "", true);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
Value *Mul = Builder.CreateMul(XExt, YExt, "", true);
Value *Mul = Builder.CreateMul(XExt, YExt, "", /*HasNUW=*/true);

same below

Value *High = Builder.CreateLShr(Mul, BW);
Value *Res = Builder.CreateTrunc(High, Ty, "", true);
Res->takeName(&I);
I.replaceAllUsesWith(Res);
LLVM_DEBUG(dbgs() << "Created long multiply from parts of " << *X << " and "
<< *Y << "\n");
return true;
};

// Common check routines for X_lo*Y_lo and X_hi*Y_lo
auto CheckLoLo = [&](Value *XlYl, Value *X, Value *Y) {
return match(XlYl, m_c_Mul(m_And(m_Specific(X), m_SpecificInt(LowMask)),
m_And(m_Specific(Y), m_SpecificInt(LowMask))));
};
auto CheckHiLo = [&](Value *XhYl, Value *X, Value *Y) {
return match(XhYl, m_c_Mul(m_LShr(m_Specific(X), m_SpecificInt(BW / 2)),
m_And(m_Specific(Y), m_SpecificInt(LowMask))));
};

auto FoldMulHighCarry = [&](Value *X, Value *Y, Instruction *Carry,
Instruction *B) {
// Looking for LowSum >> 32 and carry (select)
if (Carry->getOpcode() != Instruction::Select)
std::swap(Carry, B);

// Carry = LowSum < XhYl ? 0x100000000 : 0
Value *LowSum, *XhYl;
if (!match(Carry,
m_OneUse(m_Select(
m_OneUse(m_SpecificICmp(ICmpInst::ICMP_ULT, m_Value(LowSum),
m_Value(XhYl))),
m_SpecificInt(APInt::getOneBitSet(BW, BW / 2)), m_Zero()))))
return false;

// XhYl can be Xh*Yl or Xl*Yh
if (!CheckHiLo(XhYl, X, Y)) {
if (CheckHiLo(XhYl, Y, X))
std::swap(X, Y);
else
return false;
}
if (XhYl->hasNUsesOrMore(3))
return false;

// B = LowSum >> 32
if (!match(B,
m_OneUse(m_LShr(m_Specific(LowSum), m_SpecificInt(BW / 2)))) ||
LowSum->hasNUsesOrMore(3))
return false;

// LowSum = XhYl + XlYh + XlYl>>32
Value *XlYh, *XlYl;
auto XlYlHi = m_LShr(m_Value(XlYl), m_SpecificInt(BW / 2));
if (!match(LowSum,
m_c_Add(m_Specific(XhYl),
m_OneUse(m_c_Add(m_OneUse(m_Value(XlYh)), XlYlHi)))) &&
!match(LowSum, m_c_Add(m_OneUse(m_Value(XlYh)),
m_OneUse(m_c_Add(m_Specific(XhYl), XlYlHi)))) &&
!match(LowSum,
m_c_Add(XlYlHi, m_OneUse(m_c_Add(m_Specific(XhYl),
m_OneUse(m_Value(XlYh)))))))
return false;

// Check XlYl and XlYh
if (!CheckLoLo(XlYl, X, Y))
return false;
if (!CheckHiLo(XlYh, Y, X))
return false;

return CreateMulHigh(X, Y);
};

auto FoldMulHighLadder = [&](Value *X, Value *Y, Instruction *A,
Instruction *B) {
// xh*yh + c2>>32 + c3>>32
// c2 = xh*yl + (xl*yl>>32); c3 = c2&0xffffffff + xl*yh
// or c2 = (xl*yh&0xffffffff) + xh*yl + (xl*yl>>32); c3 = xh*yl
Value *XlYh, *XhYl, *XlYl, *C2, *C3;
// Strip off the two expected shifts.
if (!match(A, m_LShr(m_Value(C2), m_SpecificInt(BW / 2))) ||
!match(B, m_LShr(m_Value(C3), m_SpecificInt(BW / 2))))
return false;

if (match(C3, m_c_Add(m_Add(m_Value(), m_Value()), m_Value())))
std::swap(C2, C3);
// Try to match c2 = (xl*yh&0xffffffff) + xh*yl + (xl*yl>>32)
if (match(C2, m_c_Add(m_c_Add(m_And(m_Specific(C3), m_SpecificInt(LowMask)),
m_Value(XlYh)),
m_LShr(m_Value(XlYl), m_SpecificInt(BW / 2)))) ||
match(C2, m_c_Add(m_c_Add(m_And(m_Specific(C3), m_SpecificInt(LowMask)),
m_LShr(m_Value(XlYl), m_SpecificInt(BW / 2))),
m_Value(XlYh))) ||
match(C2, m_c_Add(m_c_Add(m_LShr(m_Value(XlYl), m_SpecificInt(BW / 2)),
m_Value(XlYh)),
m_And(m_Specific(C3), m_SpecificInt(LowMask))))) {
XhYl = C3;
} else {
// Match c3 = c2&0xffffffff + xl*yh
if (!match(C3, m_c_Add(m_And(m_Specific(C2), m_SpecificInt(LowMask)),
m_Value(XlYh))))
std::swap(C2, C3);
if (!match(C3, m_c_Add(m_OneUse(
m_And(m_Specific(C2), m_SpecificInt(LowMask))),
m_Value(XlYh))) ||
!C3->hasOneUse() || C2->hasNUsesOrMore(3))
return false;

// Match c2 = xh*yl + (xl*yl >> 32)
if (!match(C2, m_c_Add(m_LShr(m_Value(XlYl), m_SpecificInt(BW / 2)),
m_Value(XhYl))))
return false;
}

// Match XhYl and XlYh - they can appear either way around.
if (!CheckHiLo(XlYh, Y, X))
std::swap(XlYh, XhYl);
if (!CheckHiLo(XlYh, Y, X))
return false;
if (!CheckHiLo(XhYl, X, Y))
return false;
if (!CheckLoLo(XlYl, X, Y))
return false;

return CreateMulHigh(X, Y);
};

auto FoldMulHighLadder4 = [&](Value *X, Value *Y, Instruction *A,
Instruction *B, Instruction *C) {
/// Ladder4: xh*yh + (xl*yh)>>32 + (xh+yl)>>32 + low>>32;
/// low = (xl*yl)>>32 + (xl*yh)&0xffffffff + (xh*yl)&0xffffffff

// Find A = Low >> 32 and B/C = XhYl>>32, XlYh>>32.
auto ShiftAdd = m_LShr(m_Add(m_Value(), m_Value()), m_SpecificInt(BW / 2));
if (!match(A, ShiftAdd))
std::swap(A, B);
if (!match(A, ShiftAdd))
std::swap(A, C);
Value *Low;
if (!match(A, m_LShr(m_OneUse(m_Value(Low)), m_SpecificInt(BW / 2))))
return false;

// Match B == XhYl>>32 and C == XlYh>>32
Value *XhYl, *XlYh;
if (!match(B, m_LShr(m_Value(XhYl), m_SpecificInt(BW / 2))) ||
!match(C, m_LShr(m_Value(XlYh), m_SpecificInt(BW / 2))))
return false;
if (!CheckHiLo(XhYl, X, Y))
std::swap(XhYl, XlYh);
if (!CheckHiLo(XhYl, X, Y) || XhYl->hasNUsesOrMore(3))
return false;
if (!CheckHiLo(XlYh, Y, X) || XlYh->hasNUsesOrMore(3))
return false;

// Match Low as XlYl>>32 + XhYl&0xffffffff + XlYh&0xffffffff
Value *XlYl;
if (!match(
Low,
m_c_Add(
m_OneUse(m_c_Add(
m_OneUse(m_And(m_Specific(XhYl), m_SpecificInt(LowMask))),
m_OneUse(m_And(m_Specific(XlYh), m_SpecificInt(LowMask))))),
m_OneUse(m_LShr(m_Value(XlYl), m_SpecificInt(BW / 2))))) &&
!match(
Low,
m_c_Add(
m_OneUse(m_c_Add(
m_OneUse(m_And(m_Specific(XhYl), m_SpecificInt(LowMask))),
m_OneUse(m_LShr(m_Value(XlYl), m_SpecificInt(BW / 2))))),
m_OneUse(m_And(m_Specific(XlYh), m_SpecificInt(LowMask))))) &&
!match(
Low,
m_c_Add(
m_OneUse(m_c_Add(
m_OneUse(m_And(m_Specific(XlYh), m_SpecificInt(LowMask))),
m_OneUse(m_LShr(m_Value(XlYl), m_SpecificInt(BW / 2))))),
m_OneUse(m_And(m_Specific(XhYl), m_SpecificInt(LowMask))))))
return false;
if (!CheckLoLo(XlYl, X, Y))
return false;

return CreateMulHigh(X, Y);
};

auto FoldMulHighCarry4 = [&](Value *X, Value *Y, Instruction *Carry,
Instruction *B, Instruction *C) {
// xh*yh + carry + crosssum>>32 + (xl*yl + crosssum&0xffffffff) >> 32
// crosssum = xh*yl+xl*yh
// carry = crosssum < xh*yl ? 0x1000000 : 0
if (Carry->getOpcode() != Instruction::Select)
std::swap(Carry, B);
if (Carry->getOpcode() != Instruction::Select)
std::swap(Carry, C);

// Carry = CrossSum < XhYl ? 0x100000000 : 0
Value *CrossSum, *XhYl;
if (!match(Carry,
m_OneUse(m_Select(
m_OneUse(m_SpecificICmp(ICmpInst::ICMP_ULT,
m_Value(CrossSum), m_Value(XhYl))),
m_SpecificInt(APInt::getOneBitSet(BW, BW / 2)), m_Zero()))))
return false;

if (!match(B, m_LShr(m_Specific(CrossSum), m_SpecificInt(BW / 2))))
std::swap(B, C);
if (!match(B, m_LShr(m_Specific(CrossSum), m_SpecificInt(BW / 2))))
return false;

Value *XlYl, *LowAccum;
if (!match(C, m_LShr(m_Value(LowAccum), m_SpecificInt(BW / 2))) ||
!match(LowAccum,
m_c_Add(m_OneUse(m_LShr(m_Value(XlYl), m_SpecificInt(BW / 2))),
m_OneUse(m_And(m_Specific(CrossSum),
m_SpecificInt(LowMask))))) ||
LowAccum->hasNUsesOrMore(3))
return false;
if (!CheckLoLo(XlYl, X, Y))
return false;

if (!CheckHiLo(XhYl, X, Y))
std::swap(X, Y);
if (!CheckHiLo(XhYl, X, Y))
return false;
Value *XlYh;
if (!match(CrossSum, m_c_Add(m_Specific(XhYl), m_OneUse(m_Value(XlYh)))) ||
!CheckHiLo(XlYh, Y, X) || CrossSum->hasNUsesOrMore(4) ||
XhYl->hasNUsesOrMore(3))
return false;

return CreateMulHigh(X, Y);
};

// X and Y are the two inputs, A, B and C are other parts of the pattern
// (crosssum>>32, carry, etc).
Value *X, *Y;
Instruction *A, *B, *C;
auto HiHi = m_OneUse(m_Mul(m_LShr(m_Value(X), m_SpecificInt(BW / 2)),
m_LShr(m_Value(Y), m_SpecificInt(BW / 2))));
if ((match(&I, m_c_Add(HiHi, m_OneUse(m_Add(m_Instruction(A),
m_Instruction(B))))) ||
match(&I, m_c_Add(m_Instruction(A),
m_OneUse(m_c_Add(HiHi, m_Instruction(B)))))) &&
A->hasOneUse() && B->hasOneUse())
if (FoldMulHighCarry(X, Y, A, B) || FoldMulHighLadder(X, Y, A, B))
return true;

if ((match(&I, m_c_Add(HiHi, m_OneUse(m_c_Add(
m_Instruction(A),
m_OneUse(m_Add(m_Instruction(B),
m_Instruction(C))))))) ||
match(&I, m_c_Add(m_Instruction(A),
m_OneUse(m_c_Add(
HiHi, m_OneUse(m_Add(m_Instruction(B),
m_Instruction(C))))))) ||
match(&I, m_c_Add(m_Instruction(A),
m_OneUse(m_c_Add(
m_Instruction(B),
m_OneUse(m_c_Add(HiHi, m_Instruction(C))))))) ||
match(&I,
m_c_Add(m_OneUse(m_c_Add(HiHi, m_Instruction(A))),
m_OneUse(m_Add(m_Instruction(B), m_Instruction(C)))))) &&
A->hasOneUse() && B->hasOneUse() && C->hasOneUse())
return FoldMulHighCarry4(X, Y, A, B, C) ||
FoldMulHighLadder4(X, Y, A, B, C);

return false;
}

/// This is the entry point for folds that could be implemented in regular
/// InstCombine, but they are separated because they are not expected to
/// occur frequently and/or have more than a constant-length pattern match.
Expand Down Expand Up @@ -1495,6 +1808,7 @@ static bool foldUnusualPatterns(Function &F, DominatorTree &DT,
MadeChange |= foldConsecutiveLoads(I, DL, TTI, AA, DT);
MadeChange |= foldPatternedLoads(I, DL);
MadeChange |= foldICmpOrChain(I, DL, TTI, AA, DT);
MadeChange |= foldMulHigh(I);
// NOTE: This function introduces erasing of the instruction `I`, so it
// needs to be called at the end of this sequence, otherwise we may make
// bugs.
Expand Down
Loading