Skip to content

Commit ad9af42

Browse files
committed
[MLIR][MathToEmitC] Ensure scalar type handling and refactor
This patch ensures that the MathToEmitC pass only converts scalar `FloatType`s, avoiding invalid conversions of non-scalar types like tensors. - **Validation:** Added checks to convert only scalar types. - **Refactoring:** Moved implementation to `MathToEmitCPass.cpp` and split headers. - **Testing:** Added test cases to ensure proper error handling for non-scalar types.
1 parent d5bd00c commit ad9af42

File tree

8 files changed

+120
-89
lines changed

8 files changed

+120
-89
lines changed

mlir/include/mlir/Conversion/MathToEmitC/MathToEmitC.h

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -9,16 +9,10 @@
99
#ifndef MLIR_CONVERSION_MATHTOEMITC_MATHTOEMITC_H
1010
#define MLIR_CONVERSION_MATHTOEMITC_MATHTOEMITC_H
1111

12-
#include "mlir/IR/BuiltinOps.h"
13-
#include "mlir/Pass/Pass.h"
14-
#include <memory>
15-
1612
namespace mlir {
13+
class RewritePatternSet;
1714

18-
#define GEN_PASS_DECL_CONVERTMATHTOEMITC
19-
#include "mlir/Conversion/Passes.h.inc"
20-
21-
std::unique_ptr<OperationPass<mlir::ModuleOp>> createConvertMathToEmitCPass();
15+
void populateConvertMathToEmitCPatterns(RewritePatternSet &patterns);
2216

2317
} // namespace mlir
2418

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
//===- MathToEmitCPass.h - Math to EmitC Pass -----------------*- C++ -*-===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#ifndef MLIR_CONVERSION_MATHTOEMITC_MATHTOEMITCPASS_H
10+
#define MLIR_CONVERSION_MATHTOEMITC_MATHTOEMITCPASS_H
11+
12+
#include <memory>
13+
14+
namespace mlir {
15+
class Pass;
16+
17+
#define GEN_PASS_DECL_CONVERTMATHTOEMITC
18+
#include "mlir/Conversion/Passes.h.inc"
19+
} // namespace mlir
20+
21+
#endif // MLIR_CONVERSION_MATHTOEMITC_MATHTOEMITCPASS_H

mlir/include/mlir/Conversion/Passes.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@
4343
#include "mlir/Conversion/IndexToLLVM/IndexToLLVM.h"
4444
#include "mlir/Conversion/IndexToSPIRV/IndexToSPIRV.h"
4545
#include "mlir/Conversion/LinalgToStandard/LinalgToStandard.h"
46-
#include "mlir/Conversion/MathToEmitC/MathToEmitC.h"
46+
#include "mlir/Conversion/MathToEmitC/MathToEmitCPass.h"
4747
#include "mlir/Conversion/MathToFuncs/MathToFuncs.h"
4848
#include "mlir/Conversion/MathToLLVM/MathToLLVM.h"
4949
#include "mlir/Conversion/MathToLibm/MathToLibm.h"

mlir/include/mlir/Conversion/Passes.td

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -792,8 +792,6 @@ def ConvertMathToEmitC : Pass<"convert-math-to-emitc", "ModuleOp"> {
792792
Unlike convert-math-to-funcs pass, this pass uses call_opaque,
793793
therefore enables us to overload the same funtion with different argument types
794794
}];
795-
796-
let constructor = "mlir::createConvertMathToEmitCPass()";
797795
let dependentDialects = ["emitc::EmitCDialect",
798796
"math::MathDialect"
799797
];

