Skip to content

Commit c0055ec

Browse files
[mlir][EmitC] Add MathToEmitC pass for math function lowering to EmitC (#113799)
This commit introduces a new MathToEmitC conversion pass that lowers selected math operations from the Math dialect to the emitc.call_opaque operation in the EmitC dialect. **Supported Math Operations:** The following operations are converted: - math.floor -> emitc.call_opaque<"floor"> - math.round -> emitc.call_opaque<"round"> - math.exp -> emitc.call_opaque<"exp"> - math.cos -> emitc.call_opaque<"cos"> - math.sin -> emitc.call_opaque<"sin"> - math.acos -> emitc.call_opaque<"acos"> - math.asin -> emitc.call_opaque<"asin"> - math.atan2 -> emitc.call_opaque<"atan2"> - math.ceil -> emitc.call_opaque<"ceil"> - math.absf -> emitc.call_opaque<"fabs"> - math.powf -> emitc.call_opaque<"pow"> **Target Language Standards:** The pass supports targeting different language standards: - C99: Generates calls with suffixes (e.g., floorf, fabsf) for single-precision floats. - CPP11: Prepends std:: to functions (e.g., std::floor, std::fabs). **Design Decisions:** The pass uses emitc.call_opaque instead of emitc.call to better emulate C-style function overloading. emitc.call_opaque does not require a unique type signature, making it more suitable for operations like <math.h> functions that may be overloaded for different types. This design choice ensures compatibility with C/C++ conventions.
1 parent 7a77f14 commit c0055ec

File tree

11 files changed

+384
-0
lines changed

11 files changed

+384
-0
lines changed
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
//===- MathToEmitC.h - Math to EmitC Patterns -------------------*- C++ -*-===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#ifndef MLIR_CONVERSION_MATHTOEMITC_MATHTOEMITC_H
10+
#define MLIR_CONVERSION_MATHTOEMITC_MATHTOEMITC_H
11+
#include "mlir/Dialect/EmitC/IR/EmitC.h"
12+
namespace mlir {
13+
class RewritePatternSet;
14+
namespace emitc {
15+
16+
/// Enum to specify the language target for EmitC code generation.
17+
enum class LanguageTarget { c99, cpp11 };
18+
19+
} // namespace emitc
20+
21+
void populateConvertMathToEmitCPatterns(RewritePatternSet &patterns,
22+
emitc::LanguageTarget languageTarget);
23+
} // namespace mlir
24+
25+
#endif // MLIR_CONVERSION_MATHTOEMITC_MATHTOEMITC_H
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
//===- MathToEmitCPass.h - Math to EmitC Pass -------------------*- C++ -*-===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#ifndef MLIR_CONVERSION_MATHTOEMITC_MATHTOEMITCPASS_H
10+
#define MLIR_CONVERSION_MATHTOEMITC_MATHTOEMITCPASS_H
11+
12+
#include "mlir/Conversion/MathToEmitC/MathToEmitC.h"
13+
#include <memory>
14+
namespace mlir {
15+
class Pass;
16+
17+
#define GEN_PASS_DECL_CONVERTMATHTOEMITC
18+
#include "mlir/Conversion/Passes.h.inc"
19+
} // namespace mlir
20+
21+
#endif // MLIR_CONVERSION_MATHTOEMITC_MATHTOEMITCPASS_H

mlir/include/mlir/Conversion/Passes.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@
4343
#include "mlir/Conversion/IndexToLLVM/IndexToLLVM.h"
4444
#include "mlir/Conversion/IndexToSPIRV/IndexToSPIRV.h"
4545
#include "mlir/Conversion/LinalgToStandard/LinalgToStandard.h"
46+
#include "mlir/Conversion/MathToEmitC/MathToEmitCPass.h"
4647
#include "mlir/Conversion/MathToFuncs/MathToFuncs.h"
4748
#include "mlir/Conversion/MathToLLVM/MathToLLVM.h"
4849
#include "mlir/Conversion/MathToLibm/MathToLibm.h"

mlir/include/mlir/Conversion/Passes.td

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -790,6 +790,28 @@ def ConvertMathToSPIRV : Pass<"convert-math-to-spirv"> {
790790
let dependentDialects = ["spirv::SPIRVDialect"];
791791
}
792792

793+
//===----------------------------------------------------------------------===//
794+
// MathToEmitC
795+
//===----------------------------------------------------------------------===//
796+
797+
def ConvertMathToEmitC : Pass<"convert-math-to-emitc"> {
798+
let summary = "Convert some Math operations to EmitC call_opaque operations";
799+
let description = [{
800+
This pass converts supported Math ops to `call_opaque` ops targeting libc/libm
801+
functions. Unlike convert-math-to-funcs pass, converting to `call_opaque` ops
802+
allows to overload the same function with different argument types.
803+
}];
804+
let dependentDialects = ["emitc::EmitCDialect"];
805+
let options = [
806+
Option<"languageTarget", "language-target", "::mlir::emitc::LanguageTarget",
807+
/*default=*/"::mlir::emitc::LanguageTarget::c99", "Select the language standard target for callees (c99 or cpp11).",
808+
[{::llvm::cl::values(
809+
clEnumValN(::mlir::emitc::LanguageTarget::c99, "c99", "c99"),
810+
clEnumValN(::mlir::emitc::LanguageTarget::cpp11, "cpp11", "cpp11")
811+
)}]>
812+
];
813+
}
814+
793815
//===----------------------------------------------------------------------===//
794816
// MathToFuncs
795817
//===----------------------------------------------------------------------===//

mlir/lib/Conversion/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ add_subdirectory(IndexToLLVM)
3333
add_subdirectory(IndexToSPIRV)
3434
add_subdirectory(LinalgToStandard)
3535
add_subdirectory(LLVMCommon)
36+
add_subdirectory(MathToEmitC)
3637
add_subdirectory(MathToFuncs)
3738
add_subdirectory(MathToLibm)
3839
add_subdirectory(MathToLLVM)
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
add_mlir_conversion_library(MLIRMathToEmitC
2+
MathToEmitC.cpp
3+
MathToEmitCPass.cpp
4+
5+
ADDITIONAL_HEADER_DIRS
6+
${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/MathToEmitC
7+
8+
DEPENDS
9+
MLIRConversionPassIncGen
10+
11+
LINK_COMPONENTS
12+
Core
13+
14+
LINK_LIBS PUBLIC
15+
MLIREmitCDialect
16+
MLIRMathDialect
17+
MLIRPass
18+
MLIRTransformUtils
19+
)
Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
//===- MathToEmitC.cpp - Math to EmitC Patterns -----------------*- C++ -*-===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#include "mlir/Conversion/MathToEmitC/MathToEmitC.h"
10+
11+
#include "mlir/Dialect/EmitC/IR/EmitC.h"
12+
#include "mlir/Dialect/Math/IR/Math.h"
13+
#include "mlir/Transforms/DialectConversion.h"
14+
15+
using namespace mlir;
16+
17+
namespace {
18+
template <typename OpType>
19+
class LowerToEmitCCallOpaque : public OpRewritePattern<OpType> {
20+
std::string calleeStr;
21+
emitc::LanguageTarget languageTarget;
22+
23+
public:
24+
LowerToEmitCCallOpaque(MLIRContext *context, std::string calleeStr,
25+
emitc::LanguageTarget languageTarget)
26+
: OpRewritePattern<OpType>(context), calleeStr(std::move(calleeStr)),
27+
languageTarget(languageTarget) {}
28+
29+
LogicalResult matchAndRewrite(OpType op,
30+
PatternRewriter &rewriter) const override;
31+
};
32+
33+
template <typename OpType>
34+
LogicalResult LowerToEmitCCallOpaque<OpType>::matchAndRewrite(
35+
OpType op, PatternRewriter &rewriter) const {
36+
if (!llvm::all_of(op->getOperandTypes(),
37+
llvm::IsaPred<Float32Type, Float64Type>) ||
38+
!llvm::all_of(op->getResultTypes(),
39+
llvm::IsaPred<Float32Type, Float64Type>))
40+
return rewriter.notifyMatchFailure(
41+
op.getLoc(),
42+
"expected all operands and results to be of type f32 or f64");
43+
std::string modifiedCalleeStr = calleeStr;
44+
if (languageTarget == emitc::LanguageTarget::cpp11) {
45+
modifiedCalleeStr = "std::" + calleeStr;
46+
} else if (languageTarget == emitc::LanguageTarget::c99) {
47+
auto operandType = op->getOperandTypes()[0];
48+
if (operandType.isF32())
49+
modifiedCalleeStr = calleeStr + "f";
50+
}
51+
rewriter.replaceOpWithNewOp<emitc::CallOpaqueOp>(
52+
op, op.getType(), modifiedCalleeStr, op->getOperands());
53+
return success();
54+
}
55+
56+
} // namespace
57+
58+
// Populates patterns to replace `math` operations with `emitc.call_opaque`,
59+
// using function names consistent with those in <math.h>.
60+
void mlir::populateConvertMathToEmitCPatterns(
61+
RewritePatternSet &patterns, emitc::LanguageTarget languageTarget) {
62+
auto *context = patterns.getContext();
63+
patterns.insert<LowerToEmitCCallOpaque<math::FloorOp>>(context, "floor",
64+
languageTarget);
65+
patterns.insert<LowerToEmitCCallOpaque<math::RoundOp>>(context, "round",
66+
languageTarget);
67+
patterns.insert<LowerToEmitCCallOpaque<math::ExpOp>>(context, "exp",
68+
languageTarget);
69+
patterns.insert<LowerToEmitCCallOpaque<math::CosOp>>(context, "cos",
70+
languageTarget);
71+
patterns.insert<LowerToEmitCCallOpaque<math::SinOp>>(context, "sin",
72+
languageTarget);
73+
patterns.insert<LowerToEmitCCallOpaque<math::AcosOp>>(context, "acos",
74+
languageTarget);
75+
patterns.insert<LowerToEmitCCallOpaque<math::AsinOp>>(context, "asin",
76+
languageTarget);
77+
patterns.insert<LowerToEmitCCallOpaque<math::Atan2Op>>(context, "atan2",
78+
languageTarget);
79+
patterns.insert<LowerToEmitCCallOpaque<math::CeilOp>>(context, "ceil",
80+
languageTarget);
81+
patterns.insert<LowerToEmitCCallOpaque<math::AbsFOp>>(context, "fabs",
82+
languageTarget);
83+
patterns.insert<LowerToEmitCCallOpaque<math::PowFOp>>(context, "pow",
84+
languageTarget);
85+
}
Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
//===- MathToEmitCPass.cpp - Math to EmitC Pass -----------------*- C++ -*-===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
//
9+
// This file implements a pass to convert the Math dialect to the EmitC dialect.
10+
//
11+
//===----------------------------------------------------------------------===//
12+
13+
#include "mlir/Conversion/MathToEmitC/MathToEmitCPass.h"
14+
#include "mlir/Conversion/MathToEmitC/MathToEmitC.h"
15+
#include "mlir/Dialect/EmitC/IR/EmitC.h"
16+
#include "mlir/Dialect/Math/IR/Math.h"
17+
#include "mlir/Pass/Pass.h"
18+
#include "mlir/Transforms/DialectConversion.h"
19+
20+
namespace mlir {
21+
#define GEN_PASS_DEF_CONVERTMATHTOEMITC
22+
#include "mlir/Conversion/Passes.h.inc"
23+
} // namespace mlir
24+
25+
using namespace mlir;
26+
namespace {
27+
28+
// Replaces Math operations with `emitc.call_opaque` operations.
29+
struct ConvertMathToEmitC
30+
: public impl::ConvertMathToEmitCBase<ConvertMathToEmitC> {
31+
using ConvertMathToEmitCBase::ConvertMathToEmitCBase;
32+
33+
public:
34+
void runOnOperation() final;
35+
};
36+
37+
} // namespace
38+
39+
void ConvertMathToEmitC::runOnOperation() {
40+
ConversionTarget target(getContext());
41+
target.addLegalOp<emitc::CallOpaqueOp>();
42+
43+
target.addIllegalOp<math::FloorOp, math::ExpOp, math::RoundOp, math::CosOp,
44+
math::SinOp, math::Atan2Op, math::CeilOp, math::AcosOp,
45+
math::AsinOp, math::AbsFOp, math::PowFOp>();
46+
47+
RewritePatternSet patterns(&getContext());
48+
populateConvertMathToEmitCPatterns(patterns, languageTarget);
49+
50+
if (failed(
51+
applyPartialConversion(getOperation(), target, std::move(patterns))))
52+
signalPassFailure();
53+
}
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
// RUN: mlir-opt -split-input-file -convert-math-to-emitc -verify-diagnostics %s
2+
3+
func.func @unsupported_tensor_type(%arg0 : tensor<4xf32>) -> tensor<4xf32> {
4+
// expected-error @+1 {{failed to legalize operation 'math.absf' that was explicitly marked illegal}}
5+
%0 = math.absf %arg0 : tensor<4xf32>
6+
return %0 : tensor<4xf32>
7+
}
8+
9+
// -----
10+
11+
func.func @unsupported_f16_type(%arg0 : f16) -> f16 {
12+
// expected-error @+1 {{failed to legalize operation 'math.absf' that was explicitly marked illegal}}
13+
%0 = math.absf %arg0 : f16
14+
return %0 : f16
15+
}
16+
17+
// -----
18+
19+
func.func @unsupported_f128_type(%arg0 : f128) -> f128 {
20+
// expected-error @+1 {{failed to legalize operation 'math.absf' that was explicitly marked illegal}}
21+
%0 = math.absf %arg0 : f128
22+
return %0 : f128
23+
}
Lines changed: 112 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,112 @@
1+
// RUN: mlir-opt -convert-math-to-emitc=language-target=c99 %s | FileCheck %s --check-prefix=c99
2+
// RUN: mlir-opt -convert-math-to-emitc=language-target=cpp11 %s | FileCheck %s --check-prefix=cpp11
3+
4+
func.func @absf(%arg0: f32, %arg1: f64) {
5+
// c99: emitc.call_opaque "fabsf"
6+
// c99-NEXT: emitc.call_opaque "fabs"
7+
// cpp11: emitc.call_opaque "std::fabs"
8+
// cpp11-NEXT: emitc.call_opaque "std::fabs"
9+
%0 = math.absf %arg0 : f32
10+
%1 = math.absf %arg1 : f64
11+
return
12+
}
13+
14+
func.func @floor(%arg0: f32, %arg1: f64) {
15+
// c99: emitc.call_opaque "floorf"
16+
// c99-NEXT: emitc.call_opaque "floor"
17+
// cpp11: emitc.call_opaque "std::floor"
18+
// cpp11-NEXT: emitc.call_opaque "std::floor"
19+
%0 = math.floor %arg0 : f32
20+
%1 = math.floor %arg1 : f64
21+
return
22+
}
23+
24+
func.func @sin(%arg0: f32, %arg1: f64) {
25+
// c99: emitc.call_opaque "sinf"
26+
// c99-NEXT: emitc.call_opaque "sin"
27+
// cpp11: emitc.call_opaque "std::sin"
28+
// cpp11-NEXT: emitc.call_opaque "std::sin"
29+
%0 = math.sin %arg0 : f32
30+
%1 = math.sin %arg1 : f64
31+
return
32+
}
33+
34+
func.func @cos(%arg0: f32, %arg1: f64) {
35+
// c99: emitc.call_opaque "cosf"
36+
// c99-NEXT: emitc.call_opaque "cos"
37+
// cpp11: emitc.call_opaque "std::cos"
38+
// cpp11-NEXT: emitc.call_opaque "std::cos"
39+
%0 = math.cos %arg0 : f32
40+
%1 = math.cos %arg1 : f64
41+
return
42+
}
43+
44+
func.func @asin(%arg0: f32, %arg1: f64) {
45+
// c99: emitc.call_opaque "asinf"
46+
// c99-NEXT: emitc.call_opaque "asin"
47+
// cpp11: emitc.call_opaque "std::asin"
48+
// cpp11-NEXT: emitc.call_opaque "std::asin"
49+
%0 = math.asin %arg0 : f32
50+
%1 = math.asin %arg1 : f64
51+
return
52+
}
53+
54+
func.func @acos(%arg0: f32, %arg1: f64) {
55+
// c99: emitc.call_opaque "acosf"
56+
// c99-NEXT: emitc.call_opaque "acos"
57+
// cpp11: emitc.call_opaque "std::acos"
58+
// cpp11-NEXT: emitc.call_opaque "std::acos"
59+
%0 = math.acos %arg0 : f32
60+
%1 = math.acos %arg1 : f64
61+
return
62+
}
63+
64+
func.func @atan2(%arg0: f32, %arg1: f32, %arg2: f64, %arg3: f64) {
65+
// c99: emitc.call_opaque "atan2f"
66+
// c99-NEXT: emitc.call_opaque "atan2"
67+
// cpp11: emitc.call_opaque "std::atan2"
68+
// cpp11-NEXT: emitc.call_opaque "std::atan2"
69+
%0 = math.atan2 %arg0, %arg1 : f32
70+
%1 = math.atan2 %arg2, %arg3 : f64
71+
return
72+
}
73+
74+
func.func @ceil(%arg0: f32, %arg1: f64) {
75+
// c99: emitc.call_opaque "ceilf"
76+
// c99-NEXT: emitc.call_opaque "ceil"
77+
// cpp11: emitc.call_opaque "std::ceil"
78+
// cpp11-NEXT: emitc.call_opaque "std::ceil"
79+
%0 = math.ceil %arg0 : f32
80+
%1 = math.ceil %arg1 : f64
81+
return
82+
}
83+
84+
func.func @exp(%arg0: f32, %arg1: f64) {
85+
// c99: emitc.call_opaque "expf"
86+
// c99-NEXT: emitc.call_opaque "exp"
87+
// cpp11: emitc.call_opaque "std::exp"
88+
// cpp11-NEXT: emitc.call_opaque "std::exp"
89+
%0 = math.exp %arg0 : f32
90+
%1 = math.exp %arg1 : f64
91+
return
92+
}
93+
94+
func.func @powf(%arg0: f32, %arg1: f32, %arg2: f64, %arg3: f64) {
95+
// c99: emitc.call_opaque "powf"
96+
// c99-NEXT: emitc.call_opaque "pow"
97+
// cpp11: emitc.call_opaque "std::pow"
98+
// cpp11-NEXT: emitc.call_opaque "std::pow"
99+
%0 = math.powf %arg0, %arg1 : f32
100+
%1 = math.powf %arg2, %arg3 : f64
101+
return
102+
}
103+
104+
func.func @round(%arg0: f32, %arg1: f64) {
105+
// c99: emitc.call_opaque "roundf"
106+
// c99-NEXT: emitc.call_opaque "round"
107+
// cpp11: emitc.call_opaque "std::round"
108+
// cpp11-NEXT: emitc.call_opaque "std::round"
109+
%0 = math.round %arg0 : f32
110+
%1 = math.round %arg1 : f64
111+
return
112+
}

0 commit comments

Comments
 (0)