Skip to content

Commit 0cc5ee4

Browse files
committed
[AggressiveInstCombine] Fold i64 x i64 -> i128 multiply-by-parts
This patch adds patterns to recognize a full i64 x i64 -> i128 multiplication by 4 x i32 parts, folding it to a full 128-bit multiply. The low/high parts are implemented as independent patterns. There's also an additional pattern for the high part, both patterns have been seen in real code, and there's one more I'm aware of but I thought I'd share a patch first to see what people think before handling any further cases. On AArch64 the mul and umulh instructions can be used to efficiently compute the low/high parts. I also believe X86 can do the i128 mul in one instruction (returning both halves). So it seems like this is relatively common and could be a useful optimization for several targets.
1 parent 34d4f0c commit 0cc5ee4

File tree

2 files changed

+2827
-0
lines changed

2 files changed

+2827
-0
lines changed

llvm/lib/Transforms/AggressiveInstCombine/AggressiveInstCombine.cpp

Lines changed: 256 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1428,6 +1428,259 @@ static bool foldLibCalls(Instruction &I, TargetTransformInfo &TTI,
14281428
return false;
14291429
}
14301430

1431+
/// Match low part of 128-bit multiplication.
1432+
static bool foldMul128Low(Instruction &I, const DataLayout &DL,
1433+
DominatorTree &DT) {
1434+
auto *Ty = I.getType();
1435+
if (!Ty->isIntegerTy(64))
1436+
return false;
1437+
1438+
// (low_accum << 32) | lo(lo(y) * lo(x))
1439+
Value *LowAccum = nullptr, *YLowXLow = nullptr;
1440+
if (!match(&I, m_c_DisjointOr(
1441+
m_OneUse(m_Shl(m_Value(LowAccum), m_SpecificInt(32))),
1442+
m_OneUse(
1443+
m_And(m_Value(YLowXLow), m_SpecificInt(0xffffffff))))))
1444+
return false;
1445+
1446+
// lo(cross_sum) + hi(lo(y) * lo(x))
1447+
Value *CrossSum = nullptr;
1448+
if (!match(
1449+
LowAccum,
1450+
m_c_Add(m_OneUse(m_And(m_Value(CrossSum), m_SpecificInt(0xffffffff))),
1451+
m_OneUse(m_LShr(m_Specific(YLowXLow), m_SpecificInt(32))))) ||
1452+
LowAccum->hasNUsesOrMore(3))
1453+
return false;
1454+
1455+
// (hi(y) * lo(x)) + (lo(y) * hi(x))
1456+
Value *YHigh = nullptr, *XLow = nullptr, *YLowXHigh = nullptr;
1457+
if (!match(CrossSum, m_c_Add(m_OneUse(m_c_Mul(m_Value(YHigh), m_Value(XLow))),
1458+
m_Value(YLowXHigh))) ||
1459+
CrossSum->hasNUsesOrMore(4))
1460+
return false;
1461+
1462+
// lo(y) * lo(x)
1463+
Value *YLow = nullptr;
1464+
if (!match(YLowXLow, m_c_Mul(m_Value(YLow), m_Specific(XLow))) ||
1465+
YLowXLow->hasNUsesOrMore(3))
1466+
return false;
1467+
1468+
// lo(y) * hi(x)
1469+
Value *XHigh = nullptr;
1470+
if (!match(YLowXHigh, m_c_Mul(m_Specific(YLow), m_Value(XHigh))) ||
1471+
!YLowXHigh->hasNUses(2))
1472+
return false;
1473+
1474+
Value *X = nullptr;
1475+
// lo(x) = x & 0xffffffff
1476+
if (!match(XLow, m_c_And(m_Value(X), m_SpecificInt(0xffffffff))) ||
1477+
!XLow->hasNUses(2))
1478+
return false;
1479+
// hi(x) = x >> 32
1480+
if (!match(XHigh, m_LShr(m_Specific(X), m_SpecificInt(32))) ||
1481+
!XHigh->hasNUses(2))
1482+
return false;
1483+
1484+
// Same for Y.
1485+
Value *Y = nullptr;
1486+
if (!match(YLow, m_c_And(m_Value(Y), m_SpecificInt(0xffffffff))) ||
1487+
!YLow->hasNUses(2))
1488+
return false;
1489+
if (!match(YHigh, m_LShr(m_Specific(Y), m_SpecificInt(32))) ||
1490+
!YHigh->hasNUses(2))
1491+
return false;
1492+
1493+
IRBuilder<> Builder(&I);
1494+
Value *XExt = Builder.CreateZExt(X, Builder.getInt128Ty());
1495+
Value *YExt = Builder.CreateZExt(Y, Builder.getInt128Ty());
1496+
Value *Mul128 = Builder.CreateMul(XExt, YExt);
1497+
Value *Res = Builder.CreateTrunc(Mul128, Builder.getInt64Ty());
1498+
I.replaceAllUsesWith(Res);
1499+
1500+
return true;
1501+
}
1502+
1503+
/// Match high part of 128-bit multiplication.
1504+
static bool foldMul128High(Instruction &I, const DataLayout &DL,
1505+
DominatorTree &DT) {
1506+
auto *Ty = I.getType();
1507+
if (!Ty->isIntegerTy(64))
1508+
return false;
1509+
1510+
// intermediate_plus_carry + hi(low_accum)
1511+
Value *IntermediatePlusCarry = nullptr, *LowAccum = nullptr;
1512+
if (!match(&I,
1513+
m_c_Add(m_OneUse(m_Value(IntermediatePlusCarry)),
1514+
m_OneUse(m_LShr(m_Value(LowAccum), m_SpecificInt(32))))))
1515+
return false;
1516+
1517+
// match:
1518+
// (((hi(y) * hi(x)) + carry) + hi(cross_sum))
1519+
// or:
1520+
// ((hi(cross_sum) + (hi(y) * hi(x))) + carry)
1521+
CmpPredicate Pred;
1522+
Value *CrossSum = nullptr, *XHigh = nullptr, *YHigh = nullptr,
1523+
*Carry = nullptr;
1524+
if (!match(IntermediatePlusCarry,
1525+
m_c_Add(m_c_Add(m_OneUse(m_c_Mul(m_Value(YHigh), m_Value(XHigh))),
1526+
m_Value(Carry)),
1527+
m_OneUse(m_LShr(m_Value(CrossSum), m_SpecificInt(32))))) &&
1528+
!match(IntermediatePlusCarry,
1529+
m_c_Add(m_OneUse(m_c_Add(
1530+
m_OneUse(m_LShr(m_Value(CrossSum), m_SpecificInt(32))),
1531+
m_OneUse(m_c_Mul(m_Value(YHigh), m_Value(XHigh))))),
1532+
m_Value(Carry))))
1533+
return false;
1534+
1535+
// (select (icmp ult cross_sum, (lo(y) * hi(x))), (1 << 32), 0)
1536+
Value *YLowXHigh = nullptr;
1537+
if (!match(Carry,
1538+
m_OneUse(m_Select(m_OneUse(m_ICmp(Pred, m_Specific(CrossSum),
1539+
m_Value(YLowXHigh))),
1540+
m_SpecificInt(4294967296), m_SpecificInt(0)))) ||
1541+
Pred != ICmpInst::ICMP_ULT)
1542+
return false;
1543+
1544+
// (hi(y) * lo(x)) + (lo(y) * hi(x))
1545+
Value *XLow = nullptr;
1546+
if (!match(CrossSum,
1547+
m_c_Add(m_OneUse(m_c_Mul(m_Specific(YHigh), m_Value(XLow))),
1548+
m_Specific(YLowXHigh))) ||
1549+
CrossSum->hasNUsesOrMore(4))
1550+
return false;
1551+
1552+
// lo(y) * hi(x)
1553+
Value *YLow = nullptr;
1554+
if (!match(YLowXHigh, m_c_Mul(m_Value(YLow), m_Specific(XHigh))) ||
1555+
!YLowXHigh->hasNUses(2))
1556+
return false;
1557+
1558+
// lo(cross_sum) + hi(lo(y) * lo(x))
1559+
Value *YLowXLow = nullptr;
1560+
if (!match(LowAccum,
1561+
m_c_Add(m_OneUse(m_c_And(m_Specific(CrossSum),
1562+
m_SpecificInt(0xffffffff))),
1563+
m_OneUse(m_LShr(m_Value(YLowXLow), m_SpecificInt(32))))) ||
1564+
LowAccum->hasNUsesOrMore(3))
1565+
return false;
1566+
1567+
// lo(y) * lo(x)
1568+
//
1569+
// When only doing the high part there's a single use and 2 uses when doing
1570+
// full multiply. Given the low/high patterns are separate, it's non-trivial
1571+
// to vary the number of uses to check this, but applying the optimization
1572+
// when there's an unrelated use when only doing the high part still results
1573+
// in less instructions and is likely profitable, so an upper bound of 2 uses
1574+
// should be fine.
1575+
if (!match(YLowXLow, m_c_Mul(m_Specific(YLow), m_Specific(XLow))) ||
1576+
YLowXLow->hasNUsesOrMore(3))
1577+
return false;
1578+
1579+
Value *X = nullptr;
1580+
// lo(x) = x & 0xffffffff
1581+
if (!match(XLow, m_c_And(m_Value(X), m_SpecificInt(0xffffffff))) ||
1582+
!XLow->hasNUses(2))
1583+
return false;
1584+
// hi(x) = x >> 32
1585+
if (!match(XHigh, m_LShr(m_Specific(X), m_SpecificInt(32))) ||
1586+
!XHigh->hasNUses(2))
1587+
return false;
1588+
1589+
// Same for Y.
1590+
Value *Y = nullptr;
1591+
if (!match(YLow, m_c_And(m_Value(Y), m_SpecificInt(0xffffffff))) ||
1592+
!YLow->hasNUses(2))
1593+
return false;
1594+
if (!match(YHigh, m_LShr(m_Specific(Y), m_SpecificInt(32))) ||
1595+
!YHigh->hasNUses(2))
1596+
return false;
1597+
1598+
IRBuilder<> Builder(&I);
1599+
Value *XExt = Builder.CreateZExt(X, Builder.getInt128Ty());
1600+
Value *YExt = Builder.CreateZExt(Y, Builder.getInt128Ty());
1601+
Value *Mul128 = Builder.CreateMul(XExt, YExt);
1602+
Value *High = Builder.CreateLShr(Mul128, 64);
1603+
Value *Res = Builder.CreateTrunc(High, Builder.getInt64Ty());
1604+
I.replaceAllUsesWith(Res);
1605+
1606+
return true;
1607+
}
1608+
1609+
/// Match another variant of high part of 128-bit multiplication.
1610+
///
1611+
/// %t0 = mul nuw i64 %y_lo, %x_lo
1612+
/// %t1 = mul nuw i64 %y_lo, %x_hi
1613+
/// %t2 = mul nuw i64 %y_hi, %x_lo
1614+
/// %t3 = mul nuw i64 %y_hi, %x_hi
1615+
/// %t0_hi = lshr i64 %t0, 32
1616+
/// %u0 = add nuw i64 %t0_hi, %t1
1617+
/// %u0_lo = and i64 %u0, 4294967295
1618+
/// %u0_hi = lshr i64 %u0, 32
1619+
/// %u1 = add nuw i64 %u0_lo, %t2
1620+
/// %u1_hi = lshr i64 %u1, 32
1621+
/// %u2 = add nuw i64 %u0_hi, %t3
1622+
/// %hw64 = add nuw i64 %u2, %u1_hi
1623+
static bool foldMul128HighVariant(Instruction &I, const DataLayout &DL,
1624+
DominatorTree &DT) {
1625+
auto *Ty = I.getType();
1626+
if (!Ty->isIntegerTy(64))
1627+
return false;
1628+
1629+
// hw64 = (hi(u0) + (hi(y) * hi(x)) + (lo(u0) + (hi(y) * lo(x)) >> 32))
1630+
Value *U0 = nullptr, *XHigh = nullptr, *YHigh = nullptr, *XLow = nullptr;
1631+
if (!match(
1632+
&I,
1633+
m_c_Add(m_OneUse(m_c_Add(
1634+
m_OneUse(m_LShr(m_Value(U0), m_SpecificInt(32))),
1635+
m_OneUse(m_c_Mul(m_Value(YHigh), m_Value(XHigh))))),
1636+
m_OneUse(m_LShr(
1637+
m_OneUse(m_c_Add(
1638+
m_OneUse(m_c_And(m_Deferred(U0),
1639+
m_SpecificInt(0xffffffff))),
1640+
m_OneUse(m_c_Mul(m_Deferred(YHigh), m_Value(XLow))))),
1641+
m_SpecificInt(32))))))
1642+
return false;
1643+
1644+
// u0 = (hi(lo(y) * lo(x)) + (lo(y) * hi(x)))
1645+
Value *YLow = nullptr;
1646+
if (!match(U0,
1647+
m_c_Add(m_OneUse(m_LShr(
1648+
m_OneUse(m_c_Mul(m_Value(YLow), m_Specific(XLow))),
1649+
m_SpecificInt(32))),
1650+
m_OneUse(m_c_Mul(m_Deferred(YLow), m_Specific(XHigh))))) ||
1651+
!U0->hasNUses(2))
1652+
return false;
1653+
1654+
Value *X = nullptr;
1655+
// lo(x) = x & 0xffffffff
1656+
if (!match(XLow, m_c_And(m_Value(X), m_SpecificInt(0xffffffff))) ||
1657+
!XLow->hasNUses(2))
1658+
return false;
1659+
// hi(x) = x >> 32
1660+
if (!match(XHigh, m_LShr(m_Specific(X), m_SpecificInt(32))) ||
1661+
!XHigh->hasNUses(2))
1662+
return false;
1663+
1664+
// Same for Y.
1665+
Value *Y = nullptr;
1666+
if (!match(YLow, m_c_And(m_Value(Y), m_SpecificInt(0xffffffff))) ||
1667+
!YLow->hasNUses(2))
1668+
return false;
1669+
if (!match(YHigh, m_LShr(m_Specific(Y), m_SpecificInt(32))) ||
1670+
!YHigh->hasNUses(2))
1671+
return false;
1672+
1673+
IRBuilder<> Builder(&I);
1674+
Value *XExt = Builder.CreateZExt(X, Builder.getInt128Ty());
1675+
Value *YExt = Builder.CreateZExt(Y, Builder.getInt128Ty());
1676+
Value *Mul128 = Builder.CreateMul(XExt, YExt);
1677+
Value *High = Builder.CreateLShr(Mul128, 64);
1678+
Value *Res = Builder.CreateTrunc(High, Builder.getInt64Ty());
1679+
I.replaceAllUsesWith(Res);
1680+
1681+
return true;
1682+
}
1683+
14311684
/// This is the entry point for folds that could be implemented in regular
14321685
/// InstCombine, but they are separated because they are not expected to
14331686
/// occur frequently and/or have more than a constant-length pattern match.
@@ -1457,6 +1710,9 @@ static bool foldUnusualPatterns(Function &F, DominatorTree &DT,
14571710
MadeChange |= foldConsecutiveLoads(I, DL, TTI, AA, DT);
14581711
MadeChange |= foldPatternedLoads(I, DL);
14591712
MadeChange |= foldICmpOrChain(I, DL, TTI, AA, DT);
1713+
MadeChange |= foldMul128Low(I, DL, DT);
1714+
MadeChange |= foldMul128High(I, DL, DT);
1715+
MadeChange |= foldMul128HighVariant(I, DL, DT);
14601716
// NOTE: This function introduces erasing of the instruction `I`, so it
14611717
// needs to be called at the end of this sequence, otherwise we may make
14621718
// bugs.

0 commit comments

Comments
 (0)