mlir/lib/Conversion/MathToEmitC/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
add_mlir_conversion_library(MLIRMathToEmitC
22
MathToEmitC.cpp
3+
MathToEmitCPass.cpp
34

45
ADDITIONAL_HEADER_DIRS
56
${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/MathToEmitC
Lines changed: 25 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
21
//===- MathToEmitC.cpp - Math to EmitC Pass Implementation ----------===//
32
//
43
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
@@ -8,43 +7,49 @@
87
//===----------------------------------------------------------------------===//
98

109
#include "mlir/Conversion/MathToEmitC/MathToEmitC.h"
10+
1111
#include "mlir/Dialect/EmitC/IR/EmitC.h"
1212
#include "mlir/Dialect/Math/IR/Math.h"
13-
#include "mlir/Pass/Pass.h"
1413
#include "mlir/Transforms/DialectConversion.h"
1514

16-
namespace mlir {
17-
#define GEN_PASS_DEF_CONVERTMATHTOEMITC
18-
#include "mlir/Conversion/Passes.h.inc"
19-
} // namespace mlir
20-
2115
using namespace mlir;
22-
namespace {
23-
24-
// Replaces Math operations with `emitc.call_opaque` operations.
25-
struct ConvertMathToEmitCPass
26-
: public impl::ConvertMathToEmitCBase<ConvertMathToEmitCPass> {
27-
public:
28-
void runOnOperation() final;
29-
};
30-
31-
} // end anonymous namespace
3216

17+
namespace {
3318
template <typename OpType>
3419
class LowerToEmitCCallOpaque : public mlir::OpRewritePattern<OpType> {
3520
std::string calleeStr;
3621

3722
public:
3823
LowerToEmitCCallOpaque(MLIRContext *context, std::string calleeStr)
39-
: OpRewritePattern<OpType>(context), calleeStr(calleeStr) {}
24+
: OpRewritePattern<OpType>(context), calleeStr(std::move(calleeStr)) {}
4025

4126
LogicalResult matchAndRewrite(OpType op,
4227
PatternRewriter &rewriter) const override;
4328
};
4429

30+
template <typename OpType>
31+
LogicalResult LowerToEmitCCallOpaque<OpType>::matchAndRewrite(
32+
OpType op, PatternRewriter &rewriter) const {
33+
auto actualOp = mlir::cast<OpType>(op);
34+
if (!llvm::all_of(
35+
actualOp->getOperands(),
36+
[](Value operand) { return isa<FloatType>(operand.getType()); }) ||
37+
!llvm::all_of(actualOp->getResultTypes(),
38+
[](mlir::Type type) { return isa<FloatType>(type); })) {
39+
op.emitError("non-float types are not supported");
40+
return mlir::failure();
41+
}
42+
mlir::StringAttr callee = rewriter.getStringAttr(calleeStr);
43+
rewriter.replaceOpWithNewOp<mlir::emitc::CallOpaqueOp>(
44+
actualOp, actualOp.getType(), callee, actualOp->getOperands());
45+
return mlir::success();
46+
}
47+
48+
} // namespace
49+
4550
// Populates patterns to replace `math` operations with `emitc.call_opaque`,
4651
// using function names consistent with those in <math.h>.
47-
static void populateConvertMathToEmitCPatterns(RewritePatternSet &patterns) {
52+
void mlir::populateConvertMathToEmitCPatterns(RewritePatternSet &patterns) {
4853
auto *context = patterns.getContext();
4954
patterns.insert<LowerToEmitCCallOpaque<math::FloorOp>>(context, "floor");
5055
patterns.insert<LowerToEmitCCallOpaque<math::RoundEvenOp>>(context, "rint");
@@ -56,44 +61,5 @@ static void populateConvertMathToEmitCPatterns(RewritePatternSet &patterns) {
5661
patterns.insert<LowerToEmitCCallOpaque<math::Atan2Op>>(context, "atan2");
5762
patterns.insert<LowerToEmitCCallOpaque<math::CeilOp>>(context, "ceil");
5863
patterns.insert<LowerToEmitCCallOpaque<math::AbsFOp>>(context, "fabs");
59-
patterns.insert<LowerToEmitCCallOpaque<math::FPowIOp>>(context, "powf");
60-
patterns.insert<LowerToEmitCCallOpaque<math::IPowIOp>>(context, "pow");
61-
}
62-
63-
template <typename OpType>
64-
LogicalResult LowerToEmitCCallOpaque<OpType>::matchAndRewrite(
65-
OpType op, PatternRewriter &rewriter) const {
66-
mlir::StringAttr callee = rewriter.getStringAttr(calleeStr);
67-
auto actualOp = mlir::cast<OpType>(op);
68-
rewriter.replaceOpWithNewOp<mlir::emitc::CallOpaqueOp>(
69-
actualOp, actualOp.getType(), callee, actualOp->getOperands());
70-
return mlir::success();
71-
}
72-
73-
void ConvertMathToEmitCPass::runOnOperation() {
74-
auto moduleOp = getOperation();
75-
// Insert #include <math.h> at the beginning of the module
76-
OpBuilder builder(moduleOp.getBodyRegion());
77-
builder.setInsertionPointToStart(&moduleOp.getBodyRegion().front());
78-
builder.create<emitc::IncludeOp>(moduleOp.getLoc(),
79-
builder.getStringAttr("math.h"));
80-
81-
ConversionTarget target(getContext());
82-
target.addLegalOp<emitc::CallOpaqueOp>();
83-
84-
target.addIllegalOp<math::FloorOp, math::ExpOp, math::RoundEvenOp,
85-
math::CosOp, math::SinOp, math::Atan2Op, math::CeilOp,
86-
math::AcosOp, math::AsinOp, math::AbsFOp, math::PowFOp,
87-
math::FPowIOp, math::IPowIOp>();
88-
89-
RewritePatternSet patterns(&getContext());
90-
populateConvertMathToEmitCPatterns(patterns);
91-
92-
if (failed(applyPartialConversion(moduleOp, target, std::move(patterns))))
93-
signalPassFailure();
94-
}
95-
96-
std::unique_ptr<OperationPass<mlir::ModuleOp>>
97-
mlir::createConvertMathToEmitCPass() {
98-
return std::make_unique<ConvertMathToEmitCPass>();
64+
patterns.insert<LowerToEmitCCallOpaque<math::PowFOp>>(context, "pow");
9965
}
Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
//===- MathToEmitCPass.cpp - Math to EmitC Pass -----------------*- C++ -*-===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
//
9+
// This file implements a pass to convert the Math dialect to the EmitC dialect.
10+
//
11+
//===----------------------------------------------------------------------===//
12+
13+
#include "mlir/Conversion/MathToEmitC/MathToEmitCPass.h"
14+
#include "mlir/Conversion/MathToEmitC/MathToEmitC.h"
15+
#include "mlir/Dialect/EmitC/IR/EmitC.h"
16+
#include "mlir/Dialect/Math/IR/Math.h"
17+
#include "mlir/Pass/Pass.h"
18+
#include "mlir/Transforms/DialectConversion.h"
19+
20+
namespace mlir {
21+
#define GEN_PASS_DEF_CONVERTMATHTOEMITC
22+
#include "mlir/Conversion/Passes.h.inc"
23+
} // namespace mlir
24+
25+
using namespace mlir;
26+
namespace {
27+
28+
// Replaces Math operations with `emitc.call_opaque` operations.
29+
struct ConvertMathToEmitCPass
30+
: public impl::ConvertMathToEmitCBase<ConvertMathToEmitCPass> {
31+
public:
32+
void runOnOperation() final;
33+
};
34+
35+
} // end anonymous namespace
36+
37+
void ConvertMathToEmitCPass::runOnOperation() {
38+
auto moduleOp = getOperation();
39+
// Insert #include <math.h> at the beginning of the module
40+
OpBuilder builder(moduleOp.getBodyRegion());
41+
builder.setInsertionPointToStart(&moduleOp.getBodyRegion().front());
42+
builder.create<emitc::IncludeOp>(moduleOp.getLoc(),
43+
builder.getStringAttr("math.h"));
44+
45+
ConversionTarget target(getContext());
46+
target.addLegalOp<emitc::CallOpaqueOp>();
47+
48+
target.addIllegalOp<math::FloorOp, math::ExpOp, math::RoundEvenOp,
49+
math::CosOp, math::SinOp, math::Atan2Op, math::CeilOp,
50+
math::AcosOp, math::AsinOp, math::AbsFOp, math::PowFOp,
51+
math::FPowIOp, math::IPowIOp>();
52+
53+
RewritePatternSet patterns(&getContext());
54+
populateConvertMathToEmitCPatterns(patterns);
55+
56+
if (failed(applyPartialConversion(moduleOp, target, std::move(patterns))))
57+
signalPassFailure();
58+
}

mlir/test/Conversion/MathToEmitC/math-to-emitc.mlir

Lines changed: 12 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// RUN: mlir-opt --split-input-file -convert-math-to-emitc %s | FileCheck %s
1+
// RUN: mlir-opt --split-input-file -convert-math-to-emitc -verify-diagnostics %s | FileCheck %s
22

33
// CHECK-LABEL: emitc.include "math.h"
44

@@ -110,31 +110,24 @@ func.func @exp_to_call_opaque(%arg0: f32) {
110110
return
111111
}
112112

113-
114113
// -----
115114

116-
// CHECK-LABEL: func.func @fpowi_to_call_opaque(
115+
// CHECK-LABEL: func.func @powf_to_call_opaque(
117116
// CHECK-SAME: %[[VAL_0:.*]]: f32,
118-
// CHECK-SAME: %[[VAL_1:.*]]: i32) {
119-
// CHECK: %[[VAL_2:.*]] = emitc.call_opaque "powf"(%[[VAL_0]], %[[VAL_1]]) : (f32, i32) -> f32
117+
// CHECK-SAME: %[[VAL_1:.*]]: f32) {
118+
// CHECK: %[[VAL_2:.*]] = emitc.call_opaque "pow"(%[[VAL_0]], %[[VAL_1]]) : (f32, f32) -> f32
120119
// CHECK: return
121120
// CHECK: }
122-
func.func @fpowi_to_call_opaque(%arg0: f32, %arg1: i32) {
123-
%1 = math.fpowi %arg0, %arg1 : f32, i32
121+
func.func @powf_to_call_opaque(%arg0: f32, %arg1: f32) {
122+
%1 = math.powf %arg0, %arg1 : f32
124123
return
125124
}
126125

127126
// -----
128127

129-
// CHECK-LABEL: func.func @ipowi_to_call_opaque(
130-
// CHECK-SAME: %[[VAL_0:.*]]: i32,
131-
// CHECK-SAME: %[[VAL_1:.*]]: i32) {
132-
// CHECK: %[[VAL_2:.*]] = emitc.call_opaque "pow"(%[[VAL_0]], %[[VAL_1]]) : (i32, i32) -> i32
133-
// CHECK: return
134-
// CHECK: }
135-
func.func @ipowi_to_call_opaque(%arg0: i32, %arg1: i32) {
136-
%1 = math.ipowi %arg0, %arg1 : i32
137-
return
138-
}
139-
140-
128+
func.func @test(%arg0 : tensor<4xf32>) -> tensor<4xf32> {
129+
// expected-error @+2 {{failed to legalize operation 'math.absf' that was explicitly marked illegal}}
130+
// expected-error @+1 {{non-float types are not supported}}
131+
%0 = math.absf %arg0 : tensor<4xf32>
132+
return %0 : tensor<4xf32>
133+
}

0 commit comments

Comments
 (0)