@@ -1667,28 +1667,158 @@ void mlir::populatePolynomialApproximateErfPattern(
16671667 patterns.add <ErfPolynomialApproximation>(patterns.getContext ());
16681668}
16691669
1670+ void mlir::populateMathF32ExpansionPatterns (
1671+ RewritePatternSet &patterns,
1672+ const std::function<bool (StringRef)> &predicate) {
1673+ MLIRContext *context = patterns.getContext ();
1674+ if (predicate (" acos" )) {
1675+ patterns.add <ReuseF32Expansion<math::AcosOp>>(context);
1676+ }
1677+ if (predicate (" acosh" )) {
1678+ patterns.add <ReuseF32Expansion<math::AcoshOp>>(context);
1679+ }
1680+ if (predicate (" asin" )) {
1681+ patterns.add <ReuseF32Expansion<math::AsinOp>>(context);
1682+ }
1683+ if (predicate (" asinh" )) {
1684+ patterns.add <ReuseF32Expansion<math::AsinhOp>>(context);
1685+ }
1686+ if (predicate (" atan" )) {
1687+ patterns.add <ReuseF32Expansion<math::AtanOp>>(context);
1688+ }
1689+ if (predicate (" atan2" )) {
1690+ patterns.add <ReuseF32Expansion<math::Atan2Op>>(context);
1691+ }
1692+ if (predicate (" atanh" )) {
1693+ patterns.add <ReuseF32Expansion<math::AtanhOp>>(context);
1694+ }
1695+ if (predicate (" cbrt" )) {
1696+ patterns.add <ReuseF32Expansion<math::CbrtOp>>(context);
1697+ }
1698+ if (predicate (" cos" )) {
1699+ patterns.add <ReuseF32Expansion<math::CosOp>>(context);
1700+ }
1701+ if (predicate (" cosh" )) {
1702+ patterns.add <ReuseF32Expansion<math::CoshOp>>(context);
1703+ }
1704+ if (predicate (" erf" )) {
1705+ patterns.add <ReuseF32Expansion<math::ErfOp>>(context);
1706+ }
1707+ if (predicate (" exp" )) {
1708+ patterns.add <ReuseF32Expansion<math::ExpOp>>(context);
1709+ }
1710+ if (predicate (" exp2" )) {
1711+ patterns.add <ReuseF32Expansion<math::Exp2Op>>(context);
1712+ }
1713+ if (predicate (" expm1" )) {
1714+ patterns.add <ReuseF32Expansion<math::ExpM1Op>>(context);
1715+ }
1716+ if (predicate (" log" )) {
1717+ patterns.add <ReuseF32Expansion<math::LogOp>>(context);
1718+ }
1719+ if (predicate (" log10" )) {
1720+ patterns.add <ReuseF32Expansion<math::Log10Op>>(context);
1721+ }
1722+ if (predicate (" log2" )) {
1723+ patterns.add <ReuseF32Expansion<math::Log2Op>>(context);
1724+ }
1725+ if (predicate (" log1p" )) {
1726+ patterns.add <ReuseF32Expansion<math::Log1pOp>>(context);
1727+ }
1728+ if (predicate (" powf" )) {
1729+ patterns.add <ReuseF32Expansion<math::PowFOp>>(context);
1730+ }
1731+ if (predicate (" rsqrt" )) {
1732+ patterns.add <ReuseF32Expansion<math::RsqrtOp>>(context);
1733+ }
1734+ if (predicate (" sin" )) {
1735+ patterns.add <ReuseF32Expansion<math::SinOp>>(context);
1736+ }
1737+ if (predicate (" sinh" )) {
1738+ patterns.add <ReuseF32Expansion<math::SinhOp>>(context);
1739+ }
1740+ if (predicate (" sqrt" )) {
1741+ patterns.add <ReuseF32Expansion<math::SqrtOp>>(context);
1742+ }
1743+ if (predicate (" tan" )) {
1744+ patterns.add <ReuseF32Expansion<math::TanOp>>(context);
1745+ }
1746+ if (predicate (" tanh" )) {
1747+ patterns.add <ReuseF32Expansion<math::TanhOp>>(context);
1748+ }
1749+ }
1750+
1751+ void mlir::populateMathPolynomialApproximationPatterns (
1752+ RewritePatternSet &patterns,
1753+ const std::function<bool (StringRef)> &predicate) {
1754+ MLIRContext *context = patterns.getContext ();
1755+ if (predicate (" acos" )) {
1756+ patterns.add <AcosPolynomialApproximation>(context);
1757+ }
1758+ if (predicate (" asin" )) {
1759+ patterns.add <AsinPolynomialApproximation>(context);
1760+ }
1761+ if (predicate (" atan" )) {
1762+ patterns.add <AtanApproximation>(context);
1763+ }
1764+ if (predicate (" atan2" )) {
1765+ patterns.add <Atan2Approximation>(context);
1766+ }
1767+ if (predicate (" cbrt" )) {
1768+ patterns.add <CbrtApproximation>(context);
1769+ }
1770+ if (predicate (" cos" )) {
1771+ patterns.add <SinAndCosApproximation<false , math::CosOp>>(context);
1772+ }
1773+ if (predicate (" erf" )) {
1774+ patterns.add <ErfPolynomialApproximation>(context);
1775+ }
1776+ if (predicate (" exp" )) {
1777+ patterns.add <ExpApproximation>(context);
1778+ }
1779+ if (predicate (" expm1" )) {
1780+ patterns.add <ExpM1Approximation>(context);
1781+ }
1782+ if (predicate (" log" )) {
1783+ patterns.add <LogApproximation>(context);
1784+ }
1785+ if (predicate (" log2" )) {
1786+ patterns.add <Log2Approximation>(context);
1787+ }
1788+ if (predicate (" log1p" )) {
1789+ patterns.add <Log1pApproximation>(context);
1790+ }
1791+ if (predicate (" rsqrt" )) {
1792+ patterns.add <RsqrtApproximation>(context);
1793+ }
1794+ if (predicate (" sin" )) {
1795+ patterns.add <SinAndCosApproximation<true , math::SinOp>>(context);
1796+ }
1797+ if (predicate (" tanh" )) {
1798+ patterns.add <TanhApproximation>(context);
1799+ }
1800+ }
1801+
16701802void mlir::populateMathPolynomialApproximationPatterns (
16711803 RewritePatternSet &patterns,
16721804 const MathPolynomialApproximationOptions &options) {
1673- // Patterns for leveraging existing f32 lowerings on other data types.
1674- patterns
1675- .add <ReuseF32Expansion<math::AtanOp>, ReuseF32Expansion<math::Atan2Op>,
1676- ReuseF32Expansion<math::TanhOp>, ReuseF32Expansion<math::LogOp>,
1677- ReuseF32Expansion<math::Log2Op>, ReuseF32Expansion<math::Log1pOp>,
1678- ReuseF32Expansion<math::ErfOp>, ReuseF32Expansion<math::ExpOp>,
1679- ReuseF32Expansion<math::ExpM1Op>, ReuseF32Expansion<math::CbrtOp>,
1680- ReuseF32Expansion<math::SinOp>, ReuseF32Expansion<math::CosOp>>(
1681- patterns.getContext ());
1682-
1683- patterns
1684- .add <AtanApproximation, Atan2Approximation, TanhApproximation,
1685- LogApproximation, Log2Approximation, Log1pApproximation,
1686- ErfPolynomialApproximation, AsinPolynomialApproximation,
1687- AcosPolynomialApproximation, ExpApproximation, ExpM1Approximation,
1688- CbrtApproximation, SinAndCosApproximation<true , math::SinOp>,
1689- SinAndCosApproximation<false , math::CosOp>>(patterns.getContext ());
1805+ mlir::populateMathF32ExpansionPatterns (patterns, [](StringRef name) {
1806+ return name == " atan" || name == " atan2" || name == " tanh" ||
1807+ name == " log" || name == " log2" || name == " log1p" ||
1808+ name == " erf" || name == " exp" || name == " expm1" ||
1809+ name == " cbrt" || name == " sin" || name == " cos" ;
1810+ });
1811+
1812+ populateMathPolynomialApproximationPatterns (patterns, [](StringRef name) {
1813+ return name == " atan" || name == " atan2" || name == " tanh" ||
1814+ name == " log" || name == " log2" || name == " log1p" ||
1815+ name == " erf" || name == " asin" || name == " acos" || name == " exp" ||
1816+ name == " expm1" || name == " cbrt" || name == " sin" || name == " cos" ;
1817+ });
1818+
16901819 if (options.enableAvx2 ) {
1691- patterns.add <RsqrtApproximation, ReuseF32Expansion<math::RsqrtOp>>(
1692- patterns.getContext ());
1820+ auto predicateRsqrt = [](StringRef name) { return name == " rsqrt" ; };
1821+ mlir::populateMathF32ExpansionPatterns (patterns, predicateRsqrt);
1822+ mlir::populateMathPolynomialApproximationPatterns (patterns, predicateRsqrt);
16931823 }
16941824}
0 commit comments