Skip to content

Commit 380f379

Browse files
committed
[AggressiveInstCombine] Match long high-half multiply
This patch adds recognition of high-half multiply by parts into a single larger multiply. Considering a multiply made up of high and low parts, we can split the multiply into: x * y == (xh*T + xl) * (yh*T + yl) where xh == x>>32 and xl == x & 0xffffffff. T = 2^32. This expands to xh*yh*T*T + xh*yl*T + xl*yh*T + xl*yl which I find it helpful to be drawn as [ xh*yh ] [ xh*yl ] [ xl*yh ] [ xl*yl ] We are looking for the "high" half, which is xh*yh + xh*yl>>32 + xl*yh>>32 + carrys. The carry makes this difficult and there are multiple ways of representing it. The ones we attempt to support here are: Carry: xh*yh + carry + lowsum carry = lowsum < xh*yl ? 0x1000000 : 0 lowsum = xh*yl + xl*yh + (xl*yl>>32) Ladder: xh*yh + c2>>32 + c3>>32 c2 = xh*yl + (xl*yl >> 32); c3 = c2&0xffffffff + xl*yh Carry4: xh*yh + carry + crosssum>>32 + (xl*yl + crosssum&0xffffffff) >> 32 crosssum = xh*yl + xl*yh carry = crosssum < xh*yl ? 0x1000000 : 0 Ladder4: xh*yh + (xl*yh)>>32 + (xh*yl)>>32 + low>>32; low = (xl*yl)>>32 + (xl*yh)&0xffffffff + (xh*yl)&0xffffffff They all start by matching xh*yh + 2 or 3 other operands. The bottom of the tree is xh*yh, xh*yl, xl*yh and xl*yl. Based on #156879 by @c-rhodes
1 parent 2d35ea5 commit 380f379

File tree

5 files changed

+499
-632
lines changed

5 files changed

+499
-632
lines changed

llvm/lib/Transforms/AggressiveInstCombine/AggressiveInstCombine.cpp

Lines changed: 301 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1466,6 +1466,306 @@ static bool foldLibCalls(Instruction &I, TargetTransformInfo &TTI,
14661466
return false;
14671467
}
14681468

