Skip to content

Commit 23ada46

Browse files
committed
[MLIR][MathToEmitC] Add support for C and C++ targets with Lit tests
C target: Generates specific callee names based on operand types by appending the appropriate suffix C++ target: Uses standard library functions with the std:: namespace Updated LIT tests
1 parent f6c2406 commit 23ada46

File tree

7 files changed

+99
-93
lines changed

7 files changed

+99
-93
lines changed

mlir/include/mlir/Conversion/MathToEmitC/MathToEmitC.h

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,11 +8,14 @@
88

99
#ifndef MLIR_CONVERSION_MATHTOEMITC_MATHTOEMITC_H
1010
#define MLIR_CONVERSION_MATHTOEMITC_MATHTOEMITC_H
11+
#include "mlir/Dialect/EmitC/IR/EmitC.h"
1112

1213
namespace mlir {
1314
class RewritePatternSet;
1415

15-
void populateConvertMathToEmitCPatterns(RewritePatternSet &patterns);
16+
void populateConvertMathToEmitCPatterns(
17+
RewritePatternSet &patterns,
18+
emitc::MathToEmitCLanguageTarget languageTarget);
1619
} // namespace mlir
1720

1821
#endif // MLIR_CONVERSION_MATHTOEMITC_MATHTOEMITC_H

mlir/include/mlir/Conversion/MathToEmitC/MathToEmitCPass.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,8 @@
99
#ifndef MLIR_CONVERSION_MATHTOEMITC_MATHTOEMITCPASS_H
1010
#define MLIR_CONVERSION_MATHTOEMITC_MATHTOEMITCPASS_H
1111

12+
#include "mlir/Dialect/EmitC/IR/EmitC.h"
1213
#include <memory>
13-
1414
namespace mlir {
1515
class Pass;
1616

mlir/include/mlir/Conversion/Passes.td

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -792,6 +792,14 @@ def ConvertMathToEmitC : Pass<"convert-math-to-emitc"> {
792792
allows to overload the same function with different argument types.
793793
}];
794794
let dependentDialects = ["emitc::EmitCDialect"];
795+
let options = [
796+
Option<"languageTarget", "language-target", "::mlir::emitc::MathToEmitCLanguageTarget",
797+
/*default=*/"::mlir::emitc::MathToEmitCLanguageTarget::CPP", "Select the language target for callees (C or CPP).",
798+
[{::llvm::cl::values(
799+
clEnumValN(::mlir::emitc::MathToEmitCLanguageTarget::C, "C", "C"),
800+
clEnumValN(::mlir::emitc::MathToEmitCLanguageTarget::CPP, "CPP", "CPP")
801+
)}]>
802+
];
795803
}
796804

797805
//===----------------------------------------------------------------------===//

