Skip to content

Commit f6c2406

Browse files
committed
[MLIR][MathToEmitC] Refactor code, add tests for unsupported types, and restrict to f32 only
Refactored code (added newlines and nits). Added tests to verify behavior with unsupported types. Now only f32 is supported. Deleted generation of emitc.include Changed the conversion to apply at the operation level instead of the module level.
1 parent ad9af42 commit f6c2406

File tree

8 files changed

+69
-93
lines changed

8 files changed

+69
-93
lines changed

mlir/include/mlir/Conversion/MathToEmitC/MathToEmitC.h

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
//===- MathToEmitC.h - Math to EmitC Pass -----------*- C++ -*-===//
1+
//===- MathToEmitC.h - Math to EmitCPatterns -------------------*- C++ -*-===//
22
//
33
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
44
// See https://llvm.org/LICENSE.txt for license information.
@@ -13,7 +13,6 @@ namespace mlir {
1313
class RewritePatternSet;
1414

1515
void populateConvertMathToEmitCPatterns(RewritePatternSet &patterns);
16-
1716
} // namespace mlir
1817

1918
#endif // MLIR_CONVERSION_MATHTOEMITC_MATHTOEMITC_H

mlir/include/mlir/Conversion/MathToEmitC/MathToEmitCPass.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
//===- MathToEmitCPass.h - Math to EmitC Pass -----------------*- C++ -*-===//
1+
//===- MathToEmitCPass.h - Math to EmitC Pass -------------------*- C++ -*-===//
22
//
33
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
44
// See https://llvm.org/LICENSE.txt for license information.
@@ -18,4 +18,4 @@ class Pass;
1818
#include "mlir/Conversion/Passes.h.inc"
1919
} // namespace mlir
2020

21-
#endif // MLIR_CONVERSION_MATHTOEMITC_MATHTOEMITCPASS_H
21+
#endif // MLIR_CONVERSION_MATHTOEMITC_MATHTOEMITCPASS_H