1469+
/// Match high part of long multiplication.
1470+
///
1471+
/// Considering a multiply made up of high and low parts, we can split the
1472+
/// multiply into:
1473+
/// x * y == (xh*T + xl) * (yh*T + yl)
1474+
/// where xh == x>>32 and xl == x & 0xffffffff. T = 2^32.
1475+
/// This expands to
1476+
/// xh*yh*T*T + xh*yl*T + xl*yh*T + xl*yl
1477+
/// which can be drawn as
1478+
/// [ xh*yh ]
1479+
/// [ xh*yl ]
1480+
/// [ xl*yh ]
1481+
/// [ xl*yl ]
1482+
/// We are looking for the "high" half, which is xh*yh + xh*yl>>32 + xl*yh>>32 +
1483+
/// some carrys. The carry makes this difficult and there are multiple ways of
1484+
/// representing it. The ones we attempt to support here are:
1485+
/// Carry: xh*yh + carry + lowsum
1486+
/// carry = lowsum < xh*yl ? 0x1000000 : 0
1487+
/// lowsum = xh*yl + xl*yh + (xl*yl>>32)
1488+
/// Ladder: xh*yh + c2>>32 + c3>>32
1489+
/// c2 = xh*yl + (xl*yl>>32); c3 = c2&0xffffffff + xl*yh
1490+
/// Carry4: xh*yh + carry + crosssum>>32 + (xl*yl + crosssum&0xffffffff) >> 32
1491+
/// crosssum = xh*yl + xl*yh
1492+
/// carry = crosssum < xh*yl ? 0x1000000 : 0
1493+
/// Ladder4: xh*yh + (xl*yh)>>32 + (xh*yl)>>32 + low>>32;
1494+
/// low = (xl*yl)>>32 + (xl*yh)&0xffffffff + (xh*yl)&0xffffffff
1495+
///
1496+
/// They all start by matching xh*yh + 2 or 3 other operands. The bottom of the
1497+
/// tree is xh*yh, xh*yl, xl*yh and xl*yl.
1498+
static bool foldMulHigh(Instruction &I) {
1499+
Type *Ty = I.getType();
1500+
if (!Ty->isIntOrIntVectorTy())
1501+
return false;
1502+
1503+
unsigned BW = Ty->getScalarSizeInBits();
1504+
APInt LowMask = APInt::getLowBitsSet(BW, BW / 2);
1505+
if (BW % 2 != 0)
1506+
return false;
1507+
1508+
auto CreateMulHigh = [&](Value *X, Value *Y) {
1509+
IRBuilder<> Builder(&I);
1510+
Type *NTy = Ty->getWithNewBitWidth(BW * 2);
1511+
Value *XExt = Builder.CreateZExt(X, NTy);
1512+
Value *YExt = Builder.CreateZExt(Y, NTy);
1513+
Value *Mul = Builder.CreateMul(XExt, YExt);
1514+
Value *High = Builder.CreateLShr(Mul, BW);
1515+
Value *Res = Builder.CreateTrunc(High, Ty);
1516+
I.replaceAllUsesWith(Res);
1517+
LLVM_DEBUG(dbgs() << "Created long multiply from parts of " << *X << " and "
1518+
<< *Y << "\n");
1519+
return true;
1520+
};
1521+
1522+
// Common check routines for X_lo*Y_lo and X_hi*Y_lo
1523+
auto CheckLoLo = [&](Value *XlYl, Value *X, Value *Y) {
1524+
return match(XlYl, m_c_Mul(m_And(m_Specific(X), m_SpecificInt(LowMask)),
1525+
m_And(m_Specific(Y), m_SpecificInt(LowMask))));
1526+
};
1527+
auto CheckHiLo = [&](Value *XhYl, Value *X, Value *Y) {
1528+
return match(XhYl, m_c_Mul(m_LShr(m_Specific(X), m_SpecificInt(BW / 2)),
1529+
m_And(m_Specific(Y), m_SpecificInt(LowMask))));
1530+
};
1531+
1532+
auto foldMulHighCarry = [&](Value *X, Value *Y, Instruction *Carry,
1533+
Instruction *B) {
1534+
// Looking for LowSum >> 32 and carry (select)
1535+
if (Carry->getOpcode() != Instruction::Select)
1536+
std::swap(Carry, B);
1537+
1538+
// Carry = LowSum < XhYl ? 0x100000000 : 0
1539+
CmpPredicate Pred;
1540+
Value *LowSum, *XhYl;
1541+
if (!match(Carry,
1542+
m_OneUse(m_Select(
1543+
m_OneUse(m_ICmp(Pred, m_Value(LowSum), m_Value(XhYl))),
1544+
m_SpecificInt(APInt(BW, 1) << BW / 2), m_SpecificInt(0)))) ||
1545+
Pred != ICmpInst::ICMP_ULT)
1546+
return false;
1547+
1548+
// XhYl can be Xh*Yl or Xl*Yh
1549+
if (!CheckHiLo(XhYl, X, Y)) {
1550+
if (CheckHiLo(XhYl, Y, X))
1551+
std::swap(X, Y);
1552+
else
1553+
return false;
1554+
}
1555+
if (XhYl->hasNUsesOrMore(3))
1556+
return false;
1557+
1558+
// B = LowSum >> 16
1559+
if (!match(B,
1560+
m_OneUse(m_LShr(m_Specific(LowSum), m_SpecificInt(BW / 2)))) ||
1561+
LowSum->hasNUsesOrMore(3))
1562+
return false;
1563+
1564+
// LowSum = XhYl + XlYh + XlYl>>32
1565+
Value *XlYh, *XlYl;
1566+
auto XlYlHi = m_LShr(m_Value(XlYl), m_SpecificInt(BW / 2));
1567+
if (!match(LowSum,
1568+
m_c_Add(m_Specific(XhYl),
1569+
m_OneUse(m_c_Add(m_OneUse(m_Value(XlYh)), XlYlHi)))) &&
1570+
!match(LowSum, m_c_Add(m_OneUse(m_Value(XlYh)),
1571+
m_OneUse(m_c_Add(m_Specific(XhYl), XlYlHi)))) &&
1572+
!match(LowSum,
1573+
m_c_Add(XlYlHi, m_OneUse(m_c_Add(m_Specific(XhYl),
1574+
m_OneUse(m_Value(XlYh)))))))
1575+
return false;
1576+
1577+
// Check XlYl and XlYh
1578+
if (!CheckLoLo(XlYl, X, Y))
1579+
return false;
1580+
if (!CheckHiLo(XlYh, Y, X))
1581+
return false;
1582+
1583+
return CreateMulHigh(X, Y);
1584+
};
1585+
1586+
auto foldMulHighLadder = [&](Value *X, Value *Y, Instruction *A,
1587+
Instruction *B) {
1588+
// xh*yh + c2>>32 + c3>>32
1589+
// c2 = xh*yl + (xl*yl >> 32); c3 = c2&0xffffffff + xl*yh
1590+
Value *XlYh, *XhYl, *C2, *C3;
1591+
// Strip off the two expected shifts.
1592+
if (!match(A, m_LShr(m_Value(C2), m_SpecificInt(BW / 2))) ||
1593+
!match(B, m_LShr(m_Value(C3), m_SpecificInt(BW / 2))))
1594+
return false;
1595+
1596+
// Match c3 = c2&0xffffffff + xl*yh
1597+
if (!match(C3, m_c_Add(m_And(m_Specific(C2), m_SpecificInt(LowMask)),
1598+
m_Value(XhYl))))
1599+
std::swap(C2, C3);
1600+
if (!match(C3,
1601+
m_c_Add(m_OneUse(m_And(m_Specific(C2), m_SpecificInt(LowMask))),
1602+
m_Value(XhYl))) ||
1603+
!C3->hasOneUse() || C2->hasNUsesOrMore(3))
1604+
return false;
1605+
1606+
// Match c2 = xh*yl + (xl*yl >> 32)
1607+
Value *XlYl;
1608+
if (!match(C2, m_c_Add(m_LShr(m_Value(XlYl), m_SpecificInt(BW / 2)),
1609+
m_Value(XlYh))))
1610+
return false;
1611+
1612+
// Match XhYl and XlYh - they can appear either way around.
1613+
if (!CheckHiLo(XlYh, Y, X))
1614+
std::swap(XlYh, XhYl);
1615+
if (!CheckHiLo(XlYh, Y, X))
1616+
return false;
1617+
if (!CheckHiLo(XhYl, X, Y))
1618+
return false;
1619+
if (!CheckLoLo(XlYl, X, Y))
1620+
return false;
1621+
1622+
return CreateMulHigh(X, Y);
1623+
};
1624+
1625+
auto foldMulHighLadder4 = [&](Value *X, Value *Y, Instruction *A,
1626+
Instruction *B, Instruction *C) {
1627+
/// Ladder4: xh*yh + (xl*yh)>>32 + (xh+yl)>>32 + low>>32;
1628+
/// low = (xl*yl)>>32 + (xl*yh)&0xffffffff + (xh*yl)&0xffffffff
1629+
1630+
// Find A = Low >> 32 and B/C = XhYl>>32, XlYh>>32.
1631+
auto ShiftAdd = m_LShr(m_Add(m_Value(), m_Value()), m_SpecificInt(BW / 2));
1632+
if (!match(A, ShiftAdd))
1633+
std::swap(A, B);
1634+
if (!match(A, ShiftAdd))
1635+
std::swap(A, C);
1636+
Value *Low;
1637+
if (!match(A, m_LShr(m_OneUse(m_Value(Low)), m_SpecificInt(BW / 2))))
1638+
return false;
1639+
1640+
// Match B == XhYl>>32 and C == XlYh>>32
1641+
Value *XhYl, *XlYh;
1642+
if (!match(B, m_LShr(m_Value(XhYl), m_SpecificInt(BW / 2))) ||
1643+
!match(C, m_LShr(m_Value(XlYh), m_SpecificInt(BW / 2))))
1644+
return false;
1645+
if (!CheckHiLo(XhYl, X, Y))
1646+
std::swap(XhYl, XlYh);
1647+
if (!CheckHiLo(XhYl, X, Y) || XhYl->hasNUsesOrMore(3))
1648+
return false;
1649+
if (!CheckHiLo(XlYh, Y, X) || XlYh->hasNUsesOrMore(3))
1650+
return false;
1651+
1652+
// Match Low as XlYl>>32 + XhYl&0xffffffff + XlYh&0xffffffff
1653+
Value *XlYl;
1654+
if (!match(
1655+
Low,
1656+
m_c_Add(
1657+
m_OneUse(m_c_Add(
1658+
m_OneUse(m_And(m_Specific(XhYl), m_SpecificInt(LowMask))),
1659+
m_OneUse(m_And(m_Specific(XlYh), m_SpecificInt(LowMask))))),
1660+
m_OneUse(m_LShr(m_Value(XlYl), m_SpecificInt(BW / 2))))) &&
1661+
!match(
1662+
Low,
1663+
m_c_Add(
1664+
m_OneUse(m_c_Add(
1665+
m_OneUse(m_And(m_Specific(XhYl), m_SpecificInt(LowMask))),
1666+
m_OneUse(m_LShr(m_Value(XlYl), m_SpecificInt(BW / 2))))),
1667+
m_OneUse(m_And(m_Specific(XlYh), m_SpecificInt(LowMask))))) &&
1668+
!match(
1669+
Low,
1670+
m_c_Add(
1671+
m_OneUse(m_c_Add(
1672+
m_OneUse(m_And(m_Specific(XlYh), m_SpecificInt(LowMask))),
1673+
m_OneUse(m_LShr(m_Value(XlYl), m_SpecificInt(BW / 2))))),
1674+
m_OneUse(m_And(m_Specific(XhYl), m_SpecificInt(LowMask))))))
1675+
return false;
1676+
if (!CheckLoLo(XlYl, X, Y))
1677+
return false;
1678+
1679+
return CreateMulHigh(X, Y);
1680+
};
1681+
1682+
auto foldMulHighCarry4 = [&](Value *X, Value *Y, Instruction *Carry,
1683+
Instruction *B, Instruction *C) {
1684+
// xh*yh + carry + crosssum>>32 + (xl*yl + crosssum&0xffffffff) >> 32
1685+
// crosssum = xh*yl+xl*yh
1686+
// carry = crosssum < xh*yl ? 0x1000000 : 0
1687+
if (Carry->getOpcode() != Instruction::Select)
1688+
std::swap(Carry, B);
1689+
if (Carry->getOpcode() != Instruction::Select)
1690+
std::swap(Carry, C);
1691+
1692+
// Carry = CrossSum < XhYl ? 0x100000000 : 0
1693+
CmpPredicate Pred;
1694+
Value *CrossSum, *XhYl;
1695+
if (!match(Carry,
1696+
m_OneUse(m_Select(
1697+
m_OneUse(m_ICmp(Pred, m_Value(CrossSum), m_Value(XhYl))),
1698+
m_SpecificInt(APInt(BW, 1) << BW / 2), m_SpecificInt(0)))) ||
1699+
Pred != ICmpInst::ICMP_ULT)
1700+
return false;
1701+
1702+
if (!match(B, m_LShr(m_Specific(CrossSum), m_SpecificInt(BW / 2))))
1703+
std::swap(B, C);
1704+
if (!match(B, m_LShr(m_Specific(CrossSum), m_SpecificInt(BW / 2))))
1705+
return false;
1706+
1707+
Value *XlYl, *LowAccum;
1708+
if (!match(C, m_LShr(m_Value(LowAccum), m_SpecificInt(BW / 2))) ||
1709+
!match(LowAccum,
1710+
m_c_Add(m_OneUse(m_LShr(m_Value(XlYl), m_SpecificInt(BW / 2))),
1711+
m_OneUse(m_And(m_Specific(CrossSum),
1712+
m_SpecificInt(LowMask))))) ||
1713+
LowAccum->hasNUsesOrMore(3))
1714+
return false;
1715+
if (!CheckLoLo(XlYl, X, Y))
1716+
return false;
1717+
1718+
if (!CheckHiLo(XhYl, X, Y))
1719+
std::swap(X, Y);
1720+
if (!CheckHiLo(XhYl, X, Y))
1721+
return false;
1722+
if (!match(CrossSum,
1723+
m_c_Add(m_Specific(XhYl),
1724+
m_OneUse(m_c_Mul(
1725+
m_LShr(m_Specific(Y), m_SpecificInt(BW / 2)),
1726+
m_And(m_Specific(X), m_SpecificInt(LowMask)))))) ||
1727+
CrossSum->hasNUsesOrMore(4) || XhYl->hasNUsesOrMore(3))
1728+
return false;
1729+
1730+
return CreateMulHigh(X, Y);
1731+
};
1732+
1733+
// X and Y are the two inputs, A, B and C are other parts of the pattern
1734+
// (crosssum>>32, carry, etc).
1735+
Value *X, *Y;
1736+
Instruction *A, *B, *C;
1737+
auto HiHi = m_OneUse(m_Mul(m_LShr(m_Value(X), m_SpecificInt(BW / 2)),
1738+
m_LShr(m_Value(Y), m_SpecificInt(BW / 2))));
1739+
if ((match(&I, m_c_Add(HiHi, m_OneUse(m_Add(m_Instruction(A),
1740+
m_Instruction(B))))) ||
1741+
match(&I, m_c_Add(m_Instruction(A),
1742+
m_OneUse(m_c_Add(HiHi, m_Instruction(B)))))) &&
1743+
A->hasOneUse() && B->hasOneUse())
1744+
if (foldMulHighCarry(X, Y, A, B) || foldMulHighLadder(X, Y, A, B))
1745+
return true;
1746+
1747+
if ((match(&I, m_c_Add(HiHi, m_OneUse(m_c_Add(
1748+
m_Instruction(A),
1749+
m_OneUse(m_Add(m_Instruction(B),
1750+
m_Instruction(C))))))) ||
1751+
match(&I, m_c_Add(m_Instruction(A),
1752+
m_OneUse(m_c_Add(
1753+
HiHi, m_OneUse(m_Add(m_Instruction(B),
1754+
m_Instruction(C))))))) ||
1755+
match(&I, m_c_Add(m_Instruction(A),
1756+
m_OneUse(m_c_Add(
1757+
m_Instruction(B),
1758+
m_OneUse(m_c_Add(HiHi, m_Instruction(C))))))) ||
1759+
match(&I,
1760+
m_c_Add(m_OneUse(m_c_Add(HiHi, m_Instruction(A))),
1761+
m_OneUse(m_Add(m_Instruction(B), m_Instruction(C)))))) &&
1762+
A->hasOneUse() && B->hasOneUse() && C->hasOneUse())
1763+
return foldMulHighCarry4(X, Y, A, B, C) ||
1764+
foldMulHighLadder4(X, Y, A, B, C);
1765+
1766+
return false;
1767+
}
1768+
14691769
/// This is the entry point for folds that could be implemented in regular
14701770
/// InstCombine, but they are separated because they are not expected to
14711771
/// occur frequently and/or have more than a constant-length pattern match.
@@ -1495,6 +1795,7 @@ static bool foldUnusualPatterns(Function &F, DominatorTree &DT,
14951795
MadeChange |= foldConsecutiveLoads(I, DL, TTI, AA, DT);
14961796
MadeChange |= foldPatternedLoads(I, DL);
14971797
MadeChange |= foldICmpOrChain(I, DL, TTI, AA, DT);
1798+
MadeChange |= foldMulHigh(I);
14981799
// NOTE: This function introduces erasing of the instruction `I`, so it
14991800
// needs to be called at the end of this sequence, otherwise we may make
15001801
// bugs.

0 commit comments

Comments
 (0)