1-
21// ===- MathToEmitC.cpp - Math to EmitC Pass Implementation ----------===//
32//
43// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
87// ===----------------------------------------------------------------------===//
98
109#include " mlir/Conversion/MathToEmitC/MathToEmitC.h"
10+
1111#include " mlir/Dialect/EmitC/IR/EmitC.h"
1212#include " mlir/Dialect/Math/IR/Math.h"
13- #include " mlir/Pass/Pass.h"
1413#include " mlir/Transforms/DialectConversion.h"
1514
16- namespace mlir {
17- #define GEN_PASS_DEF_CONVERTMATHTOEMITC
18- #include " mlir/Conversion/Passes.h.inc"
19- } // namespace mlir
20-
2115using namespace mlir ;
22- namespace {
23-
24- // Replaces Math operations with `emitc.call_opaque` operations.
25- struct ConvertMathToEmitCPass
26- : public impl::ConvertMathToEmitCBase<ConvertMathToEmitCPass> {
27- public:
28- void runOnOperation () final ;
29- };
30-
31- } // end anonymous namespace
3216
17+ namespace {
3318template <typename OpType>
3419class LowerToEmitCCallOpaque : public mlir ::OpRewritePattern<OpType> {
3520 std::string calleeStr;
3621
3722public:
3823 LowerToEmitCCallOpaque (MLIRContext *context, std::string calleeStr)
39- : OpRewritePattern<OpType>(context), calleeStr(calleeStr) {}
24+ : OpRewritePattern<OpType>(context), calleeStr(std::move( calleeStr) ) {}
4025
4126 LogicalResult matchAndRewrite (OpType op,
4227 PatternRewriter &rewriter) const override ;
4328};
4429
30+ template <typename OpType>
31+ LogicalResult LowerToEmitCCallOpaque<OpType>::matchAndRewrite(
32+ OpType op, PatternRewriter &rewriter) const {
33+ auto actualOp = mlir::cast<OpType>(op);
34+ if (!llvm::all_of (
35+ actualOp->getOperands (),
36+ [](Value operand) { return isa<FloatType>(operand.getType ()); }) ||
37+ !llvm::all_of (actualOp->getResultTypes (),
38+ [](mlir::Type type) { return isa<FloatType>(type); })) {
39+ op.emitError (" non-float types are not supported" );
40+ return mlir::failure ();
41+ }
42+ mlir::StringAttr callee = rewriter.getStringAttr (calleeStr);
43+ rewriter.replaceOpWithNewOp <mlir::emitc::CallOpaqueOp>(
44+ actualOp, actualOp.getType (), callee, actualOp->getOperands ());
45+ return mlir::success ();
46+ }
47+
48+ } // namespace
49+
4550// Populates patterns to replace `math` operations with `emitc.call_opaque`,
4651// using function names consistent with those in <math.h>.
47- static void populateConvertMathToEmitCPatterns (RewritePatternSet &patterns) {
52+ void mlir:: populateConvertMathToEmitCPatterns (RewritePatternSet &patterns) {
4853 auto *context = patterns.getContext ();
4954 patterns.insert <LowerToEmitCCallOpaque<math::FloorOp>>(context, " floor" );
5055 patterns.insert <LowerToEmitCCallOpaque<math::RoundEvenOp>>(context, " rint" );
@@ -56,44 +61,5 @@ static void populateConvertMathToEmitCPatterns(RewritePatternSet &patterns) {
5661 patterns.insert <LowerToEmitCCallOpaque<math::Atan2Op>>(context, " atan2" );
5762 patterns.insert <LowerToEmitCCallOpaque<math::CeilOp>>(context, " ceil" );
5863 patterns.insert <LowerToEmitCCallOpaque<math::AbsFOp>>(context, " fabs" );
59- patterns.insert <LowerToEmitCCallOpaque<math::FPowIOp>>(context, " powf" );
60- patterns.insert <LowerToEmitCCallOpaque<math::IPowIOp>>(context, " pow" );
61- }
62-
63- template <typename OpType>
64- LogicalResult LowerToEmitCCallOpaque<OpType>::matchAndRewrite(
65- OpType op, PatternRewriter &rewriter) const {
66- mlir::StringAttr callee = rewriter.getStringAttr (calleeStr);
67- auto actualOp = mlir::cast<OpType>(op);
68- rewriter.replaceOpWithNewOp <mlir::emitc::CallOpaqueOp>(
69- actualOp, actualOp.getType (), callee, actualOp->getOperands ());
70- return mlir::success ();
71- }
72-
73- void ConvertMathToEmitCPass::runOnOperation () {
74- auto moduleOp = getOperation ();
75- // Insert #include <math.h> at the beginning of the module
76- OpBuilder builder (moduleOp.getBodyRegion ());
77- builder.setInsertionPointToStart (&moduleOp.getBodyRegion ().front ());
78- builder.create <emitc::IncludeOp>(moduleOp.getLoc (),
79- builder.getStringAttr (" math.h" ));
80-
81- ConversionTarget target (getContext ());
82- target.addLegalOp <emitc::CallOpaqueOp>();
83-
84- target.addIllegalOp <math::FloorOp, math::ExpOp, math::RoundEvenOp,
85- math::CosOp, math::SinOp, math::Atan2Op, math::CeilOp,
86- math::AcosOp, math::AsinOp, math::AbsFOp, math::PowFOp,
87- math::FPowIOp, math::IPowIOp>();
88-
89- RewritePatternSet patterns (&getContext ());
90- populateConvertMathToEmitCPatterns (patterns);
91-
92- if (failed (applyPartialConversion (moduleOp, target, std::move (patterns))))
93- signalPassFailure ();
94- }
95-
96- std::unique_ptr<OperationPass<mlir::ModuleOp>>
97- mlir::createConvertMathToEmitCPass () {
98- return std::make_unique<ConvertMathToEmitCPass>();
64+ patterns.insert <LowerToEmitCCallOpaque<math::PowFOp>>(context, " pow" );
9965}
0 commit comments