@@ -1466,6 +1466,306 @@ static bool foldLibCalls(Instruction &I, TargetTransformInfo &TTI,
14661466 return false ;
14671467}
14681468
1469+ // / Match high part of long multiplication.
1470+ // /
1471+ // / Considering a multiply made up of high and low parts, we can split the
1472+ // / multiply into:
1473+ // / x * y == (xh*T + xl) * (yh*T + yl)
1474+ // / where xh == x>>32 and xl == x & 0xffffffff. T = 2^32.
1475+ // / This expands to
1476+ // / xh*yh*T*T + xh*yl*T + xl*yh*T + xl*yl
1477+ // / which can be drawn as
1478+ // / [ xh*yh ]
1479+ // / [ xh*yl ]
1480+ // / [ xl*yh ]
1481+ // / [ xl*yl ]
1482+ // / We are looking for the "high" half, which is xh*yh + xh*yl>>32 + xl*yh>>32 +
1483+ // / some carrys. The carry makes this difficult and there are multiple ways of
1484+ // / representing it. The ones we attempt to support here are:
1485+ // / Carry: xh*yh + carry + lowsum
1486+ // / carry = lowsum < xh*yl ? 0x1000000 : 0
1487+ // / lowsum = xh*yl + xl*yh + (xl*yl>>32)
1488+ // / Ladder: xh*yh + c2>>32 + c3>>32
1489+ // / c2 = xh*yl + (xl*yl>>32); c3 = c2&0xffffffff + xl*yh
1490+ // / Carry4: xh*yh + carry + crosssum>>32 + (xl*yl + crosssum&0xffffffff) >> 32
1491+ // / crosssum = xh*yl + xl*yh
1492+ // / carry = crosssum < xh*yl ? 0x1000000 : 0
1493+ // / Ladder4: xh*yh + (xl*yh)>>32 + (xh*yl)>>32 + low>>32;
1494+ // / low = (xl*yl)>>32 + (xl*yh)&0xffffffff + (xh*yl)&0xffffffff
1495+ // /
1496+ // / They all start by matching xh*yh + 2 or 3 other operands. The bottom of the
1497+ // / tree is xh*yh, xh*yl, xl*yh and xl*yl.
1498+ static bool foldMulHigh (Instruction &I) {
1499+ Type *Ty = I.getType ();
1500+ if (!Ty->isIntOrIntVectorTy ())
1501+ return false ;
1502+
1503+ unsigned BW = Ty->getScalarSizeInBits ();
1504+ APInt LowMask = APInt::getLowBitsSet (BW, BW / 2 );
1505+ if (BW % 2 != 0 )
1506+ return false ;
1507+
1508+ auto CreateMulHigh = [&](Value *X, Value *Y) {
1509+ IRBuilder<> Builder (&I);
1510+ Type *NTy = Ty->getWithNewBitWidth (BW * 2 );
1511+ Value *XExt = Builder.CreateZExt (X, NTy);
1512+ Value *YExt = Builder.CreateZExt (Y, NTy);
1513+ Value *Mul = Builder.CreateMul (XExt, YExt);
1514+ Value *High = Builder.CreateLShr (Mul, BW);
1515+ Value *Res = Builder.CreateTrunc (High, Ty);
1516+ I.replaceAllUsesWith (Res);
1517+ LLVM_DEBUG (dbgs () << " Created long multiply from parts of " << *X << " and "
1518+ << *Y << " \n " );
1519+ return true ;
1520+ };
1521+
1522+ // Common check routines for X_lo*Y_lo and X_hi*Y_lo
1523+ auto CheckLoLo = [&](Value *XlYl, Value *X, Value *Y) {
1524+ return match (XlYl, m_c_Mul (m_And (m_Specific (X), m_SpecificInt (LowMask)),
1525+ m_And (m_Specific (Y), m_SpecificInt (LowMask))));
1526+ };
1527+ auto CheckHiLo = [&](Value *XhYl, Value *X, Value *Y) {
1528+ return match (XhYl, m_c_Mul (m_LShr (m_Specific (X), m_SpecificInt (BW / 2 )),
1529+ m_And (m_Specific (Y), m_SpecificInt (LowMask))));
1530+ };
1531+
1532+ auto foldMulHighCarry = [&](Value *X, Value *Y, Instruction *Carry,
1533+ Instruction *B) {
1534+ // Looking for LowSum >> 32 and carry (select)
1535+ if (Carry->getOpcode () != Instruction::Select)
1536+ std::swap (Carry, B);
1537+
1538+ // Carry = LowSum < XhYl ? 0x100000000 : 0
1539+ CmpPredicate Pred;
1540+ Value *LowSum, *XhYl;
1541+ if (!match (Carry,
1542+ m_OneUse (m_Select (
1543+ m_OneUse (m_ICmp (Pred, m_Value (LowSum), m_Value (XhYl))),
1544+ m_SpecificInt (APInt (BW, 1 ) << BW / 2 ), m_SpecificInt (0 )))) ||
1545+ Pred != ICmpInst::ICMP_ULT)
1546+ return false ;
1547+
1548+ // XhYl can be Xh*Yl or Xl*Yh
1549+ if (!CheckHiLo (XhYl, X, Y)) {
1550+ if (CheckHiLo (XhYl, Y, X))
1551+ std::swap (X, Y);
1552+ else
1553+ return false ;
1554+ }
1555+ if (XhYl->hasNUsesOrMore (3 ))
1556+ return false ;
1557+
1558+ // B = LowSum >> 16
1559+ if (!match (B,
1560+ m_OneUse (m_LShr (m_Specific (LowSum), m_SpecificInt (BW / 2 )))) ||
1561+ LowSum->hasNUsesOrMore (3 ))
1562+ return false ;
1563+
1564+ // LowSum = XhYl + XlYh + XlYl>>32
1565+ Value *XlYh, *XlYl;
1566+ auto XlYlHi = m_LShr (m_Value (XlYl), m_SpecificInt (BW / 2 ));
1567+ if (!match (LowSum,
1568+ m_c_Add (m_Specific (XhYl),
1569+ m_OneUse (m_c_Add (m_OneUse (m_Value (XlYh)), XlYlHi)))) &&
1570+ !match (LowSum, m_c_Add (m_OneUse (m_Value (XlYh)),
1571+ m_OneUse (m_c_Add (m_Specific (XhYl), XlYlHi)))) &&
1572+ !match (LowSum,
1573+ m_c_Add (XlYlHi, m_OneUse (m_c_Add (m_Specific (XhYl),
1574+ m_OneUse (m_Value (XlYh)))))))
1575+ return false ;
1576+
1577+ // Check XlYl and XlYh
1578+ if (!CheckLoLo (XlYl, X, Y))
1579+ return false ;
1580+ if (!CheckHiLo (XlYh, Y, X))
1581+ return false ;
1582+
1583+ return CreateMulHigh (X, Y);
1584+ };
1585+
1586+ auto foldMulHighLadder = [&](Value *X, Value *Y, Instruction *A,
1587+ Instruction *B) {
1588+ // xh*yh + c2>>32 + c3>>32
1589+ // c2 = xh*yl + (xl*yl >> 32); c3 = c2&0xffffffff + xl*yh
1590+ Value *XlYh, *XhYl, *C2, *C3;
1591+ // Strip off the two expected shifts.
1592+ if (!match (A, m_LShr (m_Value (C2), m_SpecificInt (BW / 2 ))) ||
1593+ !match (B, m_LShr (m_Value (C3), m_SpecificInt (BW / 2 ))))
1594+ return false ;
1595+
1596+ // Match c3 = c2&0xffffffff + xl*yh
1597+ if (!match (C3, m_c_Add (m_And (m_Specific (C2), m_SpecificInt (LowMask)),
1598+ m_Value (XhYl))))
1599+ std::swap (C2, C3);
1600+ if (!match (C3,
1601+ m_c_Add (m_OneUse (m_And (m_Specific (C2), m_SpecificInt (LowMask))),
1602+ m_Value (XhYl))) ||
1603+ !C3->hasOneUse () || C2->hasNUsesOrMore (3 ))
1604+ return false ;
1605+
1606+ // Match c2 = xh*yl + (xl*yl >> 32)
1607+ Value *XlYl;
1608+ if (!match (C2, m_c_Add (m_LShr (m_Value (XlYl), m_SpecificInt (BW / 2 )),
1609+ m_Value (XlYh))))
1610+ return false ;
1611+
1612+ // Match XhYl and XlYh - they can appear either way around.
1613+ if (!CheckHiLo (XlYh, Y, X))
1614+ std::swap (XlYh, XhYl);
1615+ if (!CheckHiLo (XlYh, Y, X))
1616+ return false ;
1617+ if (!CheckHiLo (XhYl, X, Y))
1618+ return false ;
1619+ if (!CheckLoLo (XlYl, X, Y))
1620+ return false ;
1621+
1622+ return CreateMulHigh (X, Y);
1623+ };
1624+
1625+ auto foldMulHighLadder4 = [&](Value *X, Value *Y, Instruction *A,
1626+ Instruction *B, Instruction *C) {
1627+ // / Ladder4: xh*yh + (xl*yh)>>32 + (xh+yl)>>32 + low>>32;
1628+ // / low = (xl*yl)>>32 + (xl*yh)&0xffffffff + (xh*yl)&0xffffffff
1629+
1630+ // Find A = Low >> 32 and B/C = XhYl>>32, XlYh>>32.
1631+ auto ShiftAdd = m_LShr (m_Add (m_Value (), m_Value ()), m_SpecificInt (BW / 2 ));
1632+ if (!match (A, ShiftAdd))
1633+ std::swap (A, B);
1634+ if (!match (A, ShiftAdd))
1635+ std::swap (A, C);
1636+ Value *Low;
1637+ if (!match (A, m_LShr (m_OneUse (m_Value (Low)), m_SpecificInt (BW / 2 ))))
1638+ return false ;
1639+
1640+ // Match B == XhYl>>32 and C == XlYh>>32
1641+ Value *XhYl, *XlYh;
1642+ if (!match (B, m_LShr (m_Value (XhYl), m_SpecificInt (BW / 2 ))) ||
1643+ !match (C, m_LShr (m_Value (XlYh), m_SpecificInt (BW / 2 ))))
1644+ return false ;
1645+ if (!CheckHiLo (XhYl, X, Y))
1646+ std::swap (XhYl, XlYh);
1647+ if (!CheckHiLo (XhYl, X, Y) || XhYl->hasNUsesOrMore (3 ))
1648+ return false ;
1649+ if (!CheckHiLo (XlYh, Y, X) || XlYh->hasNUsesOrMore (3 ))
1650+ return false ;
1651+
1652+ // Match Low as XlYl>>32 + XhYl&0xffffffff + XlYh&0xffffffff
1653+ Value *XlYl;
1654+ if (!match (
1655+ Low,
1656+ m_c_Add (
1657+ m_OneUse (m_c_Add (
1658+ m_OneUse (m_And (m_Specific (XhYl), m_SpecificInt (LowMask))),
1659+ m_OneUse (m_And (m_Specific (XlYh), m_SpecificInt (LowMask))))),
1660+ m_OneUse (m_LShr (m_Value (XlYl), m_SpecificInt (BW / 2 ))))) &&
1661+ !match (
1662+ Low,
1663+ m_c_Add (
1664+ m_OneUse (m_c_Add (
1665+ m_OneUse (m_And (m_Specific (XhYl), m_SpecificInt (LowMask))),
1666+ m_OneUse (m_LShr (m_Value (XlYl), m_SpecificInt (BW / 2 ))))),
1667+ m_OneUse (m_And (m_Specific (XlYh), m_SpecificInt (LowMask))))) &&
1668+ !match (
1669+ Low,
1670+ m_c_Add (
1671+ m_OneUse (m_c_Add (
1672+ m_OneUse (m_And (m_Specific (XlYh), m_SpecificInt (LowMask))),
1673+ m_OneUse (m_LShr (m_Value (XlYl), m_SpecificInt (BW / 2 ))))),
1674+ m_OneUse (m_And (m_Specific (XhYl), m_SpecificInt (LowMask))))))
1675+ return false ;
1676+ if (!CheckLoLo (XlYl, X, Y))
1677+ return false ;
1678+
1679+ return CreateMulHigh (X, Y);
1680+ };
1681+
1682+ auto foldMulHighCarry4 = [&](Value *X, Value *Y, Instruction *Carry,
1683+ Instruction *B, Instruction *C) {
1684+ // xh*yh + carry + crosssum>>32 + (xl*yl + crosssum&0xffffffff) >> 32
1685+ // crosssum = xh*yl+xl*yh
1686+ // carry = crosssum < xh*yl ? 0x1000000 : 0
1687+ if (Carry->getOpcode () != Instruction::Select)
1688+ std::swap (Carry, B);
1689+ if (Carry->getOpcode () != Instruction::Select)
1690+ std::swap (Carry, C);
1691+
1692+ // Carry = CrossSum < XhYl ? 0x100000000 : 0
1693+ CmpPredicate Pred;
1694+ Value *CrossSum, *XhYl;
1695+ if (!match (Carry,
1696+ m_OneUse (m_Select (
1697+ m_OneUse (m_ICmp (Pred, m_Value (CrossSum), m_Value (XhYl))),
1698+ m_SpecificInt (APInt (BW, 1 ) << BW / 2 ), m_SpecificInt (0 )))) ||
1699+ Pred != ICmpInst::ICMP_ULT)
1700+ return false ;
1701+
1702+ if (!match (B, m_LShr (m_Specific (CrossSum), m_SpecificInt (BW / 2 ))))
1703+ std::swap (B, C);
1704+ if (!match (B, m_LShr (m_Specific (CrossSum), m_SpecificInt (BW / 2 ))))
1705+ return false ;
1706+
1707+ Value *XlYl, *LowAccum;
1708+ if (!match (C, m_LShr (m_Value (LowAccum), m_SpecificInt (BW / 2 ))) ||
1709+ !match (LowAccum,
1710+ m_c_Add (m_OneUse (m_LShr (m_Value (XlYl), m_SpecificInt (BW / 2 ))),
1711+ m_OneUse (m_And (m_Specific (CrossSum),
1712+ m_SpecificInt (LowMask))))) ||
1713+ LowAccum->hasNUsesOrMore (3 ))
1714+ return false ;
1715+ if (!CheckLoLo (XlYl, X, Y))
1716+ return false ;
1717+
1718+ if (!CheckHiLo (XhYl, X, Y))
1719+ std::swap (X, Y);
1720+ if (!CheckHiLo (XhYl, X, Y))
1721+ return false ;
1722+ if (!match (CrossSum,
1723+ m_c_Add (m_Specific (XhYl),
1724+ m_OneUse (m_c_Mul (
1725+ m_LShr (m_Specific (Y), m_SpecificInt (BW / 2 )),
1726+ m_And (m_Specific (X), m_SpecificInt (LowMask)))))) ||
1727+ CrossSum->hasNUsesOrMore (4 ) || XhYl->hasNUsesOrMore (3 ))
1728+ return false ;
1729+
1730+ return CreateMulHigh (X, Y);
1731+ };
1732+
1733+ // X and Y are the two inputs, A, B and C are other parts of the pattern
1734+ // (crosssum>>32, carry, etc).
1735+ Value *X, *Y;
1736+ Instruction *A, *B, *C;
1737+ auto HiHi = m_OneUse (m_Mul (m_LShr (m_Value (X), m_SpecificInt (BW / 2 )),
1738+ m_LShr (m_Value (Y), m_SpecificInt (BW / 2 ))));
1739+ if ((match (&I, m_c_Add (HiHi, m_OneUse (m_Add (m_Instruction (A),
1740+ m_Instruction (B))))) ||
1741+ match (&I, m_c_Add (m_Instruction (A),
1742+ m_OneUse (m_c_Add (HiHi, m_Instruction (B)))))) &&
1743+ A->hasOneUse () && B->hasOneUse ())
1744+ if (foldMulHighCarry (X, Y, A, B) || foldMulHighLadder (X, Y, A, B))
1745+ return true ;
1746+
1747+ if ((match (&I, m_c_Add (HiHi, m_OneUse (m_c_Add (
1748+ m_Instruction (A),
1749+ m_OneUse (m_Add (m_Instruction (B),
1750+ m_Instruction (C))))))) ||
1751+ match (&I, m_c_Add (m_Instruction (A),
1752+ m_OneUse (m_c_Add (
1753+ HiHi, m_OneUse (m_Add (m_Instruction (B),
1754+ m_Instruction (C))))))) ||
1755+ match (&I, m_c_Add (m_Instruction (A),
1756+ m_OneUse (m_c_Add (
1757+ m_Instruction (B),
1758+ m_OneUse (m_c_Add (HiHi, m_Instruction (C))))))) ||
1759+ match (&I,
1760+ m_c_Add (m_OneUse (m_c_Add (HiHi, m_Instruction (A))),
1761+ m_OneUse (m_Add (m_Instruction (B), m_Instruction (C)))))) &&
1762+ A->hasOneUse () && B->hasOneUse () && C->hasOneUse ())
1763+ return foldMulHighCarry4 (X, Y, A, B, C) ||
1764+ foldMulHighLadder4 (X, Y, A, B, C);
1765+
1766+ return false ;
1767+ }
1768+
14691769// / This is the entry point for folds that could be implemented in regular
14701770// / InstCombine, but they are separated because they are not expected to
14711771// / occur frequently and/or have more than a constant-length pattern match.
@@ -1495,6 +1795,7 @@ static bool foldUnusualPatterns(Function &F, DominatorTree &DT,
14951795 MadeChange |= foldConsecutiveLoads (I, DL, TTI, AA, DT);
14961796 MadeChange |= foldPatternedLoads (I, DL);
14971797 MadeChange |= foldICmpOrChain (I, DL, TTI, AA, DT);
1798+ MadeChange |= foldMulHigh (I);
14981799 // NOTE: This function introduces erasing of the instruction `I`, so it
14991800 // needs to be called at the end of this sequence, otherwise we may make
15001801 // bugs.
0 commit comments