@@ -1466,6 +1466,329 @@ static bool foldLibCalls(Instruction &I, TargetTransformInfo &TTI,
14661466 return false ;
14671467}
14681468
1469+ // / Match high part of long multiplication.
1470+ // /
1471+ // / Considering a multiply made up of high and low parts, we can split the
1472+ // / multiply into:
1473+ // / x * y == (xh*T + xl) * (yh*T + yl)
1474+ // / where xh == x>>32 and xl == x & 0xffffffff. T = 2^32.
1475+ // / This expands to
1476+ // / xh*yh*T*T + xh*yl*T + xl*yh*T + xl*yl
1477+ // / which can be drawn as
1478+ // / [ xh*yh ]
1479+ // / [ xh*yl ]
1480+ // / [ xl*yh ]
1481+ // / [ xl*yl ]
1482+ // / We are looking for the "high" half, which is xh*yh + xh*yl>>32 + xl*yh>>32 +
1483+ // / some carrys. The carry makes this difficult and there are multiple ways of
1484+ // / representing it. The ones we attempt to support here are:
1485+ // / Carry: xh*yh + carry + lowsum
1486+ // / carry = lowsum < xh*yl ? 0x1000000 : 0
1487+ // / lowsum = xh*yl + xl*yh + (xl*yl>>32)
1488+ // / Ladder: xh*yh + c2>>32 + c3>>32
1489+ // / c2 = xh*yl + (xl*yl>>32); c3 = c2&0xffffffff + xl*yh
1490+ // / or c2 = (xl*yh&0xffffffff) + xh*yl + (xl*yl>>32); c3 = xl*yh
1491+ // / Carry4: xh*yh + carry + crosssum>>32 + (xl*yl + crosssum&0xffffffff) >> 32
1492+ // / crosssum = xh*yl + xl*yh
1493+ // / carry = crosssum < xh*yl ? 0x1000000 : 0
1494+ // / Ladder4: xh*yh + (xl*yh)>>32 + (xh*yl)>>32 + low>>32;
1495+ // / low = (xl*yl)>>32 + (xl*yh)&0xffffffff + (xh*yl)&0xffffffff
1496+ // /
1497+ // / They all start by matching xh*yh + 2 or 3 other operands. The bottom of the
1498+ // / tree is xh*yh, xh*yl, xl*yh and xl*yl.
1499+ static bool foldMulHigh (Instruction &I) {
1500+ Type *Ty = I.getType ();
1501+ if (!Ty->isIntOrIntVectorTy ())
1502+ return false ;
1503+
1504+ unsigned BitWidth = Ty->getScalarSizeInBits ();
1505+ APInt LowMask = APInt::getLowBitsSet (BitWidth, BitWidth / 2 );
1506+ if (BitWidth % 2 != 0 )
1507+ return false ;
1508+
1509+ auto CreateMulHigh = [&](Value *X, Value *Y) {
1510+ IRBuilder<> Builder (&I);
1511+ Type *NTy = Ty->getWithNewBitWidth (BitWidth * 2 );
1512+ Value *XExt = Builder.CreateZExt (X, NTy);
1513+ Value *YExt = Builder.CreateZExt (Y, NTy);
1514+ Value *Mul = Builder.CreateMul (XExt, YExt, " " , /* HasNUW=*/ true );
1515+ Value *High = Builder.CreateLShr (Mul, BitWidth);
1516+ Value *Res = Builder.CreateTrunc (High, Ty, " " , /* HasNUW=*/ true );
1517+ Res->takeName (&I);
1518+ I.replaceAllUsesWith (Res);
1519+ LLVM_DEBUG (dbgs () << " Created long multiply from parts of " << *X << " and "
1520+ << *Y << " \n " );
1521+ return true ;
1522+ };
1523+
1524+ // Common check routines for X_lo*Y_lo and X_hi*Y_lo
1525+ auto CheckLoLo = [&](Value *XlYl, Value *X, Value *Y) {
1526+ return match (XlYl, m_c_Mul (m_And (m_Specific (X), m_SpecificInt (LowMask)),
1527+ m_And (m_Specific (Y), m_SpecificInt (LowMask))));
1528+ };
1529+ auto CheckHiLo = [&](Value *XhYl, Value *X, Value *Y) {
1530+ return match (XhYl,
1531+ m_c_Mul (m_LShr (m_Specific (X), m_SpecificInt (BitWidth / 2 )),
1532+ m_And (m_Specific (Y), m_SpecificInt (LowMask))));
1533+ };
1534+
1535+ auto FoldMulHighCarry = [&](Value *X, Value *Y, Instruction *Carry,
1536+ Instruction *B) {
1537+ // Looking for LowSum >> 32 and carry (select)
1538+ if (Carry->getOpcode () != Instruction::Select)
1539+ std::swap (Carry, B);
1540+
1541+ // Carry = LowSum < XhYl ? 0x100000000 : 0
1542+ Value *LowSum, *XhYl;
1543+ if (!match (Carry,
1544+ m_OneUse (m_Select (
1545+ m_OneUse (m_SpecificICmp (ICmpInst::ICMP_ULT, m_Value (LowSum),
1546+ m_Value (XhYl))),
1547+ m_SpecificInt (APInt::getOneBitSet (BitWidth, BitWidth / 2 )),
1548+ m_Zero ()))))
1549+ return false ;
1550+
1551+ // XhYl can be Xh*Yl or Xl*Yh
1552+ if (!CheckHiLo (XhYl, X, Y)) {
1553+ if (CheckHiLo (XhYl, Y, X))
1554+ std::swap (X, Y);
1555+ else
1556+ return false ;
1557+ }
1558+ if (XhYl->hasNUsesOrMore (3 ))
1559+ return false ;
1560+
1561+ // B = LowSum >> 32
1562+ if (!match (B, m_OneUse (m_LShr (m_Specific (LowSum),
1563+ m_SpecificInt (BitWidth / 2 )))) ||
1564+ LowSum->hasNUsesOrMore (3 ))
1565+ return false ;
1566+
1567+ // LowSum = XhYl + XlYh + XlYl>>32
1568+ Value *XlYh, *XlYl;
1569+ auto XlYlHi = m_LShr (m_Value (XlYl), m_SpecificInt (BitWidth / 2 ));
1570+ if (!match (LowSum,
1571+ m_c_Add (m_Specific (XhYl),
1572+ m_OneUse (m_c_Add (m_OneUse (m_Value (XlYh)), XlYlHi)))) &&
1573+ !match (LowSum, m_c_Add (m_OneUse (m_Value (XlYh)),
1574+ m_OneUse (m_c_Add (m_Specific (XhYl), XlYlHi)))) &&
1575+ !match (LowSum,
1576+ m_c_Add (XlYlHi, m_OneUse (m_c_Add (m_Specific (XhYl),
1577+ m_OneUse (m_Value (XlYh)))))))
1578+ return false ;
1579+
1580+ // Check XlYl and XlYh
1581+ if (!CheckLoLo (XlYl, X, Y))
1582+ return false ;
1583+ if (!CheckHiLo (XlYh, Y, X))
1584+ return false ;
1585+
1586+ return CreateMulHigh (X, Y);
1587+ };
1588+
1589+ auto FoldMulHighLadder = [&](Value *X, Value *Y, Instruction *A,
1590+ Instruction *B) {
1591+ // xh*yh + c2>>32 + c3>>32
1592+ // c2 = xh*yl + (xl*yl>>32); c3 = c2&0xffffffff + xl*yh
1593+ // or c2 = (xl*yh&0xffffffff) + xh*yl + (xl*yl>>32); c3 = xh*yl
1594+ Value *XlYh, *XhYl, *XlYl, *C2, *C3;
1595+ // Strip off the two expected shifts.
1596+ if (!match (A, m_LShr (m_Value (C2), m_SpecificInt (BitWidth / 2 ))) ||
1597+ !match (B, m_LShr (m_Value (C3), m_SpecificInt (BitWidth / 2 ))))
1598+ return false ;
1599+
1600+ if (match (C3, m_c_Add (m_Add (m_Value (), m_Value ()), m_Value ())))
1601+ std::swap (C2, C3);
1602+ // Try to match c2 = (xl*yh&0xffffffff) + xh*yl + (xl*yl>>32)
1603+ if (match (C2,
1604+ m_c_Add (m_c_Add (m_And (m_Specific (C3), m_SpecificInt (LowMask)),
1605+ m_Value (XlYh)),
1606+ m_LShr (m_Value (XlYl), m_SpecificInt (BitWidth / 2 )))) ||
1607+ match (C2, m_c_Add (m_c_Add (m_And (m_Specific (C3), m_SpecificInt (LowMask)),
1608+ m_LShr (m_Value (XlYl),
1609+ m_SpecificInt (BitWidth / 2 ))),
1610+ m_Value (XlYh))) ||
1611+ match (C2, m_c_Add (m_c_Add (m_LShr (m_Value (XlYl),
1612+ m_SpecificInt (BitWidth / 2 )),
1613+ m_Value (XlYh)),
1614+ m_And (m_Specific (C3), m_SpecificInt (LowMask))))) {
1615+ XhYl = C3;
1616+ } else {
1617+ // Match c3 = c2&0xffffffff + xl*yh
1618+ if (!match (C3, m_c_Add (m_And (m_Specific (C2), m_SpecificInt (LowMask)),
1619+ m_Value (XlYh))))
1620+ std::swap (C2, C3);
1621+ if (!match (C3, m_c_Add (m_OneUse (
1622+ m_And (m_Specific (C2), m_SpecificInt (LowMask))),
1623+ m_Value (XlYh))) ||
1624+ !C3->hasOneUse () || C2->hasNUsesOrMore (3 ))
1625+ return false ;
1626+
1627+ // Match c2 = xh*yl + (xl*yl >> 32)
1628+ if (!match (C2, m_c_Add (m_LShr (m_Value (XlYl), m_SpecificInt (BitWidth / 2 )),
1629+ m_Value (XhYl))))
1630+ return false ;
1631+ }
1632+
1633+ // Match XhYl and XlYh - they can appear either way around.
1634+ if (!CheckHiLo (XlYh, Y, X))
1635+ std::swap (XlYh, XhYl);
1636+ if (!CheckHiLo (XlYh, Y, X))
1637+ return false ;
1638+ if (!CheckHiLo (XhYl, X, Y))
1639+ return false ;
1640+ if (!CheckLoLo (XlYl, X, Y))
1641+ return false ;
1642+
1643+ return CreateMulHigh (X, Y);
1644+ };
1645+
1646+ auto FoldMulHighLadder4 = [&](Value *X, Value *Y, Instruction *A,
1647+ Instruction *B, Instruction *C) {
1648+ // / Ladder4: xh*yh + (xl*yh)>>32 + (xh+yl)>>32 + low>>32;
1649+ // / low = (xl*yl)>>32 + (xl*yh)&0xffffffff + (xh*yl)&0xffffffff
1650+
1651+ // Find A = Low >> 32 and B/C = XhYl>>32, XlYh>>32.
1652+ auto ShiftAdd =
1653+ m_LShr (m_Add (m_Value (), m_Value ()), m_SpecificInt (BitWidth / 2 ));
1654+ if (!match (A, ShiftAdd))
1655+ std::swap (A, B);
1656+ if (!match (A, ShiftAdd))
1657+ std::swap (A, C);
1658+ Value *Low;
1659+ if (!match (A, m_LShr (m_OneUse (m_Value (Low)), m_SpecificInt (BitWidth / 2 ))))
1660+ return false ;
1661+
1662+ // Match B == XhYl>>32 and C == XlYh>>32
1663+ Value *XhYl, *XlYh;
1664+ if (!match (B, m_LShr (m_Value (XhYl), m_SpecificInt (BitWidth / 2 ))) ||
1665+ !match (C, m_LShr (m_Value (XlYh), m_SpecificInt (BitWidth / 2 ))))
1666+ return false ;
1667+ if (!CheckHiLo (XhYl, X, Y))
1668+ std::swap (XhYl, XlYh);
1669+ if (!CheckHiLo (XhYl, X, Y) || XhYl->hasNUsesOrMore (3 ))
1670+ return false ;
1671+ if (!CheckHiLo (XlYh, Y, X) || XlYh->hasNUsesOrMore (3 ))
1672+ return false ;
1673+
1674+ // Match Low as XlYl>>32 + XhYl&0xffffffff + XlYh&0xffffffff
1675+ Value *XlYl;
1676+ if (!match (
1677+ Low,
1678+ m_c_Add (
1679+ m_OneUse (m_c_Add (
1680+ m_OneUse (m_And (m_Specific (XhYl), m_SpecificInt (LowMask))),
1681+ m_OneUse (m_And (m_Specific (XlYh), m_SpecificInt (LowMask))))),
1682+ m_OneUse (
1683+ m_LShr (m_Value (XlYl), m_SpecificInt (BitWidth / 2 ))))) &&
1684+ !match (
1685+ Low,
1686+ m_c_Add (
1687+ m_OneUse (m_c_Add (
1688+ m_OneUse (m_And (m_Specific (XhYl), m_SpecificInt (LowMask))),
1689+ m_OneUse (
1690+ m_LShr (m_Value (XlYl), m_SpecificInt (BitWidth / 2 ))))),
1691+ m_OneUse (m_And (m_Specific (XlYh), m_SpecificInt (LowMask))))) &&
1692+ !match (
1693+ Low,
1694+ m_c_Add (
1695+ m_OneUse (m_c_Add (
1696+ m_OneUse (m_And (m_Specific (XlYh), m_SpecificInt (LowMask))),
1697+ m_OneUse (
1698+ m_LShr (m_Value (XlYl), m_SpecificInt (BitWidth / 2 ))))),
1699+ m_OneUse (m_And (m_Specific (XhYl), m_SpecificInt (LowMask))))))
1700+ return false ;
1701+ if (!CheckLoLo (XlYl, X, Y))
1702+ return false ;
1703+
1704+ return CreateMulHigh (X, Y);
1705+ };
1706+
1707+ auto FoldMulHighCarry4 = [&](Value *X, Value *Y, Instruction *Carry,
1708+ Instruction *B, Instruction *C) {
1709+ // xh*yh + carry + crosssum>>32 + (xl*yl + crosssum&0xffffffff) >> 32
1710+ // crosssum = xh*yl+xl*yh
1711+ // carry = crosssum < xh*yl ? 0x1000000 : 0
1712+ if (Carry->getOpcode () != Instruction::Select)
1713+ std::swap (Carry, B);
1714+ if (Carry->getOpcode () != Instruction::Select)
1715+ std::swap (Carry, C);
1716+
1717+ // Carry = CrossSum < XhYl ? 0x100000000 : 0
1718+ Value *CrossSum, *XhYl;
1719+ if (!match (Carry,
1720+ m_OneUse (m_Select (
1721+ m_OneUse (m_SpecificICmp (ICmpInst::ICMP_ULT,
1722+ m_Value (CrossSum), m_Value (XhYl))),
1723+ m_SpecificInt (APInt::getOneBitSet (BitWidth, BitWidth / 2 )),
1724+ m_Zero ()))))
1725+ return false ;
1726+
1727+ if (!match (B, m_LShr (m_Specific (CrossSum), m_SpecificInt (BitWidth / 2 ))))
1728+ std::swap (B, C);
1729+ if (!match (B, m_LShr (m_Specific (CrossSum), m_SpecificInt (BitWidth / 2 ))))
1730+ return false ;
1731+
1732+ Value *XlYl, *LowAccum;
1733+ if (!match (C, m_LShr (m_Value (LowAccum), m_SpecificInt (BitWidth / 2 ))) ||
1734+ !match (LowAccum, m_c_Add (m_OneUse (m_LShr (m_Value (XlYl),
1735+ m_SpecificInt (BitWidth / 2 ))),
1736+ m_OneUse (m_And (m_Specific (CrossSum),
1737+ m_SpecificInt (LowMask))))) ||
1738+ LowAccum->hasNUsesOrMore (3 ))
1739+ return false ;
1740+ if (!CheckLoLo (XlYl, X, Y))
1741+ return false ;
1742+
1743+ if (!CheckHiLo (XhYl, X, Y))
1744+ std::swap (X, Y);
1745+ if (!CheckHiLo (XhYl, X, Y))
1746+ return false ;
1747+ Value *XlYh;
1748+ if (!match (CrossSum, m_c_Add (m_Specific (XhYl), m_OneUse (m_Value (XlYh)))) ||
1749+ !CheckHiLo (XlYh, Y, X) || CrossSum->hasNUsesOrMore (4 ) ||
1750+ XhYl->hasNUsesOrMore (3 ))
1751+ return false ;
1752+
1753+ return CreateMulHigh (X, Y);
1754+ };
1755+
1756+ // X and Y are the two inputs, A, B and C are other parts of the pattern
1757+ // (crosssum>>32, carry, etc).
1758+ Value *X, *Y;
1759+ Instruction *A, *B, *C;
1760+ auto HiHi = m_OneUse (m_Mul (m_LShr (m_Value (X), m_SpecificInt (BitWidth / 2 )),
1761+ m_LShr (m_Value (Y), m_SpecificInt (BitWidth / 2 ))));
1762+ if ((match (&I, m_c_Add (HiHi, m_OneUse (m_Add (m_Instruction (A),
1763+ m_Instruction (B))))) ||
1764+ match (&I, m_c_Add (m_Instruction (A),
1765+ m_OneUse (m_c_Add (HiHi, m_Instruction (B)))))) &&
1766+ A->hasOneUse () && B->hasOneUse ())
1767+ if (FoldMulHighCarry (X, Y, A, B) || FoldMulHighLadder (X, Y, A, B))
1768+ return true ;
1769+
1770+ if ((match (&I, m_c_Add (HiHi, m_OneUse (m_c_Add (
1771+ m_Instruction (A),
1772+ m_OneUse (m_Add (m_Instruction (B),
1773+ m_Instruction (C))))))) ||
1774+ match (&I, m_c_Add (m_Instruction (A),
1775+ m_OneUse (m_c_Add (
1776+ HiHi, m_OneUse (m_Add (m_Instruction (B),
1777+ m_Instruction (C))))))) ||
1778+ match (&I, m_c_Add (m_Instruction (A),
1779+ m_OneUse (m_c_Add (
1780+ m_Instruction (B),
1781+ m_OneUse (m_c_Add (HiHi, m_Instruction (C))))))) ||
1782+ match (&I,
1783+ m_c_Add (m_OneUse (m_c_Add (HiHi, m_Instruction (A))),
1784+ m_OneUse (m_Add (m_Instruction (B), m_Instruction (C)))))) &&
1785+ A->hasOneUse () && B->hasOneUse () && C->hasOneUse ())
1786+ return FoldMulHighCarry4 (X, Y, A, B, C) ||
1787+ FoldMulHighLadder4 (X, Y, A, B, C);
1788+
1789+ return false ;
1790+ }
1791+
14691792// / This is the entry point for folds that could be implemented in regular
14701793// / InstCombine, but they are separated because they are not expected to
14711794// / occur frequently and/or have more than a constant-length pattern match.
@@ -1495,6 +1818,7 @@ static bool foldUnusualPatterns(Function &F, DominatorTree &DT,
14951818 MadeChange |= foldConsecutiveLoads (I, DL, TTI, AA, DT);
14961819 MadeChange |= foldPatternedLoads (I, DL);
14971820 MadeChange |= foldICmpOrChain (I, DL, TTI, AA, DT);
1821+ MadeChange |= foldMulHigh (I);
14981822 // NOTE: This function introduces erasing of the instruction `I`, so it
14991823 // needs to be called at the end of this sequence, otherwise we may make
15001824 // bugs.
0 commit comments