Skip to content

Commit d5bd00c

Browse files
committed
[mlir][EmitC] Add MathToEmitC pass for math function lowering to EmitC
This commit introduces a new `MathToEmitC` conversion pass that lowers selected math operations to the `emitc.call_opaque` operation in the EmitC dialect. The supported math operations include: - math.floor -> emitc.call_opaque<"floor"> - math.exp -> emitc.call_opaque<"exp"> - math.cos -> emitc.call_opaque<"cos"> - math.sin -> emitc.call_opaque<"sin"> - math.ipowi -> emitc.call_opaque<"pow"> We chose to use `emitc.call_opaque` instead of `emitc.call` to better align with C-style function overloading. Unlike `emitc.call`, which requires unique type signatures, `emitc.call_opaque` allows us to call functions without specifying a unique type-based signature. This flexibility is essential for mimicking function overloading behavior as seen in `<math.h>`. Additionally, the pass inserts an `emitc.include` operation to generate `#include <math.h>` at the top of the module to ensure the availability of the necessary math functions in the generated code. This pass enables the use of EmitC as an intermediate layer to generate C/C++ code with opaque calls to standard math functions.
1 parent a2ba438 commit d5bd00c

File tree

8 files changed

+326
-0
lines changed

8 files changed

+326
-0
lines changed
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
//===- MathToEmitC.h - Math to EmitC Pass -----------*- C++ -*-===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#ifndef MLIR_CONVERSION_MATHTOEMITC_MATHTOEMITC_H
10+
#define MLIR_CONVERSION_MATHTOEMITC_MATHTOEMITC_H
11+
12+
#include "mlir/IR/BuiltinOps.h"
13+
#include "mlir/Pass/Pass.h"
14+
#include <memory>
15+
16+
namespace mlir {
17+
18+
#define GEN_PASS_DECL_CONVERTMATHTOEMITC
19+
#include "mlir/Conversion/Passes.h.inc"
20+
21+
std::unique_ptr<OperationPass<mlir::ModuleOp>> createConvertMathToEmitCPass();
22+
23+
} // namespace mlir
24+
25+
#endif // MLIR_CONVERSION_MATHTOEMITC_MATHTOEMITC_H

mlir/include/mlir/Conversion/Passes.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@
4343
#include "mlir/Conversion/IndexToLLVM/IndexToLLVM.h"
4444
#include "mlir/Conversion/IndexToSPIRV/IndexToSPIRV.h"
4545
#include "mlir/Conversion/LinalgToStandard/LinalgToStandard.h"
46+
#include "mlir/Conversion/MathToEmitC/MathToEmitC.h"
4647
#include "mlir/Conversion/MathToFuncs/MathToFuncs.h"
4748
#include "mlir/Conversion/MathToLLVM/MathToLLVM.h"
4849
#include "mlir/Conversion/MathToLibm/MathToLibm.h"

mlir/include/mlir/Conversion/Passes.td

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -780,6 +780,25 @@ def ConvertMathToSPIRV : Pass<"convert-math-to-spirv"> {
780780
let dependentDialects = ["spirv::SPIRVDialect"];
781781
}
782782

783+
//===----------------------------------------------------------------------===//
784+
// MathToEmitC
785+
//===----------------------------------------------------------------------===//
786+
787+
def ConvertMathToEmitC : Pass<"convert-math-to-emitc", "ModuleOp"> {
788+
let summary = "Convert some Math operations to EmitC Call_opaque";
789+
let description = [{
790+
This pass converts supported Math ops to call_opaque calls to compiler generated
791+
functions implementing these operations in software.
792+
Unlike convert-math-to-funcs pass, this pass uses call_opaque,
793+
therefore enables us to overload the same funtion with different argument types
794+
}];
795+
796+
let constructor = "mlir::createConvertMathToEmitCPass()";
797+
let dependentDialects = ["emitc::EmitCDialect",
798+
"math::MathDialect"
799+
];
800+
}
801+
783802
//===----------------------------------------------------------------------===//
784803
// MathToFuncs
785804
//===----------------------------------------------------------------------===//

