-
Notifications
You must be signed in to change notification settings - Fork 15.4k
[mlir][EmitC] Add MathToEmitC pass for math function lowering to EmitC #113799
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 2 commits
d5bd00c
ad9af42
f6c2406
23ada46
8992fd5
52c35a6
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,19 @@ | ||
| //===- MathToEmitC.h - Math to EmitC Pass -----------*- C++ -*-===// | ||
| // | ||
| // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. | ||
| // See https://llvm.org/LICENSE.txt for license information. | ||
| // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||
| // | ||
| //===----------------------------------------------------------------------===// | ||
|
|
||
| #ifndef MLIR_CONVERSION_MATHTOEMITC_MATHTOEMITC_H | ||
| #define MLIR_CONVERSION_MATHTOEMITC_MATHTOEMITC_H | ||
|
|
||
| namespace mlir { | ||
| class RewritePatternSet; | ||
|
|
||
| void populateConvertMathToEmitCPatterns(RewritePatternSet &patterns); | ||
|
|
||
marbre marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| } // namespace mlir | ||
|
|
||
| #endif // MLIR_CONVERSION_MATHTOEMITC_MATHTOEMITC_H | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,21 @@ | ||
| //===- MathToEmitCPass.h - Math to EmitC Pass -----------------*- C++ -*-===// | ||
marbre marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| // | ||
| // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. | ||
| // See https://llvm.org/LICENSE.txt for license information. | ||
| // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||
| // | ||
| //===----------------------------------------------------------------------===// | ||
|
|
||
| #ifndef MLIR_CONVERSION_MATHTOEMITC_MATHTOEMITCPASS_H | ||
| #define MLIR_CONVERSION_MATHTOEMITC_MATHTOEMITCPASS_H | ||
|
|
||
| #include <memory> | ||
|
|
||
| namespace mlir { | ||
| class Pass; | ||
|
|
||
| #define GEN_PASS_DECL_CONVERTMATHTOEMITC | ||
| #include "mlir/Conversion/Passes.h.inc" | ||
| } // namespace mlir | ||
|
|
||
| #endif // MLIR_CONVERSION_MATHTOEMITC_MATHTOEMITCPASS_H | ||
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -780,6 +780,23 @@ def ConvertMathToSPIRV : Pass<"convert-math-to-spirv"> { | |
| let dependentDialects = ["spirv::SPIRVDialect"]; | ||
| } | ||
|
|
||
| //===----------------------------------------------------------------------===// | ||
| // MathToEmitC | ||
| //===----------------------------------------------------------------------===// | ||
|
|
||
| def ConvertMathToEmitC : Pass<"convert-math-to-emitc", "ModuleOp"> { | ||
| let summary = "Convert some Math operations to EmitC Call_opaque"; | ||
| let description = [{ | ||
| This pass converts supported Math ops to call_opaque calls to compiler generated | ||
| functions implementing these operations in software. | ||
|
||
| Unlike convert-math-to-funcs pass, this pass uses call_opaque, | ||
| therefore enables us to overload the same funtion with different argument types | ||
| }]; | ||
| let dependentDialects = ["emitc::EmitCDialect", | ||
| "math::MathDialect" | ||
| ]; | ||
marbre marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| } | ||
|
|
||
| //===----------------------------------------------------------------------===// | ||
| // MathToFuncs | ||
| //===----------------------------------------------------------------------===// | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,20 @@ | ||
| add_mlir_conversion_library(MLIRMathToEmitC | ||
| MathToEmitC.cpp | ||
| MathToEmitCPass.cpp | ||
|
|
||
| ADDITIONAL_HEADER_DIRS | ||
| ${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/MathToEmitC | ||
|
|
||
| DEPENDS | ||
| MLIRConversionPassIncGen | ||
|
|
||
| LINK_COMPONENTS | ||
| Core | ||
|
|
||
| LINK_LIBS PUBLIC | ||
| MLIRLLVMCommonConversion | ||
marbre marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| MLIREmitCDialect | ||
| MLIRMathDialect | ||
| MLIRPass | ||
| MLIRTransforms | ||
| ) | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,65 @@ | ||
| //===- MathToEmitC.cpp - Math to EmitC Pass Implementation ----------===// | ||
marbre marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| // | ||
| // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. | ||
| // See https://llvm.org/LICENSE.txt for license information. | ||
| // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||
| // | ||
| //===----------------------------------------------------------------------===// | ||
|
|
||
| #include "mlir/Conversion/MathToEmitC/MathToEmitC.h" | ||
|
|
||
| #include "mlir/Dialect/EmitC/IR/EmitC.h" | ||
| #include "mlir/Dialect/Math/IR/Math.h" | ||
| #include "mlir/Transforms/DialectConversion.h" | ||
|
|
||
| using namespace mlir; | ||
|
|
||
| namespace { | ||
| template <typename OpType> | ||
| class LowerToEmitCCallOpaque : public mlir::OpRewritePattern<OpType> { | ||
|
||
| std::string calleeStr; | ||
|
|
||
| public: | ||
| LowerToEmitCCallOpaque(MLIRContext *context, std::string calleeStr) | ||
| : OpRewritePattern<OpType>(context), calleeStr(std::move(calleeStr)) {} | ||
|
|
||
| LogicalResult matchAndRewrite(OpType op, | ||
| PatternRewriter &rewriter) const override; | ||
| }; | ||
|
|
||
| template <typename OpType> | ||
| LogicalResult LowerToEmitCCallOpaque<OpType>::matchAndRewrite( | ||
| OpType op, PatternRewriter &rewriter) const { | ||
| auto actualOp = mlir::cast<OpType>(op); | ||
|
||
| if (!llvm::all_of( | ||
|
||
| actualOp->getOperands(), | ||
| [](Value operand) { return isa<FloatType>(operand.getType()); }) || | ||
| !llvm::all_of(actualOp->getResultTypes(), | ||
| [](mlir::Type type) { return isa<FloatType>(type); })) { | ||
| op.emitError("non-float types are not supported"); | ||
| return mlir::failure(); | ||
|
||
| } | ||
| mlir::StringAttr callee = rewriter.getStringAttr(calleeStr); | ||
| rewriter.replaceOpWithNewOp<mlir::emitc::CallOpaqueOp>( | ||
| actualOp, actualOp.getType(), callee, actualOp->getOperands()); | ||
|
||
| return mlir::success(); | ||
| } | ||
|
|
||
| } // namespace | ||
|
|
||
| // Populates patterns to replace `math` operations with `emitc.call_opaque`, | ||
| // using function names consistent with those in <math.h>. | ||
| void mlir::populateConvertMathToEmitCPatterns(RewritePatternSet &patterns) { | ||
| auto *context = patterns.getContext(); | ||
| patterns.insert<LowerToEmitCCallOpaque<math::FloorOp>>(context, "floor"); | ||
| patterns.insert<LowerToEmitCCallOpaque<math::RoundEvenOp>>(context, "rint"); | ||
|
||
| patterns.insert<LowerToEmitCCallOpaque<math::ExpOp>>(context, "exp"); | ||
| patterns.insert<LowerToEmitCCallOpaque<math::CosOp>>(context, "cos"); | ||
| patterns.insert<LowerToEmitCCallOpaque<math::SinOp>>(context, "sin"); | ||
| patterns.insert<LowerToEmitCCallOpaque<math::AcosOp>>(context, "acos"); | ||
| patterns.insert<LowerToEmitCCallOpaque<math::AsinOp>>(context, "asin"); | ||
| patterns.insert<LowerToEmitCCallOpaque<math::Atan2Op>>(context, "atan2"); | ||
| patterns.insert<LowerToEmitCCallOpaque<math::CeilOp>>(context, "ceil"); | ||
| patterns.insert<LowerToEmitCCallOpaque<math::AbsFOp>>(context, "fabs"); | ||
| patterns.insert<LowerToEmitCCallOpaque<math::PowFOp>>(context, "pow"); | ||
| } | ||
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
| @@ -0,0 +1,58 @@ | ||||||
| //===- MathToEmitCPass.cpp - Math to EmitC Pass -----------------*- C++ -*-===// | ||||||
| // | ||||||
| // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. | ||||||
| // See https://llvm.org/LICENSE.txt for license information. | ||||||
| // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||||||
| // | ||||||
| //===----------------------------------------------------------------------===// | ||||||
| // | ||||||
| // This file implements a pass to convert the Math dialect to the EmitC dialect. | ||||||
| // | ||||||
| //===----------------------------------------------------------------------===// | ||||||
|
|
||||||
| #include "mlir/Conversion/MathToEmitC/MathToEmitCPass.h" | ||||||
| #include "mlir/Conversion/MathToEmitC/MathToEmitC.h" | ||||||
| #include "mlir/Dialect/EmitC/IR/EmitC.h" | ||||||
| #include "mlir/Dialect/Math/IR/Math.h" | ||||||
| #include "mlir/Pass/Pass.h" | ||||||
| #include "mlir/Transforms/DialectConversion.h" | ||||||
|
|
||||||
| namespace mlir { | ||||||
| #define GEN_PASS_DEF_CONVERTMATHTOEMITC | ||||||
| #include "mlir/Conversion/Passes.h.inc" | ||||||
| } // namespace mlir | ||||||
|
|
||||||
| using namespace mlir; | ||||||
| namespace { | ||||||
|
|
||||||
| // Replaces Math operations with `emitc.call_opaque` operations. | ||||||
| struct ConvertMathToEmitCPass | ||||||
| : public impl::ConvertMathToEmitCBase<ConvertMathToEmitCPass> { | ||||||
marbre marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||||||
| public: | ||||||
| void runOnOperation() final; | ||||||
| }; | ||||||
|
|
||||||
| } // end anonymous namespace | ||||||
marbre marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||||||
|
|
||||||
| void ConvertMathToEmitCPass::runOnOperation() { | ||||||
marbre marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||||||
| auto moduleOp = getOperation(); | ||||||
| // Insert #include <math.h> at the beginning of the module | ||||||
marbre marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||||||
| OpBuilder builder(moduleOp.getBodyRegion()); | ||||||
| builder.setInsertionPointToStart(&moduleOp.getBodyRegion().front()); | ||||||
| builder.create<emitc::IncludeOp>(moduleOp.getLoc(), | ||||||
| builder.getStringAttr("math.h")); | ||||||
|
|
||||||
| ConversionTarget target(getContext()); | ||||||
| target.addLegalOp<emitc::CallOpaqueOp>(); | ||||||
|
|
||||||
| target.addIllegalOp<math::FloorOp, math::ExpOp, math::RoundEvenOp, | ||||||
|
||||||
| target.addIllegalOp<math::FloorOp, math::ExpOp, math::RoundEvenOp, | |
| target.addIllegalOp<math::FloorOp, math::ExpOp, math::RoundOp, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good catch!
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Remove the integer variants here too, please.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Right
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
newline
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,133 @@ | ||
| // RUN: mlir-opt --split-input-file -convert-math-to-emitc -verify-diagnostics %s | FileCheck %s | ||
|
||
|
|
||
| // CHECK-LABEL: emitc.include "math.h" | ||
|
|
||
| // CHECK-LABEL: func.func @absf_to_call_opaque( | ||
| // CHECK-SAME: %[[VAL_0:.*]]: f32) { | ||
| // CHECK: %[[VAL_1:.*]] = emitc.call_opaque "fabs"(%[[VAL_0]]) : (f32) -> f32 | ||
| // CHECK: return | ||
| // CHECK: } | ||
| func.func @absf_to_call_opaque(%arg0: f32) { | ||
| %1 = math.absf %arg0 : f32 | ||
|
||
| return | ||
| } | ||
|
|
||
| // ----- | ||
|
|
||
| // CHECK-LABEL: func.func @floor_to_call_opaque( | ||
| // CHECK-SAME: %[[VAL_0:.*]]: f32) { | ||
| // CHECK: %[[VAL_1:.*]] = emitc.call_opaque "floor"(%[[VAL_0]]) : (f32) -> f32 | ||
| // CHECK: return | ||
| // CHECK: } | ||
| func.func @floor_to_call_opaque(%arg0: f32) { | ||
| %1 = math.floor %arg0 : f32 | ||
| return | ||
| } | ||
|
|
||
| // ----- | ||
|
|
||
| // CHECK-LABEL: func.func @sin_to_call_opaque( | ||
| // CHECK-SAME: %[[VAL_0:.*]]: f32) { | ||
| // CHECK: %[[VAL_1:.*]] = emitc.call_opaque "sin"(%[[VAL_0]]) : (f32) -> f32 | ||
| // CHECK: return | ||
| // CHECK: } | ||
| func.func @sin_to_call_opaque(%arg0: f32) { | ||
| %1 = math.sin %arg0 : f32 | ||
| return | ||
| } | ||
|
|
||
| // ----- | ||
|
|
||
| // CHECK-LABEL: func.func @cos_to_call_opaque( | ||
| // CHECK-SAME: %[[VAL_0:.*]]: f32) { | ||
| // CHECK: %[[VAL_1:.*]] = emitc.call_opaque "cos"(%[[VAL_0]]) : (f32) -> f32 | ||
| // CHECK: return | ||
| // CHECK: } | ||
| func.func @cos_to_call_opaque(%arg0: f32) { | ||
| %1 = math.cos %arg0 : f32 | ||
| return | ||
| } | ||
|
|
||
|
|
||
| // ----- | ||
|
|
||
| // CHECK-LABEL: func.func @asin_to_call_opaque( | ||
| // CHECK-SAME: %[[VAL_0:.*]]: f32) { | ||
| // CHECK: %[[VAL_1:.*]] = emitc.call_opaque "asin"(%[[VAL_0]]) : (f32) -> f32 | ||
| // CHECK: return | ||
| // CHECK: } | ||
| func.func @asin_to_call_opaque(%arg0: f32) { | ||
| %1 = math.asin %arg0 : f32 | ||
| return | ||
| } | ||
|
|
||
| // ----- | ||
|
|
||
| // CHECK-LABEL: func.func @acos_to_call_opaque( | ||
| // CHECK-SAME: %[[VAL_0:.*]]: f32) { | ||
| // CHECK: %[[VAL_1:.*]] = emitc.call_opaque "acos"(%[[VAL_0]]) : (f32) -> f32 | ||
| // CHECK: return | ||
| // CHECK: } | ||
| func.func @acos_to_call_opaque(%arg0: f32) { | ||
| %1 = math.acos %arg0 : f32 | ||
| return | ||
| } | ||
|
|
||
| // ----- | ||
|
|
||
| // CHECK-LABEL: func.func @atan2_to_call_opaque( | ||
| // CHECK-SAME: %[[VAL_0:.*]]: f32, | ||
| // CHECK-SAME: %[[VAL_1:.*]]: f32) { | ||
| // CHECK: %[[VAL_2:.*]] = emitc.call_opaque "atan2"(%[[VAL_0]], %[[VAL_1]]) : (f32, f32) -> f32 | ||
| // CHECK: return | ||
| // CHECK: } | ||
| func.func @atan2_to_call_opaque(%arg0: f32, %arg1: f32) { | ||
| %1 = math.atan2 %arg0, %arg1 : f32 | ||
| return | ||
| } | ||
|
|
||
| // ----- | ||
|
|
||
| // CHECK-LABEL: func.func @ceil_to_call_opaque( | ||
| // CHECK-SAME: %[[VAL_0:.*]]: f32) { | ||
| // CHECK: %[[VAL_1:.*]] = emitc.call_opaque "ceil"(%[[VAL_0]]) : (f32) -> f32 | ||
| // CHECK: return | ||
| // CHECK: } | ||
| func.func @ceil_to_call_opaque(%arg0: f32) { | ||
| %1 = math.ceil %arg0 : f32 | ||
| return | ||
| } | ||
|
|
||
| // ----- | ||
|
|
||
| // CHECK-LABEL: func.func @exp_to_call_opaque( | ||
| // CHECK-SAME: %[[VAL_0:.*]]: f32) { | ||
| // CHECK: %[[VAL_1:.*]] = emitc.call_opaque "exp"(%[[VAL_0]]) : (f32) -> f32 | ||
| // CHECK: return | ||
| // CHECK: } | ||
| func.func @exp_to_call_opaque(%arg0: f32) { | ||
| %1 = math.exp %arg0 : f32 | ||
| return | ||
| } | ||
|
|
||
| // ----- | ||
|
|
||
| // CHECK-LABEL: func.func @powf_to_call_opaque( | ||
| // CHECK-SAME: %[[VAL_0:.*]]: f32, | ||
| // CHECK-SAME: %[[VAL_1:.*]]: f32) { | ||
| // CHECK: %[[VAL_2:.*]] = emitc.call_opaque "pow"(%[[VAL_0]], %[[VAL_1]]) : (f32, f32) -> f32 | ||
| // CHECK: return | ||
| // CHECK: } | ||
| func.func @powf_to_call_opaque(%arg0: f32, %arg1: f32) { | ||
| %1 = math.powf %arg0, %arg1 : f32 | ||
| return | ||
| } | ||
|
|
||
| // ----- | ||
|
|
||
| func.func @test(%arg0 : tensor<4xf32>) -> tensor<4xf32> { | ||
| // expected-error @+2 {{failed to legalize operation 'math.absf' that was explicitly marked illegal}} | ||
| // expected-error @+1 {{non-float types are not supported}} | ||
| %0 = math.absf %arg0 : tensor<4xf32> | ||
|
||
| return %0 : tensor<4xf32> | ||
| } | ||
|
||
Uh oh!
There was an error while loading. Please reload this page.