mlir/include/mlir/Dialect/EmitC/IR/EmitCAttributes.td

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,4 +62,13 @@ def EmitC_OpaqueAttr : EmitC_Attr<"Opaque", "opaque"> {
6262

6363
def EmitC_OpaqueOrTypedAttr : AnyAttrOf<[EmitC_OpaqueAttr, TypedAttrInterface]>;
6464

65+
def MathToEmitCLanguageTarget : I32EnumAttr<"MathToEmitCLanguageTarget",
66+
"Specifies the language target for generating callees.", [
67+
I32EnumAttrCase<"C", 0, "Use C-style function names">,
68+
I32EnumAttrCase<"CPP", 1, "Use C++-style function names">
69+
]> {
70+
let cppNamespace = "::mlir::emitc";
71+
}
72+
73+
6574
#endif // MLIR_DIALECT_EMITC_IR_EMITCATTRIBUTES

mlir/lib/Conversion/MathToEmitC/MathToEmitC.cpp

Lines changed: 41 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -18,10 +18,13 @@ namespace {
1818
template <typename OpType>
1919
class LowerToEmitCCallOpaque : public OpRewritePattern<OpType> {
2020
std::string calleeStr;
21+
emitc::MathToEmitCLanguageTarget languageTarget;
2122

2223
public:
23-
LowerToEmitCCallOpaque(MLIRContext *context, std::string calleeStr)
24-
: OpRewritePattern<OpType>(context), calleeStr(std::move(calleeStr)) {}
24+
LowerToEmitCCallOpaque(MLIRContext *context, std::string calleeStr,
25+
emitc::MathToEmitCLanguageTarget languageTarget)
26+
: OpRewritePattern<OpType>(context), calleeStr(std::move(calleeStr)),
27+
languageTarget(languageTarget) {}
2528

2629
LogicalResult matchAndRewrite(OpType op,
2730
PatternRewriter &rewriter) const override;
@@ -32,27 +35,49 @@ LogicalResult LowerToEmitCCallOpaque<OpType>::matchAndRewrite(
3235
OpType op, PatternRewriter &rewriter) const {
3336
if (!llvm::all_of(op->getOperandTypes(), llvm::IsaPred<Float32Type, Float64Type>)||
3437
!llvm::all_of(op->getResultTypes(),llvm::IsaPred<Float32Type, Float64Type>))
35-
return rewriter.notifyMatchFailure(op.getLoc(), "expected all operands and results to be of type f32 or f64");
38+
return rewriter.notifyMatchFailure(
39+
op.getLoc(), "expected all operands and results to be of type f32");
40+
std::string modifiedCalleeStr = calleeStr;
41+
if (languageTarget == emitc::MathToEmitCLanguageTarget::CPP) {
42+
modifiedCalleeStr = "std::" + calleeStr;
43+
} else if (languageTarget == emitc::MathToEmitCLanguageTarget::C) {
44+
auto operandType = op->getOperandTypes()[0];
45+
if (operandType.isF32())
46+
modifiedCalleeStr = calleeStr + "f";
47+
}
3648
rewriter.replaceOpWithNewOp<emitc::CallOpaqueOp>(
37-
op, op.getType(), calleeStr, op->getOperands());
49+
op, op.getType(), modifiedCalleeStr, op->getOperands());
3850
return success();
3951
}
4052

4153
} // namespace
4254

4355
// Populates patterns to replace `math` operations with `emitc.call_opaque`,
4456
// using function names consistent with those in <math.h>.
45-
void mlir::populateConvertMathToEmitCPatterns(RewritePatternSet &patterns) {
57+
void mlir::populateConvertMathToEmitCPatterns(
58+
RewritePatternSet &patterns,
59+
emitc::MathToEmitCLanguageTarget languageTarget) {
4660
auto *context = patterns.getContext();
47-
patterns.insert<LowerToEmitCCallOpaque<math::FloorOp>>(context, "floorf");
48-
patterns.insert<LowerToEmitCCallOpaque<math::RoundOp>>(context, "roundf");
49-
patterns.insert<LowerToEmitCCallOpaque<math::ExpOp>>(context, "expf");
50-
patterns.insert<LowerToEmitCCallOpaque<math::CosOp>>(context, "cosf");
51-
patterns.insert<LowerToEmitCCallOpaque<math::SinOp>>(context, "sinf");
52-
patterns.insert<LowerToEmitCCallOpaque<math::AcosOp>>(context, "acosf");
53-
patterns.insert<LowerToEmitCCallOpaque<math::AsinOp>>(context, "asinf");
54-
patterns.insert<LowerToEmitCCallOpaque<math::Atan2Op>>(context, "atan2f");
55-
patterns.insert<LowerToEmitCCallOpaque<math::CeilOp>>(context, "ceilf");
56-
patterns.insert<LowerToEmitCCallOpaque<math::AbsFOp>>(context, "fabsf");
57-
patterns.insert<LowerToEmitCCallOpaque<math::PowFOp>>(context, "powf");
61+
patterns.insert<LowerToEmitCCallOpaque<math::FloorOp>>(context, "floor",
62+
languageTarget);
63+
patterns.insert<LowerToEmitCCallOpaque<math::RoundOp>>(context, "round",
64+
languageTarget);
65+
patterns.insert<LowerToEmitCCallOpaque<math::ExpOp>>(context, "exp",
66+
languageTarget);
67+
patterns.insert<LowerToEmitCCallOpaque<math::CosOp>>(context, "cos",
68+
languageTarget);
69+
patterns.insert<LowerToEmitCCallOpaque<math::SinOp>>(context, "sin",
70+
languageTarget);
71+
patterns.insert<LowerToEmitCCallOpaque<math::AcosOp>>(context, "acos",
72+
languageTarget);
73+
patterns.insert<LowerToEmitCCallOpaque<math::AsinOp>>(context, "asin",
74+
languageTarget);
75+
patterns.insert<LowerToEmitCCallOpaque<math::Atan2Op>>(context, "atan2",
76+
languageTarget);
77+
patterns.insert<LowerToEmitCCallOpaque<math::CeilOp>>(context, "ceil",
78+
languageTarget);
79+
patterns.insert<LowerToEmitCCallOpaque<math::AbsFOp>>(context, "fabs",
80+
languageTarget);
81+
patterns.insert<LowerToEmitCCallOpaque<math::PowFOp>>(context, "pow",
82+
languageTarget);
5883
}

mlir/lib/Conversion/MathToEmitC/MathToEmitCPass.cpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,11 +28,13 @@ namespace {
2828
// Replaces Math operations with `emitc.call_opaque` operations.
2929
struct ConvertMathToEmitC
3030
: public impl::ConvertMathToEmitCBase<ConvertMathToEmitC> {
31+
using ConvertMathToEmitCBase::ConvertMathToEmitCBase;
32+
3133
public:
3234
void runOnOperation() final;
3335
};
3436

35-
} // end anonymous namespace
37+
} // namespace
3638

3739
void ConvertMathToEmitC::runOnOperation() {
3840
ConversionTarget target(getContext());
@@ -43,7 +45,7 @@ void ConvertMathToEmitC::runOnOperation() {
4345
math::AcosOp, math::AsinOp, math::AbsFOp, math::PowFOp>();
4446

4547
RewritePatternSet patterns(&getContext());
46-
populateConvertMathToEmitCPatterns(patterns);
48+
populateConvertMathToEmitCPatterns(patterns, languageTarget);
4749

4850
if (failed(applyPartialConversion(getOperation(), target, std::move(patterns))))
4951
signalPassFailure();

mlir/test/Conversion/MathToEmitC/math-to-emitc.mlir

Lines changed: 32 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -1,105 +1,64 @@
1-
// RUN: mlir-opt -convert-math-to-emitc %s | FileCheck %s
1+
// RUN: mlir-opt -convert-math-to-emitc=language-target=C %s | FileCheck %s --check-prefix=C
2+
// RUN: mlir-opt -convert-math-to-emitc=language-target=CPP %s | FileCheck %s --check-prefix=CPP
23

3-
4-
// CHECK-LABEL: func.func @absf_to_call_opaque(
5-
// CHECK-SAME: %[[VAL_0:.*]]: f32) {
6-
// CHECK: %[[VAL_1:.*]] = emitc.call_opaque "fabsf"(%[[VAL_0]]) : (f32) -> f32
7-
// CHECK: return
8-
// CHECK: }
94
func.func @absf_to_call_opaque(%arg0: f32) {
5+
// C: emitc.call_opaque "fabsf"
6+
// CPP: emitc.call_opaque "std::fabs"
107
%1 = math.absf %arg0 : f32
118
return
129
}
13-
// CHECK-LABEL: func.func @floor_to_call_opaque(
14-
// CHECK-SAME: %[[VAL_0:.*]]: f32) {
15-
// CHECK: %[[VAL_1:.*]] = emitc.call_opaque "floorf"(%[[VAL_0]]) : (f32) -> f32
16-
// CHECK: return
17-
// CHECK: }
1810
func.func @floor_to_call_opaque(%arg0: f32) {
11+
// C: emitc.call_opaque "floorf"
12+
// CPP: emitc.call_opaque "std::floor"
1913
%1 = math.floor %arg0 : f32
2014
return
2115
}
22-
// CHECK-LABEL: func.func @sin_to_call_opaque(
23-
// CHECK-SAME: %[[VAL_0:.*]]: f32) {
24-
// CHECK: %[[VAL_1:.*]] = emitc.call_opaque "sinf"(%[[VAL_0]]) : (f32) -> f32
25-
// CHECK: return
26-
// CHECK: }
2716
func.func @sin_to_call_opaque(%arg0: f32) {
17+
// C: emitc.call_opaque "sinf"
18+
// CPP: emitc.call_opaque "std::sin"
2819
%1 = math.sin %arg0 : f32
2920
return
3021
}
31-
32-
// CHECK-LABEL: func.func @cos_to_call_opaque(
33-
// CHECK-SAME: %[[VAL_0:.*]]: f32) {
34-
// CHECK: %[[VAL_1:.*]] = emitc.call_opaque "cosf"(%[[VAL_0]]) : (f32) -> f32
35-
// CHECK: return
36-
// CHECK: }
3722
func.func @cos_to_call_opaque(%arg0: f32) {
23+
// C: emitc.call_opaque "cosf"
24+
// CPP: emitc.call_opaque "std::cos"
3825
%1 = math.cos %arg0 : f32
3926
return
4027
}
41-
42-
// CHECK-LABEL: func.func @asin_to_call_opaque(
43-
// CHECK-SAME: %[[VAL_0:.*]]: f32) {
44-
// CHECK: %[[VAL_1:.*]] = emitc.call_opaque "asinf"(%[[VAL_0]]) : (f32) -> f32
45-
// CHECK: return
46-
// CHECK: }
4728
func.func @asin_to_call_opaque(%arg0: f32) {
29+
// C: emitc.call_opaque "asinf"
30+
// CPP: emitc.call_opaque "std::asin"
4831
%1 = math.asin %arg0 : f32
4932
return
5033
}
51-
52-
// CHECK-LABEL: func.func @acos_to_call_opaque(
53-
// CHECK-SAME: %[[VAL_0:.*]]: f32) {
54-
// CHECK: %[[VAL_1:.*]] = emitc.call_opaque "acosf"(%[[VAL_0]]) : (f32) -> f32
55-
// CHECK: return
56-
// CHECK: }
57-
func.func @acos_to_call_opaque(%arg0: f32) {
58-
%1 = math.acos %arg0 : f32
34+
func.func @acos_to_call_opaque(%arg0: f64) {
35+
// C: emitc.call_opaque "acos"
36+
// CPP: emitc.call_opaque "std::acos"
37+
%1 = math.acos %arg0 : f64
5938
return
6039
}
61-
62-
// CHECK-LABEL: func.func @atan2_to_call_opaque(
63-
// CHECK-SAME: %[[VAL_0:.*]]: f32,
64-
// CHECK-SAME: %[[VAL_1:.*]]: f32) {
65-
// CHECK: %[[VAL_2:.*]] = emitc.call_opaque "atan2f"(%[[VAL_0]], %[[VAL_1]]) : (f32, f32) -> f32
66-
// CHECK: return
67-
// CHECK: }
68-
func.func @atan2_to_call_opaque(%arg0: f32, %arg1: f32) {
69-
%1 = math.atan2 %arg0, %arg1 : f32
40+
func.func @atan2_to_call_opaque(%arg0: f64, %arg1: f64) {
41+
// C: emitc.call_opaque "atan2"
42+
// CPP: emitc.call_opaque "std::atan2"
43+
%1 = math.atan2 %arg0, %arg1 : f64
7044
return
7145
}
72-
73-
74-
// CHECK-LABEL: func.func @ceil_to_call_opaque(
75-
// CHECK-SAME: %[[VAL_0:.*]]: f32) {
76-
// CHECK: %[[VAL_1:.*]] = emitc.call_opaque "ceilf"(%[[VAL_0]]) : (f32) -> f32
77-
// CHECK: return
78-
// CHECK: }
79-
func.func @ceil_to_call_opaque(%arg0: f32) {
80-
%1 = math.ceil %arg0 : f32
46+
func.func @ceil_to_call_opaque(%arg0: f64) {
47+
// C: emitc.call_opaque "ceil"
48+
// CPP: emitc.call_opaque "std::ceil"
49+
%1 = math.ceil %arg0 : f64
8150
return
8251
}
83-
84-
// CHECK-LABEL: func.func @exp_to_call_opaque(
85-
// CHECK-SAME: %[[VAL_0:.*]]: f32) {
86-
// CHECK: %[[VAL_1:.*]] = emitc.call_opaque "expf"(%[[VAL_0]]) : (f32) -> f32
87-
// CHECK: return
88-
// CHECK: }
89-
func.func @exp_to_call_opaque(%arg0: f32) {
90-
%1 = math.exp %arg0 : f32
52+
func.func @exp_to_call_opaque(%arg0: f64) {
53+
// C: emitc.call_opaque "exp"
54+
// CPP: emitc.call_opaque "std::exp"
55+
%1 = math.exp %arg0 : f64
9156
return
9257
}
93-
94-
95-
// CHECK-LABEL: func.func @powf_to_call_opaque(
96-
// CHECK-SAME: %[[VAL_0:.*]]: f32,
97-
// CHECK-SAME: %[[VAL_1:.*]]: f32) {
98-
// CHECK: %[[VAL_2:.*]] = emitc.call_opaque "powf"(%[[VAL_0]], %[[VAL_1]]) : (f32, f32) -> f32
99-
// CHECK: return
100-
// CHECK: }
101-
func.func @powf_to_call_opaque(%arg0: f32, %arg1: f32) {
102-
%1 = math.powf %arg0, %arg1 : f32
58+
func.func @powf_to_call_opaque(%arg0: f64, %arg1: f64) {
59+
// C: emitc.call_opaque "pow"
60+
// CPP: emitc.call_opaque "std::pow"
61+
%1 = math.powf %arg0, %arg1 : f64
10362
return
10463
}
10564

0 commit comments

Comments
 (0)