diff --git a/mlir/include/mlir/Dialect/Math/Transforms/Passes.h b/mlir/include/mlir/Dialect/Math/Transforms/Passes.h index 9adc1c6940a15..c0fe5d3be448a 100644 --- a/mlir/include/mlir/Dialect/Math/Transforms/Passes.h +++ b/mlir/include/mlir/Dialect/Math/Transforms/Passes.h @@ -9,6 +9,7 @@ #ifndef MLIR_DIALECT_MATH_TRANSFORMS_PASSES_H_ #define MLIR_DIALECT_MATH_TRANSFORMS_PASSES_H_ +#include "mlir/IR/PatternMatch.h" #include "mlir/Pass/Pass.h" namespace mlir { @@ -52,12 +53,14 @@ void populatePolynomialApproximateErfcPattern(RewritePatternSet &patterns); // Adds patterns to convert to f32 around math functions for which `predicate` // returns true. void populateMathF32ExpansionPatterns( - RewritePatternSet &patterns, llvm::function_ref predicate); + RewritePatternSet &patterns, llvm::function_ref predicate, + PatternBenefit = 1); // Adds patterns to enable polynomial approximations for math functions for // which `predicate` returns true. void populateMathPolynomialApproximationPatterns( - RewritePatternSet &patterns, llvm::function_ref predicate); + RewritePatternSet &patterns, llvm::function_ref predicate, + PatternBenefit = 1); // Legacy. Calls both populateMathF32ExpansionPatterns and // populateMathPolynomialApproximationPatterns with predicates enabling a diff --git a/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp b/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp index 167eebd786dba..a26e380232a91 100644 --- a/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp +++ b/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp @@ -1776,90 +1776,93 @@ void mlir::populatePolynomialApproximateErfcPattern( template static void populateMathF32ExpansionPattern(RewritePatternSet &patterns, - llvm::function_ref predicate) { + llvm::function_ref predicate, + PatternBenefit benefit) { if (predicate(OpType::getOperationName())) { - patterns.add>(patterns.getContext()); + patterns.add>(patterns.getContext(), benefit); } } void mlir::populateMathF32ExpansionPatterns( - RewritePatternSet &patterns, - llvm::function_ref predicate) { - populateMathF32ExpansionPattern(patterns, predicate); - populateMathF32ExpansionPattern(patterns, predicate); - populateMathF32ExpansionPattern(patterns, predicate); - populateMathF32ExpansionPattern(patterns, predicate); - populateMathF32ExpansionPattern(patterns, predicate); - populateMathF32ExpansionPattern(patterns, predicate); - populateMathF32ExpansionPattern(patterns, predicate); - populateMathF32ExpansionPattern(patterns, predicate); - populateMathF32ExpansionPattern(patterns, predicate); - populateMathF32ExpansionPattern(patterns, predicate); - populateMathF32ExpansionPattern(patterns, predicate); - populateMathF32ExpansionPattern(patterns, predicate); - populateMathF32ExpansionPattern(patterns, predicate); - populateMathF32ExpansionPattern(patterns, predicate); - populateMathF32ExpansionPattern(patterns, predicate); - populateMathF32ExpansionPattern(patterns, predicate); - populateMathF32ExpansionPattern(patterns, predicate); - populateMathF32ExpansionPattern(patterns, predicate); - populateMathF32ExpansionPattern(patterns, predicate); - populateMathF32ExpansionPattern(patterns, predicate); - populateMathF32ExpansionPattern(patterns, predicate); - populateMathF32ExpansionPattern(patterns, predicate); - populateMathF32ExpansionPattern(patterns, predicate); - populateMathF32ExpansionPattern(patterns, predicate); - populateMathF32ExpansionPattern(patterns, predicate); - populateMathF32ExpansionPattern(patterns, predicate); + RewritePatternSet &patterns, llvm::function_ref predicate, + PatternBenefit benefit) { + populateMathF32ExpansionPattern(patterns, predicate, benefit); + populateMathF32ExpansionPattern(patterns, predicate, benefit); + populateMathF32ExpansionPattern(patterns, predicate, benefit); + populateMathF32ExpansionPattern(patterns, predicate, benefit); + populateMathF32ExpansionPattern(patterns, predicate, benefit); + populateMathF32ExpansionPattern(patterns, predicate, benefit); + populateMathF32ExpansionPattern(patterns, predicate, benefit); + populateMathF32ExpansionPattern(patterns, predicate, benefit); + populateMathF32ExpansionPattern(patterns, predicate, benefit); + populateMathF32ExpansionPattern(patterns, predicate, benefit); + populateMathF32ExpansionPattern(patterns, predicate, benefit); + populateMathF32ExpansionPattern(patterns, predicate, benefit); + populateMathF32ExpansionPattern(patterns, predicate, benefit); + populateMathF32ExpansionPattern(patterns, predicate, benefit); + populateMathF32ExpansionPattern(patterns, predicate, benefit); + populateMathF32ExpansionPattern(patterns, predicate, benefit); + populateMathF32ExpansionPattern(patterns, predicate, benefit); + populateMathF32ExpansionPattern(patterns, predicate, benefit); + populateMathF32ExpansionPattern(patterns, predicate, benefit); + populateMathF32ExpansionPattern(patterns, predicate, benefit); + populateMathF32ExpansionPattern(patterns, predicate, benefit); + populateMathF32ExpansionPattern(patterns, predicate, benefit); + populateMathF32ExpansionPattern(patterns, predicate, benefit); + populateMathF32ExpansionPattern(patterns, predicate, benefit); + populateMathF32ExpansionPattern(patterns, predicate, benefit); + populateMathF32ExpansionPattern(patterns, predicate, benefit); } template static void populateMathPolynomialApproximationPattern( - RewritePatternSet &patterns, - llvm::function_ref predicate) { + RewritePatternSet &patterns, llvm::function_ref predicate, + PatternBenefit benefit) { if (predicate(OpType::getOperationName())) { - patterns.add(patterns.getContext()); + patterns.add(patterns.getContext(), benefit); } } void mlir::populateMathPolynomialApproximationPatterns( - RewritePatternSet &patterns, - llvm::function_ref predicate) { + RewritePatternSet &patterns, llvm::function_ref predicate, + PatternBenefit benefit) { populateMathPolynomialApproximationPattern( - patterns, predicate); + patterns, predicate, benefit); populateMathPolynomialApproximationPattern( - patterns, predicate); + patterns, predicate, benefit); populateMathPolynomialApproximationPattern( - patterns, predicate); + patterns, predicate, benefit); populateMathPolynomialApproximationPattern( - patterns, predicate); + patterns, predicate, benefit); populateMathPolynomialApproximationPattern( - patterns, predicate); + patterns, predicate, benefit); populateMathPolynomialApproximationPattern< - CosOp, SinAndCosApproximation>(patterns, predicate); + CosOp, SinAndCosApproximation>(patterns, predicate, + benefit); populateMathPolynomialApproximationPattern( - patterns, predicate); + patterns, predicate, benefit); populateMathPolynomialApproximationPattern( - patterns, predicate); + patterns, predicate, benefit); populateMathPolynomialApproximationPattern( - patterns, predicate); + patterns, predicate, benefit); populateMathPolynomialApproximationPattern( - patterns, predicate); + patterns, predicate, benefit); populateMathPolynomialApproximationPattern( - patterns, predicate); + patterns, predicate, benefit); populateMathPolynomialApproximationPattern( - patterns, predicate); + patterns, predicate, benefit); populateMathPolynomialApproximationPattern( - patterns, predicate); + patterns, predicate, benefit); populateMathPolynomialApproximationPattern( - patterns, predicate); + patterns, predicate, benefit); populateMathPolynomialApproximationPattern< - SinOp, SinAndCosApproximation>(patterns, predicate); + SinOp, SinAndCosApproximation>(patterns, predicate, + benefit); populateMathPolynomialApproximationPattern( - patterns, predicate); + patterns, predicate, benefit); } void mlir::populateMathPolynomialApproximationPatterns(