Skip to content

Commit 81207f2

Browse files
committed
[AggressiveInstCombine] Fold i64 x i64 -> i128 multiply-by-parts
This patch adds patterns to recognize a full i64 x i64 -> i128 multiplication by 4 x i32 parts, folding it to a full 128-bit multiply. The low/high parts are implemented as independent patterns. There's also an additional pattern for the high part, both patterns have been seen in real code, and there's one more I'm aware of but I thought I'd share a patch first to see what people think before handling any further cases. On AArch64 the mul and umulh instructions can be used to efficiently compute the low/high parts. I also believe X86 can do the i128 mul in one instruction (returning both halves). So it seems like this is relatively common and could be a useful optimization for several targets.
1 parent 889db04 commit 81207f2

File tree

2 files changed

+2827
-0
lines changed

2 files changed

+2827
-0
lines changed

llvm/lib/Transforms/AggressiveInstCombine/AggressiveInstCombine.cpp

Lines changed: 256 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1457,6 +1457,259 @@ static bool foldLibCalls(Instruction &I, TargetTransformInfo &TTI,
14571457
return false;
14581458
}
14591459

1460+
/// Match low part of 128-bit multiplication.
1461+
static bool foldMul128Low(Instruction &I, const DataLayout &DL,
1462+
DominatorTree &DT) {
1463+
auto *Ty = I.getType();
1464+
if (!Ty->isIntegerTy(64))
1465+
return false;
1466+
1467+
// (low_accum << 32) | lo(lo(y) * lo(x))
1468+
Value *LowAccum = nullptr, *YLowXLow = nullptr;
1469+
if (!match(&I, m_c_DisjointOr(
1470+
m_OneUse(m_Shl(m_Value(LowAccum), m_SpecificInt(32))),
1471+
m_OneUse(
1472+
m_And(m_Value(YLowXLow), m_SpecificInt(0xffffffff))))))
1473+
return false;
1474+
1475+
// lo(cross_sum) + hi(lo(y) * lo(x))
1476+
Value *CrossSum = nullptr;
1477+
if (!match(
1478+
LowAccum,
1479+
m_c_Add(m_OneUse(m_And(m_Value(CrossSum), m_SpecificInt(0xffffffff))),
1480+
m_OneUse(m_LShr(m_Specific(YLowXLow), m_SpecificInt(32))))) ||
1481+
LowAccum->hasNUsesOrMore(3))
1482+
return false;
1483+
1484+
// (hi(y) * lo(x)) + (lo(y) * hi(x))
1485+
Value *YHigh = nullptr, *XLow = nullptr, *YLowXHigh = nullptr;
1486+
if (!match(CrossSum, m_c_Add(m_OneUse(m_c_Mul(m_Value(YHigh), m_Value(XLow))),
1487+
m_Value(YLowXHigh))) ||
1488+
CrossSum->hasNUsesOrMore(4))
1489+
return false;
1490+
1491+
// lo(y) * lo(x)
1492+
Value *YLow = nullptr;
1493+
if (!match(YLowXLow, m_c_Mul(m_Value(YLow), m_Specific(XLow))) ||
1494+
YLowXLow->hasNUsesOrMore(3))
1495+
return false;
1496+
1497+
// lo(y) * hi(x)
1498+
Value *XHigh = nullptr;
1499+
if (!match(YLowXHigh, m_c_Mul(m_Specific(YLow), m_Value(XHigh))) ||
1500+
!YLowXHigh->hasNUses(2))
1501+
return false;
1502+
1503+
Value *X = nullptr;
1504+
// lo(x) = x & 0xffffffff
1505+
if (!match(XLow, m_c_And(m_Value(X), m_SpecificInt(0xffffffff))) ||
1506+
!XLow->hasNUses(2))
1507+
return false;
1508+
// hi(x) = x >> 32
1509+
if (!match(XHigh, m_LShr(m_Specific(X), m_SpecificInt(32))) ||
1510+
!XHigh->hasNUses(2))
1511+
return false;
1512+
1513+
// Same for Y.
1514+
Value *Y = nullptr;
1515+
if (!match(YLow, m_c_And(m_Value(Y), m_SpecificInt(0xffffffff))) ||
1516+
!YLow->hasNUses(2))
1517+
return false;
1518+
if (!match(YHigh, m_LShr(m_Specific(Y), m_SpecificInt(32))) ||
1519+
!YHigh->hasNUses(2))
1520+
return false;
1521+
1522+
IRBuilder<> Builder(&I);
1523+
Value *XExt = Builder.CreateZExt(X, Builder.getInt128Ty());
1524+
Value *YExt = Builder.CreateZExt(Y, Builder.getInt128Ty());
1525+
Value *Mul128 = Builder.CreateMul(XExt, YExt);
1526+
Value *Res = Builder.CreateTrunc(Mul128, Builder.getInt64Ty());
1527+
I.replaceAllUsesWith(Res);
1528+
1529+
return true;
1530+
}
1531+
1532+
/// Match high part of 128-bit multiplication.
1533+
static bool foldMul128High(Instruction &I, const DataLayout &DL,
1534+
DominatorTree &DT) {
1535+
auto *Ty = I.getType();
1536+
if (!Ty->isIntegerTy(64))
1537+
return false;
1538+
1539+
// intermediate_plus_carry + hi(low_accum)
1540+
Value *IntermediatePlusCarry = nullptr, *LowAccum = nullptr;
1541+
if (!match(&I,
1542+
m_c_Add(m_OneUse(m_Value(IntermediatePlusCarry)),
1543+
m_OneUse(m_LShr(m_Value(LowAccum), m_SpecificInt(32))))))
1544+
return false;
1545+
1546+
// match:
1547+
// (((hi(y) * hi(x)) + carry) + hi(cross_sum))
1548+
// or:
1549+
// ((hi(cross_sum) + (hi(y) * hi(x))) + carry)
1550+
CmpPredicate Pred;
1551+
Value *CrossSum = nullptr, *XHigh = nullptr, *YHigh = nullptr,
1552+
*Carry = nullptr;
1553+
if (!match(IntermediatePlusCarry,
1554+
m_c_Add(m_c_Add(m_OneUse(m_c_Mul(m_Value(YHigh), m_Value(XHigh))),
1555+
m_Value(Carry)),
1556+
m_OneUse(m_LShr(m_Value(CrossSum), m_SpecificInt(32))))) &&
1557+
!match(IntermediatePlusCarry,
1558+
m_c_Add(m_OneUse(m_c_Add(
1559+
m_OneUse(m_LShr(m_Value(CrossSum), m_SpecificInt(32))),
1560+
m_OneUse(m_c_Mul(m_Value(YHigh), m_Value(XHigh))))),
1561+
m_Value(Carry))))
1562+
return false;
1563+
1564+
// (select (icmp ult cross_sum, (lo(y) * hi(x))), (1 << 32), 0)
1565+
Value *YLowXHigh = nullptr;
1566+
if (!match(Carry,
1567+
m_OneUse(m_Select(m_OneUse(m_ICmp(Pred, m_Specific(CrossSum),
1568+
m_Value(YLowXHigh))),
1569+
m_SpecificInt(4294967296), m_SpecificInt(0)))) ||
1570+
Pred != ICmpInst::ICMP_ULT)
1571+
return false;
1572+
1573+
// (hi(y) * lo(x)) + (lo(y) * hi(x))
1574+
Value *XLow = nullptr;
1575+
if (!match(CrossSum,
1576+
m_c_Add(m_OneUse(m_c_Mul(m_Specific(YHigh), m_Value(XLow))),
1577+
m_Specific(YLowXHigh))) ||
1578+
CrossSum->hasNUsesOrMore(4))
1579+
return false;
1580+
1581+
// lo(y) * hi(x)
1582+
Value *YLow = nullptr;
1583+
if (!match(YLowXHigh, m_c_Mul(m_Value(YLow), m_Specific(XHigh))) ||
1584+
!YLowXHigh->hasNUses(2))
1585+
return false;
1586+
1587+
// lo(cross_sum) + hi(lo(y) * lo(x))
1588+
Value *YLowXLow = nullptr;
1589+
if (!match(LowAccum,
1590+
m_c_Add(m_OneUse(m_c_And(m_Specific(CrossSum),
1591+
m_SpecificInt(0xffffffff))),
1592+
m_OneUse(m_LShr(m_Value(YLowXLow), m_SpecificInt(32))))) ||
1593+
LowAccum->hasNUsesOrMore(3))
1594+
return false;
1595+
1596+
// lo(y) * lo(x)
1597+
//
1598+
// When only doing the high part there's a single use and 2 uses when doing
1599+
// full multiply. Given the low/high patterns are separate, it's non-trivial
1600+
// to vary the number of uses to check this, but applying the optimization
1601+
// when there's an unrelated use when only doing the high part still results
1602+
// in less instructions and is likely profitable, so an upper bound of 2 uses
1603+
// should be fine.
1604+
if (!match(YLowXLow, m_c_Mul(m_Specific(YLow), m_Specific(XLow))) ||
1605+
YLowXLow->hasNUsesOrMore(3))
1606+
return false;
1607+
1608+
Value *X = nullptr;
1609+
// lo(x) = x & 0xffffffff
1610+
if (!match(XLow, m_c_And(m_Value(X), m_SpecificInt(0xffffffff))) ||
1611+
!XLow->hasNUses(2))
1612+
return false;
1613+
// hi(x) = x >> 32
1614+
if (!match(XHigh, m_LShr(m_Specific(X), m_SpecificInt(32))) ||
1615+
!XHigh->hasNUses(2))
1616+
return false;
1617+
1618+
// Same for Y.
1619+
Value *Y = nullptr;
1620+
if (!match(YLow, m_c_And(m_Value(Y), m_SpecificInt(0xffffffff))) ||
1621+
!YLow->hasNUses(2))
1622+
return false;
1623+
if (!match(YHigh, m_LShr(m_Specific(Y), m_SpecificInt(32))) ||
1624+
!YHigh->hasNUses(2))
1625+
return false;
1626+
1627+
IRBuilder<> Builder(&I);
1628+
Value *XExt = Builder.CreateZExt(X, Builder.getInt128Ty());
1629+
Value *YExt = Builder.CreateZExt(Y, Builder.getInt128Ty());
1630+
Value *Mul128 = Builder.CreateMul(XExt, YExt);
1631+
Value *High = Builder.CreateLShr(Mul128, 64);
1632+
Value *Res = Builder.CreateTrunc(High, Builder.getInt64Ty());
1633+
I.replaceAllUsesWith(Res);
1634+
1635+
return true;
1636+
}
1637+
1638+
/// Match another variant of high part of 128-bit multiplication.
1639+
///
1640+
/// %t0 = mul nuw i64 %y_lo, %x_lo
1641+
/// %t1 = mul nuw i64 %y_lo, %x_hi
1642+
/// %t2 = mul nuw i64 %y_hi, %x_lo
1643+
/// %t3 = mul nuw i64 %y_hi, %x_hi
1644+
/// %t0_hi = lshr i64 %t0, 32
1645+
/// %u0 = add nuw i64 %t0_hi, %t1
1646+
/// %u0_lo = and i64 %u0, 4294967295
1647+
/// %u0_hi = lshr i64 %u0, 32
1648+
/// %u1 = add nuw i64 %u0_lo, %t2
1649+
/// %u1_hi = lshr i64 %u1, 32
1650+
/// %u2 = add nuw i64 %u0_hi, %t3
1651+
/// %hw64 = add nuw i64 %u2, %u1_hi
1652+
static bool foldMul128HighVariant(Instruction &I, const DataLayout &DL,
1653+
DominatorTree &DT) {
1654+
auto *Ty = I.getType();
1655+
if (!Ty->isIntegerTy(64))
1656+
return false;
1657+
1658+
// hw64 = (hi(u0) + (hi(y) * hi(x)) + (lo(u0) + (hi(y) * lo(x)) >> 32))
1659+
Value *U0 = nullptr, *XHigh = nullptr, *YHigh = nullptr, *XLow = nullptr;
1660+
if (!match(
1661+
&I,
1662+
m_c_Add(m_OneUse(m_c_Add(
1663+
m_OneUse(m_LShr(m_Value(U0), m_SpecificInt(32))),
1664+
m_OneUse(m_c_Mul(m_Value(YHigh), m_Value(XHigh))))),
1665+
m_OneUse(m_LShr(
1666+
m_OneUse(m_c_Add(
1667+
m_OneUse(m_c_And(m_Deferred(U0),
1668+
m_SpecificInt(0xffffffff))),
1669+
m_OneUse(m_c_Mul(m_Deferred(YHigh), m_Value(XLow))))),
1670+
m_SpecificInt(32))))))
1671+
return false;
1672+
1673+
// u0 = (hi(lo(y) * lo(x)) + (lo(y) * hi(x)))
1674+
Value *YLow = nullptr;
1675+
if (!match(U0,
1676+
m_c_Add(m_OneUse(m_LShr(
1677+
m_OneUse(m_c_Mul(m_Value(YLow), m_Specific(XLow))),
1678+
m_SpecificInt(32))),
1679+
m_OneUse(m_c_Mul(m_Deferred(YLow), m_Specific(XHigh))))) ||
1680+
!U0->hasNUses(2))
1681+
return false;
1682+
1683+
Value *X = nullptr;
1684+
// lo(x) = x & 0xffffffff
1685+
if (!match(XLow, m_c_And(m_Value(X), m_SpecificInt(0xffffffff))) ||
1686+
!XLow->hasNUses(2))
1687+
return false;
1688+
// hi(x) = x >> 32
1689+
if (!match(XHigh, m_LShr(m_Specific(X), m_SpecificInt(32))) ||
1690+
!XHigh->hasNUses(2))
1691+
return false;
1692+
1693+
// Same for Y.
1694+
Value *Y = nullptr;
1695+
if (!match(YLow, m_c_And(m_Value(Y), m_SpecificInt(0xffffffff))) ||
1696+
!YLow->hasNUses(2))
1697+
return false;
1698+
if (!match(YHigh, m_LShr(m_Specific(Y), m_SpecificInt(32))) ||
1699+
!YHigh->hasNUses(2))
1700+
return false;
1701+
1702+
IRBuilder<> Builder(&I);
1703+
Value *XExt = Builder.CreateZExt(X, Builder.getInt128Ty());
1704+
Value *YExt = Builder.CreateZExt(Y, Builder.getInt128Ty());
1705+
Value *Mul128 = Builder.CreateMul(XExt, YExt);
1706+
Value *High = Builder.CreateLShr(Mul128, 64);
1707+
Value *Res = Builder.CreateTrunc(High, Builder.getInt64Ty());
1708+
I.replaceAllUsesWith(Res);
1709+
1710+
return true;
1711+
}
1712+
14601713
/// This is the entry point for folds that could be implemented in regular
14611714
/// InstCombine, but they are separated because they are not expected to
14621715
/// occur frequently and/or have more than a constant-length pattern match.
@@ -1486,6 +1739,9 @@ static bool foldUnusualPatterns(Function &F, DominatorTree &DT,
14861739
MadeChange |= foldConsecutiveLoads(I, DL, TTI, AA, DT);
14871740
MadeChange |= foldPatternedLoads(I, DL);
14881741
MadeChange |= foldICmpOrChain(I, DL, TTI, AA, DT);
1742+
MadeChange |= foldMul128Low(I, DL, DT);
1743+
MadeChange |= foldMul128High(I, DL, DT);
1744+
MadeChange |= foldMul128HighVariant(I, DL, DT);
14891745
// NOTE: This function introduces erasing of the instruction `I`, so it
14901746
// needs to be called at the end of this sequence, otherwise we may make
14911747
// bugs.

0 commit comments

Comments
 (0)