mlir/include/mlir/Conversion/Passes.td

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -784,17 +784,14 @@ def ConvertMathToSPIRV : Pass<"convert-math-to-spirv"> {
784784
// MathToEmitC
785785
//===----------------------------------------------------------------------===//
786786

787-
def ConvertMathToEmitC : Pass<"convert-math-to-emitc", "ModuleOp"> {
787+
def ConvertMathToEmitC : Pass<"convert-math-to-emitc"> {
788788
let summary = "Convert some Math operations to EmitC Call_opaque";
789789
let description = [{
790-
This pass converts supported Math ops to call_opaque calls to compiler generated
791-
functions implementing these operations in software.
792-
Unlike convert-math-to-funcs pass, this pass uses call_opaque,
793-
therefore enables us to overload the same funtion with different argument types
790+
This pass converts supported Math ops to `opaque_call` ops targeting libc/libm
791+
functions. Unlike convert-math-to-funcs pass, converting to `call_opaque` ops
792+
allows to overload the same function with different argument types.
794793
}];
795-
let dependentDialects = ["emitc::EmitCDialect",
796-
"math::MathDialect"
797-
];
794+
let dependentDialects = ["emitc::EmitCDialect"];
798795
}
799796

800797
//===----------------------------------------------------------------------===//

mlir/lib/Conversion/MathToEmitC/CMakeLists.txt

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@ add_mlir_conversion_library(MLIRMathToEmitC
1212
Core
1313

1414
LINK_LIBS PUBLIC
15-
MLIRLLVMCommonConversion
1615
MLIREmitCDialect
1716
MLIRMathDialect
1817
MLIRPass
Lines changed: 19 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
//===- MathToEmitC.cpp - Math to EmitC Pass Implementation ----------===//
1+
//===- MathToEmitC.cpp - Math to EmitC Patterns ----------------*- C++ -*-===//
22
//
33
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
44
// See https://llvm.org/LICENSE.txt for license information.
@@ -16,7 +16,7 @@ using namespace mlir;
1616

1717
namespace {
1818
template <typename OpType>
19-
class LowerToEmitCCallOpaque : public mlir::OpRewritePattern<OpType> {
19+
class LowerToEmitCCallOpaque : public OpRewritePattern<OpType> {
2020
std::string calleeStr;
2121

2222
public:
@@ -30,19 +30,12 @@ class LowerToEmitCCallOpaque : public mlir::OpRewritePattern<OpType> {
3030
template <typename OpType>
3131
LogicalResult LowerToEmitCCallOpaque<OpType>::matchAndRewrite(
3232
OpType op, PatternRewriter &rewriter) const {
33-
auto actualOp = mlir::cast<OpType>(op);
34-
if (!llvm::all_of(
35-
actualOp->getOperands(),
36-
[](Value operand) { return isa<FloatType>(operand.getType()); }) ||
37-
!llvm::all_of(actualOp->getResultTypes(),
38-
[](mlir::Type type) { return isa<FloatType>(type); })) {
39-
op.emitError("non-float types are not supported");
40-
return mlir::failure();
41-
}
42-
mlir::StringAttr callee = rewriter.getStringAttr(calleeStr);
43-
rewriter.replaceOpWithNewOp<mlir::emitc::CallOpaqueOp>(
44-
actualOp, actualOp.getType(), callee, actualOp->getOperands());
45-
return mlir::success();
33+
if (!llvm::all_of(op->getOperandTypes(), llvm::IsaPred<Float32Type, Float64Type>)||
34+
!llvm::all_of(op->getResultTypes(),llvm::IsaPred<Float32Type, Float64Type>))
35+
return rewriter.notifyMatchFailure(op.getLoc(), "expected all operands and results to be of type f32 or f64");
36+
rewriter.replaceOpWithNewOp<emitc::CallOpaqueOp>(
37+
op, op.getType(), calleeStr, op->getOperands());
38+
return success();
4639
}
4740

4841
} // namespace
@@ -51,15 +44,15 @@ LogicalResult LowerToEmitCCallOpaque<OpType>::matchAndRewrite(
5144
// using function names consistent with those in <math.h>.
5245
void mlir::populateConvertMathToEmitCPatterns(RewritePatternSet &patterns) {
5346
auto *context = patterns.getContext();
54-
patterns.insert<LowerToEmitCCallOpaque<math::FloorOp>>(context, "floor");
55-
patterns.insert<LowerToEmitCCallOpaque<math::RoundEvenOp>>(context, "rint");
56-
patterns.insert<LowerToEmitCCallOpaque<math::ExpOp>>(context, "exp");
57-
patterns.insert<LowerToEmitCCallOpaque<math::CosOp>>(context, "cos");
58-
patterns.insert<LowerToEmitCCallOpaque<math::SinOp>>(context, "sin");
59-
patterns.insert<LowerToEmitCCallOpaque<math::AcosOp>>(context, "acos");
60-
patterns.insert<LowerToEmitCCallOpaque<math::AsinOp>>(context, "asin");
61-
patterns.insert<LowerToEmitCCallOpaque<math::Atan2Op>>(context, "atan2");
62-
patterns.insert<LowerToEmitCCallOpaque<math::CeilOp>>(context, "ceil");
63-
patterns.insert<LowerToEmitCCallOpaque<math::AbsFOp>>(context, "fabs");
64-
patterns.insert<LowerToEmitCCallOpaque<math::PowFOp>>(context, "pow");
47+
patterns.insert<LowerToEmitCCallOpaque<math::FloorOp>>(context, "floorf");
48+
patterns.insert<LowerToEmitCCallOpaque<math::RoundOp>>(context, "roundf");
49+
patterns.insert<LowerToEmitCCallOpaque<math::ExpOp>>(context, "expf");
50+
patterns.insert<LowerToEmitCCallOpaque<math::CosOp>>(context, "cosf");
51+
patterns.insert<LowerToEmitCCallOpaque<math::SinOp>>(context, "sinf");
52+
patterns.insert<LowerToEmitCCallOpaque<math::AcosOp>>(context, "acosf");
53+
patterns.insert<LowerToEmitCCallOpaque<math::AsinOp>>(context, "asinf");
54+
patterns.insert<LowerToEmitCCallOpaque<math::Atan2Op>>(context, "atan2f");
55+
patterns.insert<LowerToEmitCCallOpaque<math::CeilOp>>(context, "ceilf");
56+
patterns.insert<LowerToEmitCCallOpaque<math::AbsFOp>>(context, "fabsf");
57+
patterns.insert<LowerToEmitCCallOpaque<math::PowFOp>>(context, "powf");
6558
}

mlir/lib/Conversion/MathToEmitC/MathToEmitCPass.cpp

Lines changed: 6 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -26,33 +26,25 @@ using namespace mlir;
2626
namespace {
2727

2828
// Replaces Math operations with `emitc.call_opaque` operations.
29-
struct ConvertMathToEmitCPass
30-
: public impl::ConvertMathToEmitCBase<ConvertMathToEmitCPass> {
29+
struct ConvertMathToEmitC
30+
: public impl::ConvertMathToEmitCBase<ConvertMathToEmitC> {
3131
public:
3232
void runOnOperation() final;
3333
};
3434

3535
} // end anonymous namespace
3636

37-
void ConvertMathToEmitCPass::runOnOperation() {
38-
auto moduleOp = getOperation();
39-
// Insert #include <math.h> at the beginning of the module
40-
OpBuilder builder(moduleOp.getBodyRegion());
41-
builder.setInsertionPointToStart(&moduleOp.getBodyRegion().front());
42-
builder.create<emitc::IncludeOp>(moduleOp.getLoc(),
43-
builder.getStringAttr("math.h"));
44-
37+
void ConvertMathToEmitC::runOnOperation() {
4538
ConversionTarget target(getContext());
4639
target.addLegalOp<emitc::CallOpaqueOp>();
4740

4841
target.addIllegalOp<math::FloorOp, math::ExpOp, math::RoundEvenOp,
4942
math::CosOp, math::SinOp, math::Atan2Op, math::CeilOp,
50-
math::AcosOp, math::AsinOp, math::AbsFOp, math::PowFOp,
51-
math::FPowIOp, math::IPowIOp>();
43+
math::AcosOp, math::AsinOp, math::AbsFOp, math::PowFOp>();
5244

5345
RewritePatternSet patterns(&getContext());
5446
populateConvertMathToEmitCPatterns(patterns);
5547

56-
if (failed(applyPartialConversion(moduleOp, target, std::move(patterns))))
48+
if (failed(applyPartialConversion(getOperation(), target, std::move(patterns))))
5749
signalPassFailure();
58-
}
50+
}
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
// RUN: mlir-opt -split-input-file -convert-math-to-emitc -verify-diagnostics %s
2+
3+
func.func @unsupported_tensor_type(%arg0 : tensor<4xf32>) -> tensor<4xf32> {
4+
// expected-error @+1 {{failed to legalize operation 'math.absf' that was explicitly marked illegal}}
5+
%0 = math.absf %arg0 : tensor<4xf32>
6+
return %0 : tensor<4xf32>
7+
}
8+
9+
// -----
10+
11+
func.func @unsupported_f16_type(%arg0 : f16) -> f16 {
12+
// expected-error @+1 {{failed to legalize operation 'math.absf' that was explicitly marked illegal}}
13+
%0 = math.absf %arg0 : f16
14+
return %0 : f16
15+
}
16+
17+
// -----
18+
19+
func.func @unsupported_f128_type(%arg0 : f128) -> f128 {
20+
// expected-error @+1 {{failed to legalize operation 'math.absf' that was explicitly marked illegal}}
21+
%0 = math.absf %arg0 : f128
22+
return %0 : f128
23+
}
Lines changed: 13 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -1,133 +1,106 @@
1-
// RUN: mlir-opt --split-input-file -convert-math-to-emitc -verify-diagnostics %s | FileCheck %s
1+
// RUN: mlir-opt -convert-math-to-emitc %s | FileCheck %s
22

3-
// CHECK-LABEL: emitc.include "math.h"
43

54
// CHECK-LABEL: func.func @absf_to_call_opaque(
65
// CHECK-SAME: %[[VAL_0:.*]]: f32) {
7-
// CHECK: %[[VAL_1:.*]] = emitc.call_opaque "fabs"(%[[VAL_0]]) : (f32) -> f32
6+
// CHECK: %[[VAL_1:.*]] = emitc.call_opaque "fabsf"(%[[VAL_0]]) : (f32) -> f32
87
// CHECK: return
98
// CHECK: }
109
func.func @absf_to_call_opaque(%arg0: f32) {
1110
%1 = math.absf %arg0 : f32
1211
return
1312
}
14-
15-
// -----
16-
1713
// CHECK-LABEL: func.func @floor_to_call_opaque(
1814
// CHECK-SAME: %[[VAL_0:.*]]: f32) {
19-
// CHECK: %[[VAL_1:.*]] = emitc.call_opaque "floor"(%[[VAL_0]]) : (f32) -> f32
15+
// CHECK: %[[VAL_1:.*]] = emitc.call_opaque "floorf"(%[[VAL_0]]) : (f32) -> f32
2016
// CHECK: return
2117
// CHECK: }
2218
func.func @floor_to_call_opaque(%arg0: f32) {
2319
%1 = math.floor %arg0 : f32
2420
return
2521
}
26-
27-
// -----
28-
2922
// CHECK-LABEL: func.func @sin_to_call_opaque(
3023
// CHECK-SAME: %[[VAL_0:.*]]: f32) {
31-
// CHECK: %[[VAL_1:.*]] = emitc.call_opaque "sin"(%[[VAL_0]]) : (f32) -> f32
24+
// CHECK: %[[VAL_1:.*]] = emitc.call_opaque "sinf"(%[[VAL_0]]) : (f32) -> f32
3225
// CHECK: return
3326
// CHECK: }
3427
func.func @sin_to_call_opaque(%arg0: f32) {
3528
%1 = math.sin %arg0 : f32
3629
return
3730
}
3831

39-
// -----
40-
4132
// CHECK-LABEL: func.func @cos_to_call_opaque(
4233
// CHECK-SAME: %[[VAL_0:.*]]: f32) {
43-
// CHECK: %[[VAL_1:.*]] = emitc.call_opaque "cos"(%[[VAL_0]]) : (f32) -> f32
34+
// CHECK: %[[VAL_1:.*]] = emitc.call_opaque "cosf"(%[[VAL_0]]) : (f32) -> f32
4435
// CHECK: return
4536
// CHECK: }
4637
func.func @cos_to_call_opaque(%arg0: f32) {
4738
%1 = math.cos %arg0 : f32
4839
return
4940
}
5041

51-
52-
// -----
53-
5442
// CHECK-LABEL: func.func @asin_to_call_opaque(
5543
// CHECK-SAME: %[[VAL_0:.*]]: f32) {
56-
// CHECK: %[[VAL_1:.*]] = emitc.call_opaque "asin"(%[[VAL_0]]) : (f32) -> f32
44+
// CHECK: %[[VAL_1:.*]] = emitc.call_opaque "asinf"(%[[VAL_0]]) : (f32) -> f32
5745
// CHECK: return
5846
// CHECK: }
5947
func.func @asin_to_call_opaque(%arg0: f32) {
6048
%1 = math.asin %arg0 : f32
6149
return
6250
}
6351

64-
// -----
65-
6652
// CHECK-LABEL: func.func @acos_to_call_opaque(
6753
// CHECK-SAME: %[[VAL_0:.*]]: f32) {
68-
// CHECK: %[[VAL_1:.*]] = emitc.call_opaque "acos"(%[[VAL_0]]) : (f32) -> f32
54+
// CHECK: %[[VAL_1:.*]] = emitc.call_opaque "acosf"(%[[VAL_0]]) : (f32) -> f32
6955
// CHECK: return
7056
// CHECK: }
7157
func.func @acos_to_call_opaque(%arg0: f32) {
7258
%1 = math.acos %arg0 : f32
7359
return
7460
}
7561

76-
// -----
77-
7862
// CHECK-LABEL: func.func @atan2_to_call_opaque(
7963
// CHECK-SAME: %[[VAL_0:.*]]: f32,
8064
// CHECK-SAME: %[[VAL_1:.*]]: f32) {
81-
// CHECK: %[[VAL_2:.*]] = emitc.call_opaque "atan2"(%[[VAL_0]], %[[VAL_1]]) : (f32, f32) -> f32
65+
// CHECK: %[[VAL_2:.*]] = emitc.call_opaque "atan2f"(%[[VAL_0]], %[[VAL_1]]) : (f32, f32) -> f32
8266
// CHECK: return
8367
// CHECK: }
8468
func.func @atan2_to_call_opaque(%arg0: f32, %arg1: f32) {
8569
%1 = math.atan2 %arg0, %arg1 : f32
8670
return
8771
}
8872

89-
// -----
9073

9174
// CHECK-LABEL: func.func @ceil_to_call_opaque(
9275
// CHECK-SAME: %[[VAL_0:.*]]: f32) {
93-
// CHECK: %[[VAL_1:.*]] = emitc.call_opaque "ceil"(%[[VAL_0]]) : (f32) -> f32
76+
// CHECK: %[[VAL_1:.*]] = emitc.call_opaque "ceilf"(%[[VAL_0]]) : (f32) -> f32
9477
// CHECK: return
9578
// CHECK: }
9679
func.func @ceil_to_call_opaque(%arg0: f32) {
9780
%1 = math.ceil %arg0 : f32
9881
return
9982
}
10083

101-
// -----
102-
10384
// CHECK-LABEL: func.func @exp_to_call_opaque(
10485
// CHECK-SAME: %[[VAL_0:.*]]: f32) {
105-
// CHECK: %[[VAL_1:.*]] = emitc.call_opaque "exp"(%[[VAL_0]]) : (f32) -> f32
86+
// CHECK: %[[VAL_1:.*]] = emitc.call_opaque "expf"(%[[VAL_0]]) : (f32) -> f32
10687
// CHECK: return
10788
// CHECK: }
10889
func.func @exp_to_call_opaque(%arg0: f32) {
10990
%1 = math.exp %arg0 : f32
11091
return
11192
}
11293

113-
// -----
11494

11595
// CHECK-LABEL: func.func @powf_to_call_opaque(
116-
// CHECK-SAME: %[[VAL_0:.*]]: f32,
117-
// CHECK-SAME: %[[VAL_1:.*]]: f32) {
118-
// CHECK: %[[VAL_2:.*]] = emitc.call_opaque "pow"(%[[VAL_0]], %[[VAL_1]]) : (f32, f32) -> f32
96+
// CHECK-SAME: %[[VAL_0:.*]]: f32,
97+
// CHECK-SAME: %[[VAL_1:.*]]: f32) {
98+
// CHECK: %[[VAL_2:.*]] = emitc.call_opaque "powf"(%[[VAL_0]], %[[VAL_1]]) : (f32, f32) -> f32
11999
// CHECK: return
120100
// CHECK: }
121101
func.func @powf_to_call_opaque(%arg0: f32, %arg1: f32) {
122102
%1 = math.powf %arg0, %arg1 : f32
123103
return
124104
}
125105

126-
// -----
127106

128-
func.func @test(%arg0 : tensor<4xf32>) -> tensor<4xf32> {
129-
// expected-error @+2 {{failed to legalize operation 'math.absf' that was explicitly marked illegal}}
130-
// expected-error @+1 {{non-float types are not supported}}
131-
%0 = math.absf %arg0 : tensor<4xf32>
132-
return %0 : tensor<4xf32>
133-
}

0 commit comments

Comments
 (0)