Skip to content

Commit e751f07

Browse files
committed
[AutoBump] Merge with fixes of 5c93eb5 (Feb 14)
2 parents 2b453ee + 5c93eb5 commit e751f07

File tree

4 files changed

+51
-47
lines changed

4 files changed

+51
-47
lines changed

mlir/include/mlir/Conversion/MathToLLVM/MathToLLVM.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
#ifndef MLIR_CONVERSION_MATHTOLLVM_MATHTOLLVM_H
1010
#define MLIR_CONVERSION_MATHTOLLVM_MATHTOLLVM_H
1111

12+
#include "mlir/IR/PatternMatch.h"
1213
#include <memory>
1314

1415
namespace mlir {
@@ -23,7 +24,8 @@ class Pass;
2324

2425
void populateMathToLLVMConversionPatterns(const LLVMTypeConverter &converter,
2526
RewritePatternSet &patterns,
26-
bool approximateLog1p = true);
27+
bool approximateLog1p = true,
28+
PatternBenefit benefit = 1);
2729

2830
void registerConvertMathToLLVMInterface(DialectRegistry &registry);
2931

mlir/include/mlir/Conversion/MathToLibm/MathToLibm.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,8 @@ class OperationPass;
2020
/// Populate the given list with patterns that convert from Math to Libm calls.
2121
/// If log1pBenefit is present, use it instead of benefit for the Log1p op.
2222
void populateMathToLibmConversionPatterns(
23-
RewritePatternSet &patterns, const ConvertMathToLibmOptions &options);
23+
RewritePatternSet &patterns, const ConvertMathToLibmOptions &options,
24+
PatternBenefit benefit = 1);
2425

2526
/// Create a pass to convert Math operations to libm calls.
2627
std::unique_ptr<OperationPass<ModuleOp>> createConvertMathToLibmPass();

mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -304,9 +304,9 @@ struct ConvertMathToLLVMPass
304304