mlir/lib/Conversion/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ add_subdirectory(IndexToLLVM)
3333
add_subdirectory(IndexToSPIRV)
3434
add_subdirectory(LinalgToStandard)
3535
add_subdirectory(LLVMCommon)
36+
add_subdirectory(MathToEmitC)
3637
add_subdirectory(MathToFuncs)
3738
add_subdirectory(MathToLibm)
3839
add_subdirectory(MathToLLVM)
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
add_mlir_conversion_library(MLIRMathToEmitC
2+
MathToEmitC.cpp
3+
4+
ADDITIONAL_HEADER_DIRS
5+
${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/MathToEmitC
6+
7+
DEPENDS
8+
MLIRConversionPassIncGen
9+
10+
LINK_COMPONENTS
11+
Core
12+
13+
LINK_LIBS PUBLIC
14+
MLIRLLVMCommonConversion
15+
MLIREmitCDialect
16+
MLIRMathDialect
17+
MLIRPass
18+
MLIRTransforms
19+
)
Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,99 @@
1+
2+
//===- MathToEmitC.cpp - Math to EmitC Pass Implementation ----------===//
3+
//
4+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
5+
// See https://llvm.org/LICENSE.txt for license information.
6+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7+
//
8+
//===----------------------------------------------------------------------===//
9+
10+
#include "mlir/Conversion/MathToEmitC/MathToEmitC.h"
11+
#include "mlir/Dialect/EmitC/IR/EmitC.h"
12+
#include "mlir/Dialect/Math/IR/Math.h"
13+
#include "mlir/Pass/Pass.h"
14+
#include "mlir/Transforms/DialectConversion.h"
15+
16+
namespace mlir {
17+
#define GEN_PASS_DEF_CONVERTMATHTOEMITC
18+
#include "mlir/Conversion/Passes.h.inc"
19+
} // namespace mlir
20+
21+
using namespace mlir;
22+
namespace {
23+
24+
// Replaces Math operations with `emitc.call_opaque` operations.
25+
struct ConvertMathToEmitCPass
26+
: public impl::ConvertMathToEmitCBase<ConvertMathToEmitCPass> {
27+
public:
28+
void runOnOperation() final;
29+
};
30+
31+
} // end anonymous namespace
32+
33+
template <typename OpType>
34+
class LowerToEmitCCallOpaque : public mlir::OpRewritePattern<OpType> {
35+
std::string calleeStr;
36+
37+
public:
38+
LowerToEmitCCallOpaque(MLIRContext *context, std::string calleeStr)
39+
: OpRewritePattern<OpType>(context), calleeStr(calleeStr) {}
40+
41+
LogicalResult matchAndRewrite(OpType op,
42+
PatternRewriter &rewriter) const override;
43+
};
44+
45+
// Populates patterns to replace `math` operations with `emitc.call_opaque`,
46+
// using function names consistent with those in <math.h>.
47+
static void populateConvertMathToEmitCPatterns(RewritePatternSet &patterns) {
48+
auto *context = patterns.getContext();
49+
patterns.insert<LowerToEmitCCallOpaque<math::FloorOp>>(context, "floor");
50+
patterns.insert<LowerToEmitCCallOpaque<math::RoundEvenOp>>(context, "rint");
51+
patterns.insert<LowerToEmitCCallOpaque<math::ExpOp>>(context, "exp");
52+
patterns.insert<LowerToEmitCCallOpaque<math::CosOp>>(context, "cos");
53+
patterns.insert<LowerToEmitCCallOpaque<math::SinOp>>(context, "sin");
54+
patterns.insert<LowerToEmitCCallOpaque<math::AcosOp>>(context, "acos");
55+
patterns.insert<LowerToEmitCCallOpaque<math::AsinOp>>(context, "asin");
56+
patterns.insert<LowerToEmitCCallOpaque<math::Atan2Op>>(context, "atan2");
57+
patterns.insert<LowerToEmitCCallOpaque<math::CeilOp>>(context, "ceil");
58+
patterns.insert<LowerToEmitCCallOpaque<math::AbsFOp>>(context, "fabs");
59+
patterns.insert<LowerToEmitCCallOpaque<math::FPowIOp>>(context, "powf");
60+
patterns.insert<LowerToEmitCCallOpaque<math::IPowIOp>>(context, "pow");
61+
}
62+
63+
template <typename OpType>
64+
LogicalResult LowerToEmitCCallOpaque<OpType>::matchAndRewrite(
65+
OpType op, PatternRewriter &rewriter) const {
66+
mlir::StringAttr callee = rewriter.getStringAttr(calleeStr);
67+
auto actualOp = mlir::cast<OpType>(op);
68+
rewriter.replaceOpWithNewOp<mlir::emitc::CallOpaqueOp>(
69+
actualOp, actualOp.getType(), callee, actualOp->getOperands());
70+
return mlir::success();
71+
}
72+
73+
void ConvertMathToEmitCPass::runOnOperation() {
74+
auto moduleOp = getOperation();
75+
// Insert #include <math.h> at the beginning of the module
76+
OpBuilder builder(moduleOp.getBodyRegion());
77+
builder.setInsertionPointToStart(&moduleOp.getBodyRegion().front());
78+
builder.create<emitc::IncludeOp>(moduleOp.getLoc(),
79+
builder.getStringAttr("math.h"));
80+
81+
ConversionTarget target(getContext());
82+
target.addLegalOp<emitc::CallOpaqueOp>();
83+
84+
target.addIllegalOp<math::FloorOp, math::ExpOp, math::RoundEvenOp,
85+
math::CosOp, math::SinOp, math::Atan2Op, math::CeilOp,
86+
math::AcosOp, math::AsinOp, math::AbsFOp, math::PowFOp,
87+
math::FPowIOp, math::IPowIOp>();
88+
89+
RewritePatternSet patterns(&getContext());
90+
populateConvertMathToEmitCPatterns(patterns);
91+
92+
if (failed(applyPartialConversion(moduleOp, target, std::move(patterns))))
93+
signalPassFailure();
94+
}
95+
96+
std::unique_ptr<OperationPass<mlir::ModuleOp>>
97+
mlir::createConvertMathToEmitCPass() {
98+
return std::make_unique<ConvertMathToEmitCPass>();
99+
}
Lines changed: 140 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,140 @@
1+
// RUN: mlir-opt --split-input-file -convert-math-to-emitc %s | FileCheck %s
2+
3+
// CHECK-LABEL: emitc.include "math.h"
4+
5+
// CHECK-LABEL: func.func @absf_to_call_opaque(
6+
// CHECK-SAME: %[[VAL_0:.*]]: f32) {
7+
// CHECK: %[[VAL_1:.*]] = emitc.call_opaque "fabs"(%[[VAL_0]]) : (f32) -> f32
8+
// CHECK: return
9+
// CHECK: }
10+
func.func @absf_to_call_opaque(%arg0: f32) {
11+
%1 = math.absf %arg0 : f32
12+
return
13+
}
14+
15+
// -----
16+
17+
// CHECK-LABEL: func.func @floor_to_call_opaque(
18+
// CHECK-SAME: %[[VAL_0:.*]]: f32) {
19+
// CHECK: %[[VAL_1:.*]] = emitc.call_opaque "floor"(%[[VAL_0]]) : (f32) -> f32
20+
// CHECK: return
21+
// CHECK: }
22+
func.func @floor_to_call_opaque(%arg0: f32) {
23+
%1 = math.floor %arg0 : f32
24+
return
25+
}
26+
27+
// -----
28+
29+
// CHECK-LABEL: func.func @sin_to_call_opaque(
30+
// CHECK-SAME: %[[VAL_0:.*]]: f32) {
31+
// CHECK: %[[VAL_1:.*]] = emitc.call_opaque "sin"(%[[VAL_0]]) : (f32) -> f32
32+
// CHECK: return
33+
// CHECK: }
34+
func.func @sin_to_call_opaque(%arg0: f32) {
35+
%1 = math.sin %arg0 : f32
36+
return
37+
}
38+
39+
// -----
40+
41+
// CHECK-LABEL: func.func @cos_to_call_opaque(
42+
// CHECK-SAME: %[[VAL_0:.*]]: f32) {
43+
// CHECK: %[[VAL_1:.*]] = emitc.call_opaque "cos"(%[[VAL_0]]) : (f32) -> f32
44+
// CHECK: return
45+
// CHECK: }
46+
func.func @cos_to_call_opaque(%arg0: f32) {
47+
%1 = math.cos %arg0 : f32
48+
return
49+
}
50+
51+
52+
// -----
53+
54+
// CHECK-LABEL: func.func @asin_to_call_opaque(
55+
// CHECK-SAME: %[[VAL_0:.*]]: f32) {
56+
// CHECK: %[[VAL_1:.*]] = emitc.call_opaque "asin"(%[[VAL_0]]) : (f32) -> f32
57+
// CHECK: return
58+
// CHECK: }
59+
func.func @asin_to_call_opaque(%arg0: f32) {
60+
%1 = math.asin %arg0 : f32
61+
return
62+
}
63+
64+
// -----
65+
66+
// CHECK-LABEL: func.func @acos_to_call_opaque(
67+
// CHECK-SAME: %[[VAL_0:.*]]: f32) {
68+
// CHECK: %[[VAL_1:.*]] = emitc.call_opaque "acos"(%[[VAL_0]]) : (f32) -> f32
69+
// CHECK: return
70+
// CHECK: }
71+
func.func @acos_to_call_opaque(%arg0: f32) {
72+
%1 = math.acos %arg0 : f32
73+
return
74+
}
75+
76+
// -----
77+
78+
// CHECK-LABEL: func.func @atan2_to_call_opaque(
79+
// CHECK-SAME: %[[VAL_0:.*]]: f32,
80+
// CHECK-SAME: %[[VAL_1:.*]]: f32) {
81+
// CHECK: %[[VAL_2:.*]] = emitc.call_opaque "atan2"(%[[VAL_0]], %[[VAL_1]]) : (f32, f32) -> f32
82+
// CHECK: return
83+
// CHECK: }
84+
func.func @atan2_to_call_opaque(%arg0: f32, %arg1: f32) {
85+
%1 = math.atan2 %arg0, %arg1 : f32
86+
return
87+
}
88+
89+
// -----
90+
91+
// CHECK-LABEL: func.func @ceil_to_call_opaque(
92+
// CHECK-SAME: %[[VAL_0:.*]]: f32) {
93+
// CHECK: %[[VAL_1:.*]] = emitc.call_opaque "ceil"(%[[VAL_0]]) : (f32) -> f32
94+
// CHECK: return
95+
// CHECK: }
96+
func.func @ceil_to_call_opaque(%arg0: f32) {
97+
%1 = math.ceil %arg0 : f32
98+
return
99+
}
100+
101+
// -----
102+
103+
// CHECK-LABEL: func.func @exp_to_call_opaque(
104+
// CHECK-SAME: %[[VAL_0:.*]]: f32) {
105+
// CHECK: %[[VAL_1:.*]] = emitc.call_opaque "exp"(%[[VAL_0]]) : (f32) -> f32
106+
// CHECK: return
107+
// CHECK: }
108+
func.func @exp_to_call_opaque(%arg0: f32) {
109+
%1 = math.exp %arg0 : f32
110+
return
111+
}
112+
113+
114+
// -----
115+
116+
// CHECK-LABEL: func.func @fpowi_to_call_opaque(
117+
// CHECK-SAME: %[[VAL_0:.*]]: f32,
118+
// CHECK-SAME: %[[VAL_1:.*]]: i32) {
119+
// CHECK: %[[VAL_2:.*]] = emitc.call_opaque "powf"(%[[VAL_0]], %[[VAL_1]]) : (f32, i32) -> f32
120+
// CHECK: return
121+
// CHECK: }
122+
func.func @fpowi_to_call_opaque(%arg0: f32, %arg1: i32) {
123+
%1 = math.fpowi %arg0, %arg1 : f32, i32
124+
return
125+
}
126+
127+
// -----
128+
129+
// CHECK-LABEL: func.func @ipowi_to_call_opaque(
130+
// CHECK-SAME: %[[VAL_0:.*]]: i32,
131+
// CHECK-SAME: %[[VAL_1:.*]]: i32) {
132+
// CHECK: %[[VAL_2:.*]] = emitc.call_opaque "pow"(%[[VAL_0]], %[[VAL_1]]) : (i32, i32) -> i32
133+
// CHECK: return
134+
// CHECK: }
135+
func.func @ipowi_to_call_opaque(%arg0: i32, %arg1: i32) {
136+
%1 = math.ipowi %arg0, %arg1 : i32
137+
return
138+
}
139+
140+

utils/bazel/llvm-project-overlay/mlir/BUILD.bazel

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4201,6 +4201,7 @@ cc_library(
42014201
":IndexToLLVM",
42024202
":IndexToSPIRV",
42034203
":LinalgToStandard",
4204+
":MathToEmitC",
42044205
":MathToFuncs",
42054206
":MathToLLVM",
42064207
":MathToLibm",
@@ -8721,6 +8722,27 @@ cc_library(
87218722
],
87228723
)
87238724

8725+
cc_library(
8726+
name = "MathToEmitC",
8727+
srcs = glob([
8728+
"lib/Conversion/MathToEmitC/*.cpp",
8729+
]),
8730+
hdrs = glob([
8731+
"include/mlir/Conversion/MathToEmitC/*.h",
8732+
]),
8733+
includes = [
8734+
"include",
8735+
"lib/Conversion/MathToEmitC",
8736+
],
8737+
deps = [
8738+
":ConversionPassIncGen",
8739+
":EmitCDialect",
8740+
":MathDialect",
8741+
":Pass",
8742+
":TransformUtils",
8743+
],
8744+
)
8745+
87248746
cc_library(
87258747
name = "MathToFuncs",
87268748
srcs = glob(["lib/Conversion/MathToFuncs/*.cpp"]),

0 commit comments

Comments
 (0)