@@ -1667,28 +1667,125 @@ void mlir::populatePolynomialApproximateErfPattern(
16671667 patterns.add <ErfPolynomialApproximation>(patterns.getContext ());
16681668}
16691669
1670+ template <typename OpType>
1671+ static void
1672+ populateMathF32ExpansionPattern (RewritePatternSet &patterns,
1673+ llvm::function_ref<bool (StringRef)> predicate) {
1674+ if (predicate (OpType::getOperationName ())) {
1675+ patterns.add <ReuseF32Expansion<OpType>>(patterns.getContext ());
1676+ }
1677+ }
1678+
1679+ void mlir::populateMathF32ExpansionPatterns (
1680+ RewritePatternSet &patterns,
1681+ llvm::function_ref<bool (StringRef)> predicate) {
1682+ populateMathF32ExpansionPattern<math::AcosOp>(patterns, predicate);
1683+ populateMathF32ExpansionPattern<math::AcoshOp>(patterns, predicate);
1684+ populateMathF32ExpansionPattern<math::AsinOp>(patterns, predicate);
1685+ populateMathF32ExpansionPattern<math::AsinhOp>(patterns, predicate);
1686+ populateMathF32ExpansionPattern<math::AtanOp>(patterns, predicate);
1687+ populateMathF32ExpansionPattern<math::Atan2Op>(patterns, predicate);
1688+ populateMathF32ExpansionPattern<math::AtanhOp>(patterns, predicate);
1689+ populateMathF32ExpansionPattern<math::CbrtOp>(patterns, predicate);
1690+ populateMathF32ExpansionPattern<math::CosOp>(patterns, predicate);
1691+ populateMathF32ExpansionPattern<math::CoshOp>(patterns, predicate);
1692+ populateMathF32ExpansionPattern<math::ErfOp>(patterns, predicate);
1693+ populateMathF32ExpansionPattern<math::ExpOp>(patterns, predicate);
1694+ populateMathF32ExpansionPattern<math::Exp2Op>(patterns, predicate);
1695+ populateMathF32ExpansionPattern<math::ExpM1Op>(patterns, predicate);
1696+ populateMathF32ExpansionPattern<math::LogOp>(patterns, predicate);
1697+ populateMathF32ExpansionPattern<math::Log10Op>(patterns, predicate);
1698+ populateMathF32ExpansionPattern<math::Log1pOp>(patterns, predicate);
1699+ populateMathF32ExpansionPattern<math::Log2Op>(patterns, predicate);
1700+ populateMathF32ExpansionPattern<math::PowFOp>(patterns, predicate);
1701+ populateMathF32ExpansionPattern<math::RsqrtOp>(patterns, predicate);
1702+ populateMathF32ExpansionPattern<math::SinOp>(patterns, predicate);
1703+ populateMathF32ExpansionPattern<math::SinhOp>(patterns, predicate);
1704+ populateMathF32ExpansionPattern<math::SqrtOp>(patterns, predicate);
1705+ populateMathF32ExpansionPattern<math::TanOp>(patterns, predicate);
1706+ populateMathF32ExpansionPattern<math::TanhOp>(patterns, predicate);
1707+ }
1708+
1709+ template <typename OpType, typename PatternType>
1710+ static void populateMathPolynomialApproximationPattern (
1711+ RewritePatternSet &patterns,
1712+ llvm::function_ref<bool (StringRef)> predicate) {
1713+ if (predicate (OpType::getOperationName ())) {
1714+ patterns.add <PatternType>(patterns.getContext ());
1715+ }
1716+ }
1717+
1718+ void mlir::populateMathPolynomialApproximationPatterns (
1719+ RewritePatternSet &patterns,
1720+ llvm::function_ref<bool (StringRef)> predicate) {
1721+ populateMathPolynomialApproximationPattern<AcosOp,
1722+ AcosPolynomialApproximation>(
1723+ patterns, predicate);
1724+ populateMathPolynomialApproximationPattern<AsinOp,
1725+ AsinPolynomialApproximation>(
1726+ patterns, predicate);
1727+ populateMathPolynomialApproximationPattern<AtanOp, AtanApproximation>(
1728+ patterns, predicate);
1729+ populateMathPolynomialApproximationPattern<Atan2Op, Atan2Approximation>(
1730+ patterns, predicate);
1731+ populateMathPolynomialApproximationPattern<CbrtOp, CbrtApproximation>(
1732+ patterns, predicate);
1733+ populateMathPolynomialApproximationPattern<
1734+ CosOp, SinAndCosApproximation<false , math::CosOp>>(patterns, predicate);
1735+ populateMathPolynomialApproximationPattern<ErfOp, ErfPolynomialApproximation>(
1736+ patterns, predicate);
1737+ populateMathPolynomialApproximationPattern<ExpOp, ExpApproximation>(
1738+ patterns, predicate);
1739+ populateMathPolynomialApproximationPattern<ExpM1Op, ExpM1Approximation>(
1740+ patterns, predicate);
1741+ populateMathPolynomialApproximationPattern<LogOp, LogApproximation>(
1742+ patterns, predicate);
1743+ populateMathPolynomialApproximationPattern<Log2Op, Log2Approximation>(
1744+ patterns, predicate);
1745+ populateMathPolynomialApproximationPattern<Log1pOp, Log1pApproximation>(
1746+ patterns, predicate);
1747+ populateMathPolynomialApproximationPattern<RsqrtOp, RsqrtApproximation>(
1748+ patterns, predicate);
1749+ populateMathPolynomialApproximationPattern<
1750+ SinOp, SinAndCosApproximation<true , math::SinOp>>(patterns, predicate);
1751+ populateMathPolynomialApproximationPattern<TanhOp, TanhApproximation>(
1752+ patterns, predicate);
1753+ }
1754+
16701755void mlir::populateMathPolynomialApproximationPatterns (
16711756 RewritePatternSet &patterns,
16721757 const MathPolynomialApproximationOptions &options) {
1673- // Patterns for leveraging existing f32 lowerings on other data types.
1674- patterns
1675- .add <ReuseF32Expansion<math::AtanOp>, ReuseF32Expansion<math::Atan2Op>,
1676- ReuseF32Expansion<math::TanhOp>, ReuseF32Expansion<math::LogOp>,
1677- ReuseF32Expansion<math::Log2Op>, ReuseF32Expansion<math::Log1pOp>,
1678- ReuseF32Expansion<math::ErfOp>, ReuseF32Expansion<math::ExpOp>,
1679- ReuseF32Expansion<math::ExpM1Op>, ReuseF32Expansion<math::CbrtOp>,
1680- ReuseF32Expansion<math::SinOp>, ReuseF32Expansion<math::CosOp>>(
1681- patterns.getContext ());
1682-
1683- patterns
1684- .add <AtanApproximation, Atan2Approximation, TanhApproximation,
1685- LogApproximation, Log2Approximation, Log1pApproximation,
1686- ErfPolynomialApproximation, AsinPolynomialApproximation,
1687- AcosPolynomialApproximation, ExpApproximation, ExpM1Approximation,
1688- CbrtApproximation, SinAndCosApproximation<true , math::SinOp>,
1689- SinAndCosApproximation<false , math::CosOp>>(patterns.getContext ());
1758+ mlir::populateMathF32ExpansionPatterns (patterns, [](StringRef name) -> bool {
1759+ return llvm::is_contained (
1760+ {math::AtanOp::getOperationName (), math::Atan2Op::getOperationName (),
1761+ math::TanhOp::getOperationName (), math::LogOp::getOperationName (),
1762+ math::Log2Op::getOperationName (), math::Log1pOp::getOperationName (),
1763+ math::ErfOp::getOperationName (), math::ExpOp::getOperationName (),
1764+ math::ExpM1Op::getOperationName (), math::CbrtOp::getOperationName (),
1765+ math::SinOp::getOperationName (), math::CosOp::getOperationName ()},
1766+ name);
1767+ });
1768+
1769+ populateMathPolynomialApproximationPatterns (
1770+ patterns, [](StringRef name) -> bool {
1771+ return llvm::is_contained (
1772+ {math::AtanOp::getOperationName (),
1773+ math::Atan2Op::getOperationName (),
1774+ math::TanhOp::getOperationName (), math::LogOp::getOperationName (),
1775+ math::Log2Op::getOperationName (),
1776+ math::Log1pOp::getOperationName (), math::ErfOp::getOperationName (),
1777+ math::AsinOp::getOperationName (), math::AcosOp::getOperationName (),
1778+ math::ExpOp::getOperationName (), math::ExpM1Op::getOperationName (),
1779+ math::CbrtOp::getOperationName (), math::SinOp::getOperationName (),
1780+ math::CosOp::getOperationName ()},
1781+ name);
1782+ });
1783+
16901784 if (options.enableAvx2 ) {
1691- patterns.add <RsqrtApproximation, ReuseF32Expansion<math::RsqrtOp>>(
1692- patterns.getContext ());
1785+ auto predicateRsqrt = [](StringRef name) {
1786+ return name == math::RsqrtOp::getOperationName ();
1787+ };
1788+ mlir::populateMathF32ExpansionPatterns (patterns, predicateRsqrt);
1789+ mlir::populateMathPolynomialApproximationPatterns (patterns, predicateRsqrt);
16931790 }
16941791}
0 commit comments