305305
void mlir::populateMathToLLVMConversionPatterns(
306306
const LLVMTypeConverter &converter, RewritePatternSet &patterns,
307-
bool approximateLog1p) {
307+
bool approximateLog1p, PatternBenefit benefit) {
308308
if (approximateLog1p)
309-
patterns.add<Log1pOpLowering>(converter);
309+
patterns.add<Log1pOpLowering>(converter, benefit);
310310
// clang-format off
311311
patterns.add<
312312
AbsFOpLowering,
@@ -337,7 +337,7 @@ void mlir::populateMathToLLVMConversionPatterns(
337337
FTruncOpLowering,
338338
TanOpLowering,
339339
TanhOpLowering
340-
>(converter);
340+
>(converter, benefit);
341341
// clang-format on
342342
}
343343

mlir/lib/Conversion/MathToLibm/MathToLibm.cpp

Lines changed: 43 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -50,10 +50,10 @@ template <typename Op>
5050
struct ScalarOpToLibmCall : public OpRewritePattern<Op> {
5151
public:
5252
using OpRewritePattern<Op>::OpRewritePattern;
53-
ScalarOpToLibmCall(MLIRContext *context, StringRef floatFunc,
54-
StringRef doubleFunc)
55-
: OpRewritePattern<Op>(context), floatFunc(floatFunc),
56-
doubleFunc(doubleFunc){};
53+
ScalarOpToLibmCall(MLIRContext *context, PatternBenefit benefit,
54+
StringRef floatFunc, StringRef doubleFunc)
55+
: OpRewritePattern<Op>(context, benefit), floatFunc(floatFunc),
56+
doubleFunc(doubleFunc) {};
5757

5858
LogicalResult matchAndRewrite(Op op, PatternRewriter &rewriter) const final;
5959

@@ -62,10 +62,11 @@ struct ScalarOpToLibmCall : public OpRewritePattern<Op> {
6262
};
6363

6464
template <typename OpTy>
65-
void populatePatternsForOp(RewritePatternSet &patterns, MLIRContext *ctx,
66-
StringRef floatFunc, StringRef doubleFunc) {
67-
patterns.add<VecOpToScalarOp<OpTy>, PromoteOpToF32<OpTy>>(ctx);
68-
patterns.add<ScalarOpToLibmCall<OpTy>>(ctx, floatFunc, doubleFunc);
65+
void populatePatternsForOp(RewritePatternSet &patterns, PatternBenefit benefit,
66+
MLIRContext *ctx, StringRef floatFunc,
67+
StringRef doubleFunc) {
68+
patterns.add<VecOpToScalarOp<OpTy>, PromoteOpToF32<OpTy>>(ctx, benefit);
69+
patterns.add<ScalarOpToLibmCall<OpTy>>(ctx, benefit, floatFunc, doubleFunc);
6970
}
7071

7172
} // namespace
@@ -162,49 +163,49 @@ ScalarOpToLibmCall<Op>::matchAndRewrite(Op op,
162163
}
163164

164165
void mlir::populateMathToLibmConversionPatterns(
165-
RewritePatternSet &patterns, const ConvertMathToLibmOptions &options) {
166+
RewritePatternSet &patterns, const ConvertMathToLibmOptions &options, PatternBenefit benefit) {
166167
MLIRContext *ctx = patterns.getContext();
167168

168-
populatePatternsForOp<math::AbsFOp>(patterns, ctx, "fabsf", "fabs");
169-
populatePatternsForOp<math::AcosOp>(patterns, ctx, "acosf", "acos");
170-
populatePatternsForOp<math::AcoshOp>(patterns, ctx, "acoshf", "acosh");
171-
populatePatternsForOp<math::AsinOp>(patterns, ctx, "asinf", "asin");
172-
populatePatternsForOp<math::AsinhOp>(patterns, ctx, "asinhf", "asinh");
173-
populatePatternsForOp<math::Atan2Op>(patterns, ctx, "atan2f", "atan2");
174-
populatePatternsForOp<math::AtanOp>(patterns, ctx, "atanf", "atan");
175-
populatePatternsForOp<math::AtanhOp>(patterns, ctx, "atanhf", "atanh");
176-
populatePatternsForOp<math::CbrtOp>(patterns, ctx, "cbrtf", "cbrt");
177-
populatePatternsForOp<math::CeilOp>(patterns, ctx, "ceilf", "ceil");
178-
populatePatternsForOp<math::CosOp>(patterns, ctx, "cosf", "cos");
179-
populatePatternsForOp<math::CoshOp>(patterns, ctx, "coshf", "cosh");
180-
populatePatternsForOp<math::ErfOp>(patterns, ctx, "erff", "erf");
181-
populatePatternsForOp<math::ExpOp>(patterns, ctx, "expf", "exp");
182-
populatePatternsForOp<math::Exp2Op>(patterns, ctx, "exp2f", "exp2");
183-
populatePatternsForOp<math::ExpM1Op>(patterns, ctx, "expm1f", "expm1");
184-
populatePatternsForOp<math::FloorOp>(patterns, ctx, "floorf", "floor");
185-
populatePatternsForOp<math::FmaOp>(patterns, ctx, "fmaf", "fma");
186-
populatePatternsForOp<math::LogOp>(patterns, ctx, "logf", "log");
187-
populatePatternsForOp<math::Log2Op>(patterns, ctx, "log2f", "log2");
188-
populatePatternsForOp<math::Log10Op>(patterns, ctx, "log10f", "log10");
189-
populatePatternsForOp<math::Log1pOp>(patterns, ctx, "log1pf", "log1p");
190-
populatePatternsForOp<math::PowFOp>(patterns, ctx, "powf", "pow");
169+
populatePatternsForOp<math::AbsFOp>(patterns, benefit, ctx, "fabsf", "fabs");
170+
populatePatternsForOp<math::AcosOp>(patterns, benefit, ctx, "acosf", "acos");
171+
populatePatternsForOp<math::AcoshOp>(patterns, benefit, ctx, "acoshf", "acosh");
172+
populatePatternsForOp<math::AsinOp>(patterns, benefit, ctx, "asinf", "asin");
173+
populatePatternsForOp<math::AsinhOp>(patterns, benefit, ctx, "asinhf", "asinh");
174+
populatePatternsForOp<math::Atan2Op>(patterns, benefit, ctx, "atan2f", "atan2");
175+
populatePatternsForOp<math::AtanOp>(patterns, benefit, ctx, "atanf", "atan");
176+
populatePatternsForOp<math::AtanhOp>(patterns, benefit, ctx, "atanhf", "atanh");
177+
populatePatternsForOp<math::CbrtOp>(patterns, benefit, ctx, "cbrtf", "cbrt");
178+
populatePatternsForOp<math::CeilOp>(patterns, benefit, ctx, "ceilf", "ceil");
179+
populatePatternsForOp<math::CosOp>(patterns, benefit, ctx, "cosf", "cos");
180+
populatePatternsForOp<math::CoshOp>(patterns, benefit, ctx, "coshf", "cosh");
181+
populatePatternsForOp<math::ErfOp>(patterns, benefit, ctx, "erff", "erf");
182+
populatePatternsForOp<math::ExpOp>(patterns, benefit, ctx, "expf", "exp");
183+
populatePatternsForOp<math::Exp2Op>(patterns, benefit, ctx, "exp2f", "exp2");
184+
populatePatternsForOp<math::ExpM1Op>(patterns, benefit, ctx, "expm1f", "expm1");
185+
populatePatternsForOp<math::FloorOp>(patterns, benefit, ctx, "floorf", "floor");
186+
populatePatternsForOp<math::FmaOp>(patterns, benefit, ctx, "fmaf", "fma");
187+
populatePatternsForOp<math::LogOp>(patterns, benefit, ctx, "logf", "log");
188+
populatePatternsForOp<math::Log2Op>(patterns, benefit, ctx, "log2f", "log2");
189+
populatePatternsForOp<math::Log10Op>(patterns, benefit, ctx, "log10f", "log10");
190+
populatePatternsForOp<math::Log1pOp>(patterns, benefit, ctx, "log1pf", "log1p");
191+
populatePatternsForOp<math::PowFOp>(patterns, benefit, ctx, "powf", "pow");
191192
if (options.allowC23Features)
192-
populatePatternsForOp<math::RoundEvenOp>(patterns, ctx, "roundevenf",
193+
populatePatternsForOp<math::RoundEvenOp>(patterns, benefit, ctx, "roundevenf",
193194
"roundeven");
194195
else if (options.roundingModeIsDefault)
195-
populatePatternsForOp<math::RoundEvenOp>(patterns, ctx, "nearbyintf",
196+
populatePatternsForOp<math::RoundEvenOp>(patterns, benefit, ctx, "nearbyintf",
196197
"nearbyint");
197198
// Roundeven: using nearbyint (pre-C23) for roundeven requires the
198199
// rounding mode to be FE_TONEAREST (the default). Otherwise we need to
199200
// issue a call to set the rounding mode (which this pass currently can't do).
200-
populatePatternsForOp<math::RoundOp>(patterns, ctx, "roundf", "round");
201-
populatePatternsForOp<math::SinOp>(patterns, ctx, "sinf", "sin");
202-
populatePatternsForOp<math::SinhOp>(patterns, ctx, "sinhf", "sinh");
203-
populatePatternsForOp<math::SqrtOp>(patterns, ctx, "sqrtf", "sqrt");
204-
populatePatternsForOp<math::RsqrtOp>(patterns, ctx, "rsqrtf", "rsqrt");
205-
populatePatternsForOp<math::TanOp>(patterns, ctx, "tanf", "tan");
206-
populatePatternsForOp<math::TanhOp>(patterns, ctx, "tanhf", "tanh");
207-
populatePatternsForOp<math::TruncOp>(patterns, ctx, "truncf", "trunc");
201+
populatePatternsForOp<math::RoundOp>(patterns, benefit, ctx, "roundf", "round");
202+
populatePatternsForOp<math::SinOp>(patterns, benefit, ctx, "sinf", "sin");
203+
populatePatternsForOp<math::SinhOp>(patterns, benefit, ctx, "sinhf", "sinh");
204+
populatePatternsForOp<math::SqrtOp>(patterns, benefit, ctx, "sqrtf", "sqrt");
205+
populatePatternsForOp<math::RsqrtOp>(patterns, benefit, ctx, "rsqrtf", "rsqrt");
206+
populatePatternsForOp<math::TanOp>(patterns, benefit, ctx, "tanf", "tan");
207+
populatePatternsForOp<math::TanhOp>(patterns, benefit, ctx, "tanhf", "tanh");
208+
populatePatternsForOp<math::TruncOp>(patterns, benefit, ctx, "truncf", "trunc");
208209
}
209210

210211
namespace {

0 commit comments

Comments
 (0)