Skip to content

Commit 6abbbca

Browse files
authored
[AggressiveInstCombine] Match long high-half multiply (#168396)
This patch adds recognition of high-half multiply by parts into a single larger multiply. Considering a multiply made up of high and low parts, we can split the multiply into: x * y == (xh*T + xl) * (yh*T + yl) where `xh == x>>32` and `xl == x & 0xffffffff`. `T = 2^32`. This expands to xh*yh*T*T + xh*yl*T + xl*yh*T + xl*yl which I find it helpful to be drawn as [ xh*yh ] [ xh*yl ] [ xl*yh ] [ xl*yl ] We are looking for the "high" half, which is xh*yh + xh*yl>>32 + xl*yh>>32 + carrys. The carry makes this difficult and there are multiple ways of representing it. The ones we attempt to support here are: Carry: xh*yh + carry + lowsum carry = lowsum < xh*yl ? 0x1000000 : 0 lowsum = xh*yl + xl*yh + (xl*yl>>32) Ladder: xh*yh + c2>>32 + c3>>32 c2 = xh*yl + (xl*yl >> 32); c3 = c2&0xffffffff + xl*yh Carry4: xh*yh + carry + crosssum>>32 + (xl*yl + crosssum&0xffffffff) >> 32 crosssum = xh*yl + xl*yh carry = crosssum < xh*yl ? 0x1000000 : 0 Ladder4: xh*yh + (xl*yh)>>32 + (xh*yl)>>32 + low>>32; low = (xl*yl)>>32 + (xl*yh)&0xffffffff + (xh*yl)&0xfffffff They all start by matching `xh*yh` + 2 or 3 other operands. The bottom of the tree is `xh*yh`, `xh*yl`, `xl*yh` and `xl*yl`. Based on #156879 by @c-rhodes
1 parent bb9449d commit 6abbbca

File tree

5 files changed

+5486
-0
lines changed

5 files changed

+5486
-0
lines changed

llvm/lib/Transforms/AggressiveInstCombine/AggressiveInstCombine.cpp

Lines changed: 324 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1466,6 +1466,329 @@ static bool foldLibCalls(Instruction &I, TargetTransformInfo &TTI,
14661466
return false;
14671467
}
14681468

1469+
/// Match high part of long multiplication.
1470+
///
1471+
/// Considering a multiply made up of high and low parts, we can split the
1472+
/// multiply into:
1473+
/// x * y == (xh*T + xl) * (yh*T + yl)
1474+
/// where xh == x>>32 and xl == x & 0xffffffff. T = 2^32.
1475+
/// This expands to
1476+
/// xh*yh*T*T + xh*yl*T + xl*yh*T + xl*yl
1477+
/// which can be drawn as
1478+
/// [ xh*yh ]
1479+
/// [ xh*yl ]
1480+
/// [ xl*yh ]
1481+
/// [ xl*yl ]
1482+
/// We are looking for the "high" half, which is xh*yh + xh*yl>>32 + xl*yh>>32 +
1483+
/// some carrys. The carry makes this difficult and there are multiple ways of
1484+
/// representing it. The ones we attempt to support here are:
1485+
/// Carry: xh*yh + carry + lowsum
1486+
/// carry = lowsum < xh*yl ? 0x1000000 : 0
1487+
/// lowsum = xh*yl + xl*yh + (xl*yl>>32)
1488+
/// Ladder: xh*yh + c2>>32 + c3>>32
1489+
/// c2 = xh*yl + (xl*yl>>32); c3 = c2&0xffffffff + xl*yh
1490+
/// or c2 = (xl*yh&0xffffffff) + xh*yl + (xl*yl>>32); c3 = xl*yh
1491+
/// Carry4: xh*yh + carry + crosssum>>32 + (xl*yl + crosssum&0xffffffff) >> 32
1492+
/// crosssum = xh*yl + xl*yh
1493+
/// carry = crosssum < xh*yl ? 0x1000000 : 0
1494+
/// Ladder4: xh*yh + (xl*yh)>>32 + (xh*yl)>>32 + low>>32;
1495+
/// low = (xl*yl)>>32 + (xl*yh)&0xffffffff + (xh*yl)&0xffffffff
1496+
///
1497+
/// They all start by matching xh*yh + 2 or 3 other operands. The bottom of the
1498+
/// tree is xh*yh, xh*yl, xl*yh and xl*yl.
1499+
static bool foldMulHigh(Instruction &I) {
1500+
Type *Ty = I.getType();
1501+
if (!Ty->isIntOrIntVectorTy())
1502+
return false;
1503+
1504+
unsigned BitWidth = Ty->getScalarSizeInBits();
1505+
APInt LowMask = APInt::getLowBitsSet(BitWidth, BitWidth / 2);
1506+
if (BitWidth % 2 != 0)
1507+
return false;
1508+
1509+
auto CreateMulHigh = [&](Value *X, Value *Y) {
1510+
IRBuilder<> Builder(&I);
1511+
Type *NTy = Ty->getWithNewBitWidth(BitWidth * 2);
1512+
Value *XExt = Builder.CreateZExt(X, NTy);
1513+
Value *YExt = Builder.CreateZExt(Y, NTy);
1514+
Value *Mul = Builder.CreateMul(XExt, YExt, "", /*HasNUW=*/true);
1515+
Value *High = Builder.CreateLShr(Mul, BitWidth);
1516+
Value *Res = Builder.CreateTrunc(High, Ty, "", /*HasNUW=*/true);
1517+
Res->takeName(&I);
1518+
I.replaceAllUsesWith(Res);
1519+
LLVM_DEBUG(dbgs() << "Created long multiply from parts of " << *X << " and "
1520+
<< *Y << "\n");
1521+
return true;
1522+
};
1523+
1524+
// Common check routines for X_lo*Y_lo and X_hi*Y_lo
1525+
auto CheckLoLo = [&](Value *XlYl, Value *X, Value *Y) {
1526+
return match(XlYl, m_c_Mul(m_And(m_Specific(X), m_SpecificInt(LowMask)),
1527+
m_And(m_Specific(Y), m_SpecificInt(LowMask))));
1528+
};
1529+
auto CheckHiLo = [&](Value *XhYl, Value *X, Value *Y) {
1530+
return match(XhYl,
1531+
m_c_Mul(m_LShr(m_Specific(X), m_SpecificInt(BitWidth / 2)),
1532+
m_And(m_Specific(Y), m_SpecificInt(LowMask))));
1533+
};
1534+
1535+
auto FoldMulHighCarry = [&](Value *X, Value *Y, Instruction *Carry,
1536+
Instruction *B) {
1537+
// Looking for LowSum >> 32 and carry (select)
1538+
if (Carry->getOpcode() != Instruction::Select)
1539+
std::swap(Carry, B);
1540+
1541+
// Carry = LowSum < XhYl ? 0x100000000 : 0
1542+
Value *LowSum, *XhYl;
1543+
if (!match(Carry,
1544+
m_OneUse(m_Select(
1545+
m_OneUse(m_SpecificICmp(ICmpInst::ICMP_ULT, m_Value(LowSum),
1546+
m_Value(XhYl))),
1547+
m_SpecificInt(APInt::getOneBitSet(BitWidth, BitWidth / 2)),
1548+
m_Zero()))))
1549+
return false;
1550+
1551+
// XhYl can be Xh*Yl or Xl*Yh
1552+
if (!CheckHiLo(XhYl, X, Y)) {
1553+
if (CheckHiLo(XhYl, Y, X))
1554+
std::swap(X, Y);
1555+
else
1556+
return false;
1557+
}
1558+
if (XhYl->hasNUsesOrMore(3))
1559+
return false;
1560+
1561+
// B = LowSum >> 32
1562+
if (!match(B, m_OneUse(m_LShr(m_Specific(LowSum),
1563+
m_SpecificInt(BitWidth / 2)))) ||
1564+
LowSum->hasNUsesOrMore(3))
1565+
return false;
1566+
1567+
// LowSum = XhYl + XlYh + XlYl>>32
1568+
Value *XlYh, *XlYl;
1569+
auto XlYlHi = m_LShr(m_Value(XlYl), m_SpecificInt(BitWidth / 2));
1570+
if (!match(LowSum,
1571+
m_c_Add(m_Specific(XhYl),
1572+
m_OneUse(m_c_Add(m_OneUse(m_Value(XlYh)), XlYlHi)))) &&
1573+
!match(LowSum, m_c_Add(m_OneUse(m_Value(XlYh)),
1574+
m_OneUse(m_c_Add(m_Specific(XhYl), XlYlHi)))) &&
1575+
!match(LowSum,
1576+
m_c_Add(XlYlHi, m_OneUse(m_c_Add(m_Specific(XhYl),
1577+
m_OneUse(m_Value(XlYh)))))))
1578+
return false;
1579+
1580+
// Check XlYl and XlYh
1581+
if (!CheckLoLo(XlYl, X, Y))
1582+
return false;
1583+
if (!CheckHiLo(XlYh, Y, X))
1584+
return false;
1585+
1586+
return CreateMulHigh(X, Y);
1587+
};
1588+
1589+
auto FoldMulHighLadder = [&](Value *X, Value *Y, Instruction *A,
1590+
Instruction *B) {
1591+
// xh*yh + c2>>32 + c3>>32
1592+
// c2 = xh*yl + (xl*yl>>32); c3 = c2&0xffffffff + xl*yh
1593+
// or c2 = (xl*yh&0xffffffff) + xh*yl + (xl*yl>>32); c3 = xh*yl
1594+
Value *XlYh, *XhYl, *XlYl, *C2, *C3;
1595+
// Strip off the two expected shifts.
1596+
if (!match(A, m_LShr(m_Value(C2), m_SpecificInt(BitWidth / 2))) ||
1597+
!match(B, m_LShr(m_Value(C3), m_SpecificInt(BitWidth / 2))))
1598+
return false;
1599+
1600+
if (match(C3, m_c_Add(m_Add(m_Value(), m_Value()), m_Value())))
1601+
std::swap(C2, C3);
1602+
// Try to match c2 = (xl*yh&0xffffffff) + xh*yl + (xl*yl>>32)
1603+
if (match(C2,
1604+
m_c_Add(m_c_Add(m_And(m_Specific(C3), m_SpecificInt(LowMask)),
1605+
m_Value(XlYh)),
1606+
m_LShr(m_Value(XlYl), m_SpecificInt(BitWidth / 2)))) ||
1607+
match(C2, m_c_Add(m_c_Add(m_And(m_Specific(C3), m_SpecificInt(LowMask)),
1608+
m_LShr(m_Value(XlYl),
1609+
m_SpecificInt(BitWidth / 2))),
1610+
m_Value(XlYh))) ||
1611+
match(C2, m_c_Add(m_c_Add(m_LShr(m_Value(XlYl),
1612+
m_SpecificInt(BitWidth / 2)),
1613+
m_Value(XlYh)),
1614+
m_And(m_Specific(C3), m_SpecificInt(LowMask))))) {
1615+
XhYl = C3;
1616+
} else {
1617+
// Match c3 = c2&0xffffffff + xl*yh
1618+
if (!match(C3, m_c_Add(m_And(m_Specific(C2), m_SpecificInt(LowMask)),
1619+
m_Value(XlYh))))
1620+
std::swap(C2, C3);
1621+
if (!match(C3, m_c_Add(m_OneUse(
1622+
m_And(m_Specific(C2), m_SpecificInt(LowMask))),
1623+
m_Value(XlYh))) ||
1624+
!C3->hasOneUse() || C2->hasNUsesOrMore(3))
1625+
return false;
1626+
1627+
// Match c2 = xh*yl + (xl*yl >> 32)
1628+
if (!match(C2, m_c_Add(m_LShr(m_Value(XlYl), m_SpecificInt(BitWidth / 2)),
1629+
m_Value(XhYl))))
1630+
return false;
1631+
}
1632+
1633+
// Match XhYl and XlYh - they can appear either way around.
1634+
if (!CheckHiLo(XlYh, Y, X))
1635+
std::swap(XlYh, XhYl);
1636+
if (!CheckHiLo(XlYh, Y, X))
1637+
return false;
1638+
if (!CheckHiLo(XhYl, X, Y))
1639+
return false;
1640+
if (!CheckLoLo(XlYl, X, Y))
1641+
return false;
1642+
1643+
return CreateMulHigh(X, Y);
1644+
};
1645+
1646+
auto FoldMulHighLadder4 = [&](Value *X, Value *Y, Instruction *A,
1647+
Instruction *B, Instruction *C) {
1648+
/// Ladder4: xh*yh + (xl*yh)>>32 + (xh+yl)>>32 + low>>32;
1649+
/// low = (xl*yl)>>32 + (xl*yh)&0xffffffff + (xh*yl)&0xffffffff
1650+
1651+
// Find A = Low >> 32 and B/C = XhYl>>32, XlYh>>32.
1652+
auto ShiftAdd =
1653+
m_LShr(m_Add(m_Value(), m_Value()), m_SpecificInt(BitWidth / 2));
1654+
if (!match(A, ShiftAdd))
1655+
std::swap(A, B);
1656+
if (!match(A, ShiftAdd))
1657+
std::swap(A, C);
1658+
Value *Low;
1659+
if (!match(A, m_LShr(m_OneUse(m_Value(Low)), m_SpecificInt(BitWidth / 2))))
1660+
return false;
1661+
1662+
// Match B == XhYl>>32 and C == XlYh>>32
1663+
Value *XhYl, *XlYh;
1664+
if (!match(B, m_LShr(m_Value(XhYl), m_SpecificInt(BitWidth / 2))) ||
1665+
!match(C, m_LShr(m_Value(XlYh), m_SpecificInt(BitWidth / 2))))
1666+
return false;
1667+
if (!CheckHiLo(XhYl, X, Y))
1668+
std::swap(XhYl, XlYh);
1669+
if (!CheckHiLo(XhYl, X, Y) || XhYl->hasNUsesOrMore(3))
1670+
return false;
1671+
if (!CheckHiLo(XlYh, Y, X) || XlYh->hasNUsesOrMore(3))
1672+
return false;
1673+
1674+
// Match Low as XlYl>>32 + XhYl&0xffffffff + XlYh&0xffffffff
1675+
Value *XlYl;
1676+
if (!match(
1677+
Low,
1678+
m_c_Add(
1679+
m_OneUse(m_c_Add(
1680+
m_OneUse(m_And(m_Specific(XhYl), m_SpecificInt(LowMask))),
1681+
m_OneUse(m_And(m_Specific(XlYh), m_SpecificInt(LowMask))))),
1682+
m_OneUse(
1683+
m_LShr(m_Value(XlYl), m_SpecificInt(BitWidth / 2))))) &&
1684+
!match(
1685+
Low,
1686+
m_c_Add(
1687+
m_OneUse(m_c_Add(
1688+
m_OneUse(m_And(m_Specific(XhYl), m_SpecificInt(LowMask))),
1689+
m_OneUse(
1690+
m_LShr(m_Value(XlYl), m_SpecificInt(BitWidth / 2))))),
1691+
m_OneUse(m_And(m_Specific(XlYh), m_SpecificInt(LowMask))))) &&
1692+
!match(
1693+
Low,
1694+
m_c_Add(
1695+
m_OneUse(m_c_Add(
1696+
m_OneUse(m_And(m_Specific(XlYh), m_SpecificInt(LowMask))),
1697+
m_OneUse(
1698+
m_LShr(m_Value(XlYl), m_SpecificInt(BitWidth / 2))))),
1699+
m_OneUse(m_And(m_Specific(XhYl), m_SpecificInt(LowMask))))))
1700+
return false;
1701+
if (!CheckLoLo(XlYl, X, Y))
1702+
return false;
1703+
1704+
return CreateMulHigh(X, Y);
1705+
};
1706+
1707+
auto FoldMulHighCarry4 = [&](Value *X, Value *Y, Instruction *Carry,
1708+
Instruction *B, Instruction *C) {
1709+
// xh*yh + carry + crosssum>>32 + (xl*yl + crosssum&0xffffffff) >> 32
1710+
// crosssum = xh*yl+xl*yh
1711+
// carry = crosssum < xh*yl ? 0x1000000 : 0
1712+
if (Carry->getOpcode() != Instruction::Select)
1713+
std::swap(Carry, B);
1714+
if (Carry->getOpcode() != Instruction::Select)
1715+
std::swap(Carry, C);
1716+
1717+
// Carry = CrossSum < XhYl ? 0x100000000 : 0
1718+
Value *CrossSum, *XhYl;
1719+
if (!match(Carry,
1720+
m_OneUse(m_Select(
1721+
m_OneUse(m_SpecificICmp(ICmpInst::ICMP_ULT,
1722+
m_Value(CrossSum), m_Value(XhYl))),
1723+
m_SpecificInt(APInt::getOneBitSet(BitWidth, BitWidth / 2)),
1724+
m_Zero()))))
1725+
return false;
1726+
1727+
if (!match(B, m_LShr(m_Specific(CrossSum), m_SpecificInt(BitWidth / 2))))
1728+
std::swap(B, C);
1729+
if (!match(B, m_LShr(m_Specific(CrossSum), m_SpecificInt(BitWidth / 2))))
1730+
return false;
1731+
1732+
Value *XlYl, *LowAccum;
1733+
if (!match(C, m_LShr(m_Value(LowAccum), m_SpecificInt(BitWidth / 2))) ||
1734+
!match(LowAccum, m_c_Add(m_OneUse(m_LShr(m_Value(XlYl),
1735+
m_SpecificInt(BitWidth / 2))),
1736+
m_OneUse(m_And(m_Specific(CrossSum),
1737+
m_SpecificInt(LowMask))))) ||
1738+
LowAccum->hasNUsesOrMore(3))
1739+
return false;
1740+
if (!CheckLoLo(XlYl, X, Y))
1741+
return false;
1742+
1743+
if (!CheckHiLo(XhYl, X, Y))
1744+
std::swap(X, Y);
1745+
if (!CheckHiLo(XhYl, X, Y))
1746+
return false;
1747+
Value *XlYh;
1748+
if (!match(CrossSum, m_c_Add(m_Specific(XhYl), m_OneUse(m_Value(XlYh)))) ||
1749+
!CheckHiLo(XlYh, Y, X) || CrossSum->hasNUsesOrMore(4) ||
1750+
XhYl->hasNUsesOrMore(3))
1751+
return false;
1752+
1753+
return CreateMulHigh(X, Y);
1754+
};
1755+
1756+
// X and Y are the two inputs, A, B and C are other parts of the pattern
1757+
// (crosssum>>32, carry, etc).
1758+
Value *X, *Y;
1759+
Instruction *A, *B, *C;
1760+
auto HiHi = m_OneUse(m_Mul(m_LShr(m_Value(X), m_SpecificInt(BitWidth / 2)),
1761+
m_LShr(m_Value(Y), m_SpecificInt(BitWidth / 2))));
1762+
if ((match(&I, m_c_Add(HiHi, m_OneUse(m_Add(m_Instruction(A),
1763+
m_Instruction(B))))) ||
1764+
match(&I, m_c_Add(m_Instruction(A),
1765+
m_OneUse(m_c_Add(HiHi, m_Instruction(B)))))) &&
1766+
A->hasOneUse() && B->hasOneUse())
1767+
if (FoldMulHighCarry(X, Y, A, B) || FoldMulHighLadder(X, Y, A, B))
1768+
return true;
1769+
1770+
if ((match(&I, m_c_Add(HiHi, m_OneUse(m_c_Add(
1771+
m_Instruction(A),
1772+
m_OneUse(m_Add(m_Instruction(B),
1773+
m_Instruction(C))))))) ||
1774+
match(&I, m_c_Add(m_Instruction(A),
1775+
m_OneUse(m_c_Add(
1776+
HiHi, m_OneUse(m_Add(m_Instruction(B),
1777+
m_Instruction(C))))))) ||
1778+
match(&I, m_c_Add(m_Instruction(A),
1779+
m_OneUse(m_c_Add(
1780+
m_Instruction(B),
1781+
m_OneUse(m_c_Add(HiHi, m_Instruction(C))))))) ||
1782+
match(&I,
1783+
m_c_Add(m_OneUse(m_c_Add(HiHi, m_Instruction(A))),
1784+
m_OneUse(m_Add(m_Instruction(B), m_Instruction(C)))))) &&
1785+
A->hasOneUse() && B->hasOneUse() && C->hasOneUse())
1786+
return FoldMulHighCarry4(X, Y, A, B, C) ||
1787+
FoldMulHighLadder4(X, Y, A, B, C);
1788+
1789+
return false;
1790+
}
1791+
14691792
/// This is the entry point for folds that could be implemented in regular
14701793
/// InstCombine, but they are separated because they are not expected to
14711794
/// occur frequently and/or have more than a constant-length pattern match.
@@ -1495,6 +1818,7 @@ static bool foldUnusualPatterns(Function &F, DominatorTree &DT,
14951818
MadeChange |= foldConsecutiveLoads(I, DL, TTI, AA, DT);
14961819
MadeChange |= foldPatternedLoads(I, DL);
14971820
MadeChange |= foldICmpOrChain(I, DL, TTI, AA, DT);
1821+
MadeChange |= foldMulHigh(I);
14981822
// NOTE: This function introduces erasing of the instruction `I`, so it
14991823
// needs to be called at the end of this sequence, otherwise we may make
15001824
// bugs.

0 commit comments

Comments
 (0)