From a09c4297da745fba484fb6f9a19169b114068223 Mon Sep 17 00:00:00 2001 From: Emilio Cota Date: Tue, 11 Mar 2025 10:27:25 -0400 Subject: [PATCH] [mlir][math] add benefit arg to populate math approximations/expansions This is a follow-up to #127291, which added the benefit arg to lowerings to intrinsics and libm. In this change we add the benefit arg to the math approximation and expansion lowerings, which allows users to establish a preferred order among all three math lowerings, namely approximations, intrinsics and libm. Note that we're only updating the new API added in #126103. The legacy one (`mlir::populateMathPolynomialApproximationPatterns`) is left unmodified to encourage users to move out of it. --- .../mlir/Dialect/Math/Transforms/Passes.h | 7 +- .../Transforms/PolynomialApproximation.cpp | 105 +++++++++--------- 2 files changed, 59 insertions(+), 53 deletions(-) diff --git a/mlir/include/mlir/Dialect/Math/Transforms/Passes.h b/mlir/include/mlir/Dialect/Math/Transforms/Passes.h index 9adc1c6940a15..c0fe5d3be448a 100644 --- a/mlir/include/mlir/Dialect/Math/Transforms/Passes.h +++ b/mlir/include/mlir/Dialect/Math/Transforms/Passes.h @@ -9,6 +9,7 @@ #ifndef MLIR_DIALECT_MATH_TRANSFORMS_PASSES_H_ #define MLIR_DIALECT_MATH_TRANSFORMS_PASSES_H_ +#include "mlir/IR/PatternMatch.h" #include "mlir/Pass/Pass.h" namespace mlir { @@ -52,12 +53,14 @@ void populatePolynomialApproximateErfcPattern(RewritePatternSet &patterns); // Adds patterns to convert to f32 around math functions for which `predicate` // returns true. void populateMathF32ExpansionPatterns( - RewritePatternSet &patterns, llvm::function_ref predicate); + RewritePatternSet &patterns, llvm::function_ref predicate, + PatternBenefit = 1); // Adds patterns to enable polynomial approximations for math functions for // which `predicate` returns true. void populateMathPolynomialApproximationPatterns( - RewritePatternSet &patterns, llvm::function_ref predicate); + RewritePatternSet &patterns, llvm::function_ref predicate, + PatternBenefit = 1); // Legacy. Calls both populateMathF32ExpansionPatterns and // populateMathPolynomialApproximationPatterns with predicates enabling a diff --git a/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp b/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp index 167eebd786dba..a26e380232a91 100644 --- a/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp +++ b/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp @@ -1776,90 +1776,93 @@ void mlir::populatePolynomialApproximateErfcPattern( template static void populateMathF32ExpansionPattern(RewritePatternSet &patterns, - llvm::function_ref predicate) { + llvm::function_ref predicate, + PatternBenefit benefit) { if (predicate(OpType::getOperationName())) { - patterns.add>(patterns.getContext()); + patterns.add>(patterns.getContext(), benefit); } } void mlir::populateMathF32ExpansionPatterns( - RewritePatternSet &patterns, - llvm::function_ref predicate) { - populateMathF32ExpansionPattern(patterns, predicate); - populateMathF32ExpansionPattern(patterns, predicate); - populateMathF32ExpansionPattern(patterns, predicate); - populateMathF32ExpansionPattern(patterns, predicate); - populateMathF32ExpansionPattern(patterns, predicate); - populateMathF32ExpansionPattern(patterns, predicate); - populateMathF32ExpansionPattern(patterns, predicate); - populateMathF32ExpansionPattern(patterns, predicate); - populateMathF32ExpansionPattern(patterns, predicate); - populateMathF32ExpansionPattern(patterns, predicate); - populateMathF32ExpansionPattern(patterns, predicate); - populateMathF32ExpansionPattern(patterns, predicate); - populateMathF32ExpansionPattern(patterns, predicate); - populateMathF32ExpansionPattern(patterns, predicate); - populateMathF32ExpansionPattern(patterns, predicate); - populateMathF32ExpansionPattern(patterns, predicate); - populateMathF32ExpansionPattern(patterns, predicate); - populateMathF32ExpansionPattern(patterns, predicate); - populateMathF32ExpansionPattern(patterns, predicate); - populateMathF32ExpansionPattern(patterns, predicate); - populateMathF32ExpansionPattern(patterns, predicate); - populateMathF32ExpansionPattern(patterns, predicate); - populateMathF32ExpansionPattern(patterns, predicate); - populateMathF32ExpansionPattern(patterns, predicate); - populateMathF32ExpansionPattern(patterns, predicate); - populateMathF32ExpansionPattern(patterns, predicate); + RewritePatternSet &patterns, llvm::function_ref predicate, + PatternBenefit benefit) { + populateMathF32ExpansionPattern(patterns, predicate, benefit); + populateMathF32ExpansionPattern(patterns, predicate, benefit); + populateMathF32ExpansionPattern(patterns, predicate, benefit); + populateMathF32ExpansionPattern(patterns, predicate, benefit); + populateMathF32ExpansionPattern(patterns, predicate, benefit); + populateMathF32ExpansionPattern(patterns, predicate, benefit); + populateMathF32ExpansionPattern(patterns, predicate, benefit); + populateMathF32ExpansionPattern(patterns, predicate, benefit); + populateMathF32ExpansionPattern(patterns, predicate, benefit); + populateMathF32ExpansionPattern(patterns, predicate, benefit); + populateMathF32ExpansionPattern(patterns, predicate, benefit); + populateMathF32ExpansionPattern(patterns, predicate, benefit); + populateMathF32ExpansionPattern(patterns, predicate, benefit); + populateMathF32ExpansionPattern(patterns, predicate, benefit); + populateMathF32ExpansionPattern(patterns, predicate, benefit); + populateMathF32ExpansionPattern(patterns, predicate, benefit); + populateMathF32ExpansionPattern(patterns, predicate, benefit); + populateMathF32ExpansionPattern(patterns, predicate, benefit); + populateMathF32ExpansionPattern(patterns, predicate, benefit); + populateMathF32ExpansionPattern(patterns, predicate, benefit); + populateMathF32ExpansionPattern(patterns, predicate, benefit); + populateMathF32ExpansionPattern(patterns, predicate, benefit); + populateMathF32ExpansionPattern(patterns, predicate, benefit); + populateMathF32ExpansionPattern(patterns, predicate, benefit); + populateMathF32ExpansionPattern(patterns, predicate, benefit); + populateMathF32ExpansionPattern(patterns, predicate, benefit); } template static void populateMathPolynomialApproximationPattern( - RewritePatternSet &patterns, - llvm::function_ref predicate) { + RewritePatternSet &patterns, llvm::function_ref predicate, + PatternBenefit benefit) { if (predicate(OpType::getOperationName())) { - patterns.add(patterns.getContext()); + patterns.add(patterns.getContext(), benefit); } } void mlir::populateMathPolynomialApproximationPatterns( - RewritePatternSet &patterns, - llvm::function_ref predicate) { + RewritePatternSet &patterns, llvm::function_ref predicate, + PatternBenefit benefit) { populateMathPolynomialApproximationPattern( - patterns, predicate); + patterns, predicate, benefit); populateMathPolynomialApproximationPattern( - patterns, predicate); + patterns, predicate, benefit); populateMathPolynomialApproximationPattern( - patterns, predicate); + patterns, predicate, benefit); populateMathPolynomialApproximationPattern( - patterns, predicate); + patterns, predicate, benefit); populateMathPolynomialApproximationPattern( - patterns, predicate); + patterns, predicate, benefit); populateMathPolynomialApproximationPattern< - CosOp, SinAndCosApproximation>(patterns, predicate); + CosOp, SinAndCosApproximation>(patterns, predicate, + benefit); populateMathPolynomialApproximationPattern( - patterns, predicate); + patterns, predicate, benefit); populateMathPolynomialApproximationPattern( - patterns, predicate); + patterns, predicate, benefit); populateMathPolynomialApproximationPattern( - patterns, predicate); + patterns, predicate, benefit); populateMathPolynomialApproximationPattern( - patterns, predicate); + patterns, predicate, benefit); populateMathPolynomialApproximationPattern( - patterns, predicate); + patterns, predicate, benefit); populateMathPolynomialApproximationPattern( - patterns, predicate); + patterns, predicate, benefit); populateMathPolynomialApproximationPattern( - patterns, predicate); + patterns, predicate, benefit); populateMathPolynomialApproximationPattern( - patterns, predicate); + patterns, predicate, benefit); populateMathPolynomialApproximationPattern< - SinOp, SinAndCosApproximation>(patterns, predicate); + SinOp, SinAndCosApproximation>(patterns, predicate, + benefit); populateMathPolynomialApproximationPattern( - patterns, predicate); + patterns, predicate, benefit); } void mlir::populateMathPolynomialApproximationPatterns(