@@ -50,10 +50,10 @@ template <typename Op>
5050struct ScalarOpToLibmCall : public OpRewritePattern <Op> {
5151public:
5252 using OpRewritePattern<Op>::OpRewritePattern;
53- ScalarOpToLibmCall (MLIRContext *context, StringRef floatFunc ,
54- StringRef doubleFunc)
55- : OpRewritePattern<Op>(context), floatFunc(floatFunc),
56- doubleFunc (doubleFunc){};
53+ ScalarOpToLibmCall (MLIRContext *context, PatternBenefit benefit ,
54+ StringRef floatFunc, StringRef doubleFunc)
55+ : OpRewritePattern<Op>(context, benefit ), floatFunc(floatFunc),
56+ doubleFunc (doubleFunc) {};
5757
5858 LogicalResult matchAndRewrite (Op op, PatternRewriter &rewriter) const final ;
5959
@@ -62,10 +62,11 @@ struct ScalarOpToLibmCall : public OpRewritePattern<Op> {
6262};
6363
6464template <typename OpTy>
65- void populatePatternsForOp (RewritePatternSet &patterns, MLIRContext *ctx,
66- StringRef floatFunc, StringRef doubleFunc) {
67- patterns.add <VecOpToScalarOp<OpTy>, PromoteOpToF32<OpTy>>(ctx);
68- patterns.add <ScalarOpToLibmCall<OpTy>>(ctx, floatFunc, doubleFunc);
65+ void populatePatternsForOp (RewritePatternSet &patterns, PatternBenefit benefit,
66+ MLIRContext *ctx, StringRef floatFunc,
67+ StringRef doubleFunc) {
68+ patterns.add <VecOpToScalarOp<OpTy>, PromoteOpToF32<OpTy>>(ctx, benefit);
69+ patterns.add <ScalarOpToLibmCall<OpTy>>(ctx, benefit, floatFunc, doubleFunc);
6970}
7071
7172} // namespace
@@ -162,49 +163,49 @@ ScalarOpToLibmCall<Op>::matchAndRewrite(Op op,
162163}
163164
164165void mlir::populateMathToLibmConversionPatterns (
165- RewritePatternSet &patterns, const ConvertMathToLibmOptions &options) {
166+ RewritePatternSet &patterns, const ConvertMathToLibmOptions &options, PatternBenefit benefit ) {
166167 MLIRContext *ctx = patterns.getContext ();
167168
168- populatePatternsForOp<math::AbsFOp>(patterns, ctx, " fabsf" , " fabs" );
169- populatePatternsForOp<math::AcosOp>(patterns, ctx, " acosf" , " acos" );
170- populatePatternsForOp<math::AcoshOp>(patterns, ctx, " acoshf" , " acosh" );
171- populatePatternsForOp<math::AsinOp>(patterns, ctx, " asinf" , " asin" );
172- populatePatternsForOp<math::AsinhOp>(patterns, ctx, " asinhf" , " asinh" );
173- populatePatternsForOp<math::Atan2Op>(patterns, ctx, " atan2f" , " atan2" );
174- populatePatternsForOp<math::AtanOp>(patterns, ctx, " atanf" , " atan" );
175- populatePatternsForOp<math::AtanhOp>(patterns, ctx, " atanhf" , " atanh" );
176- populatePatternsForOp<math::CbrtOp>(patterns, ctx, " cbrtf" , " cbrt" );
177- populatePatternsForOp<math::CeilOp>(patterns, ctx, " ceilf" , " ceil" );
178- populatePatternsForOp<math::CosOp>(patterns, ctx, " cosf" , " cos" );
179- populatePatternsForOp<math::CoshOp>(patterns, ctx, " coshf" , " cosh" );
180- populatePatternsForOp<math::ErfOp>(patterns, ctx, " erff" , " erf" );
181- populatePatternsForOp<math::ExpOp>(patterns, ctx, " expf" , " exp" );
182- populatePatternsForOp<math::Exp2Op>(patterns, ctx, " exp2f" , " exp2" );
183- populatePatternsForOp<math::ExpM1Op>(patterns, ctx, " expm1f" , " expm1" );
184- populatePatternsForOp<math::FloorOp>(patterns, ctx, " floorf" , " floor" );
185- populatePatternsForOp<math::FmaOp>(patterns, ctx, " fmaf" , " fma" );
186- populatePatternsForOp<math::LogOp>(patterns, ctx, " logf" , " log" );
187- populatePatternsForOp<math::Log2Op>(patterns, ctx, " log2f" , " log2" );
188- populatePatternsForOp<math::Log10Op>(patterns, ctx, " log10f" , " log10" );
189- populatePatternsForOp<math::Log1pOp>(patterns, ctx, " log1pf" , " log1p" );
190- populatePatternsForOp<math::PowFOp>(patterns, ctx, " powf" , " pow" );
169+ populatePatternsForOp<math::AbsFOp>(patterns, benefit, ctx, " fabsf" , " fabs" );
170+ populatePatternsForOp<math::AcosOp>(patterns, benefit, ctx, " acosf" , " acos" );
171+ populatePatternsForOp<math::AcoshOp>(patterns, benefit, ctx, " acoshf" , " acosh" );
172+ populatePatternsForOp<math::AsinOp>(patterns, benefit, ctx, " asinf" , " asin" );
173+ populatePatternsForOp<math::AsinhOp>(patterns, benefit, ctx, " asinhf" , " asinh" );
174+ populatePatternsForOp<math::Atan2Op>(patterns, benefit, ctx, " atan2f" , " atan2" );
175+ populatePatternsForOp<math::AtanOp>(patterns, benefit, ctx, " atanf" , " atan" );
176+ populatePatternsForOp<math::AtanhOp>(patterns, benefit, ctx, " atanhf" , " atanh" );
177+ populatePatternsForOp<math::CbrtOp>(patterns, benefit, ctx, " cbrtf" , " cbrt" );
178+ populatePatternsForOp<math::CeilOp>(patterns, benefit, ctx, " ceilf" , " ceil" );
179+ populatePatternsForOp<math::CosOp>(patterns, benefit, ctx, " cosf" , " cos" );
180+ populatePatternsForOp<math::CoshOp>(patterns, benefit, ctx, " coshf" , " cosh" );
181+ populatePatternsForOp<math::ErfOp>(patterns, benefit, ctx, " erff" , " erf" );
182+ populatePatternsForOp<math::ExpOp>(patterns, benefit, ctx, " expf" , " exp" );
183+ populatePatternsForOp<math::Exp2Op>(patterns, benefit, ctx, " exp2f" , " exp2" );
184+ populatePatternsForOp<math::ExpM1Op>(patterns, benefit, ctx, " expm1f" , " expm1" );
185+ populatePatternsForOp<math::FloorOp>(patterns, benefit, ctx, " floorf" , " floor" );
186+ populatePatternsForOp<math::FmaOp>(patterns, benefit, ctx, " fmaf" , " fma" );
187+ populatePatternsForOp<math::LogOp>(patterns, benefit, ctx, " logf" , " log" );
188+ populatePatternsForOp<math::Log2Op>(patterns, benefit, ctx, " log2f" , " log2" );
189+ populatePatternsForOp<math::Log10Op>(patterns, benefit, ctx, " log10f" , " log10" );
190+ populatePatternsForOp<math::Log1pOp>(patterns, benefit, ctx, " log1pf" , " log1p" );
191+ populatePatternsForOp<math::PowFOp>(patterns, benefit, ctx, " powf" , " pow" );
191192 if (options.allowC23Features )
192- populatePatternsForOp<math::RoundEvenOp>(patterns, ctx, " roundevenf" ,
193+ populatePatternsForOp<math::RoundEvenOp>(patterns, benefit, ctx, " roundevenf" ,
193194 " roundeven" );
194195 else if (options.roundingModeIsDefault )
195- populatePatternsForOp<math::RoundEvenOp>(patterns, ctx, " nearbyintf" ,
196+ populatePatternsForOp<math::RoundEvenOp>(patterns, benefit, ctx, " nearbyintf" ,
196197 " nearbyint" );
197198 // Roundeven: using nearbyint (pre-C23) for roundeven requires the
198199 // rounding mode to be FE_TONEAREST (the default). Otherwise we need to
199200 // issue a call to set the rounding mode (which this pass currently can't do).
200- populatePatternsForOp<math::RoundOp>(patterns, ctx, " roundf" , " round" );
201- populatePatternsForOp<math::SinOp>(patterns, ctx, " sinf" , " sin" );
202- populatePatternsForOp<math::SinhOp>(patterns, ctx, " sinhf" , " sinh" );
203- populatePatternsForOp<math::SqrtOp>(patterns, ctx, " sqrtf" , " sqrt" );
204- populatePatternsForOp<math::RsqrtOp>(patterns, ctx, " rsqrtf" , " rsqrt" );
205- populatePatternsForOp<math::TanOp>(patterns, ctx, " tanf" , " tan" );
206- populatePatternsForOp<math::TanhOp>(patterns, ctx, " tanhf" , " tanh" );
207- populatePatternsForOp<math::TruncOp>(patterns, ctx, " truncf" , " trunc" );
201+ populatePatternsForOp<math::RoundOp>(patterns, benefit, ctx, " roundf" , " round" );
202+ populatePatternsForOp<math::SinOp>(patterns, benefit, ctx, " sinf" , " sin" );
203+ populatePatternsForOp<math::SinhOp>(patterns, benefit, ctx, " sinhf" , " sinh" );
204+ populatePatternsForOp<math::SqrtOp>(patterns, benefit, ctx, " sqrtf" , " sqrt" );
205+ populatePatternsForOp<math::RsqrtOp>(patterns, benefit, ctx, " rsqrtf" , " rsqrt" );
206+ populatePatternsForOp<math::TanOp>(patterns, benefit, ctx, " tanf" , " tan" );
207+ populatePatternsForOp<math::TanhOp>(patterns, benefit, ctx, " tanhf" , " tanh" );
208+ populatePatternsForOp<math::TruncOp>(patterns, benefit, ctx, " truncf" , " trunc" );
208209}
209210
210211namespace {
0 commit comments