-
Notifications
You must be signed in to change notification settings - Fork 15.4k
[mlir][EmitC] Add MathToEmitC pass for math function lowering to EmitC #113799
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 4 commits
d5bd00c
ad9af42
f6c2406
23ada46
8992fd5
52c35a6
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,21 @@ | ||
| //===- MathToEmitC.h - Math to EmitCPatterns -------------------*- C++ -*-===// | ||
| // | ||
| // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. | ||
| // See https://llvm.org/LICENSE.txt for license information. | ||
| // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||
| // | ||
| //===----------------------------------------------------------------------===// | ||
|
|
||
| #ifndef MLIR_CONVERSION_MATHTOEMITC_MATHTOEMITC_H | ||
| #define MLIR_CONVERSION_MATHTOEMITC_MATHTOEMITC_H | ||
| #include "mlir/Dialect/EmitC/IR/EmitC.h" | ||
|
|
||
| namespace mlir { | ||
| class RewritePatternSet; | ||
|
|
||
marbre marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| void populateConvertMathToEmitCPatterns( | ||
| RewritePatternSet &patterns, | ||
| emitc::MathToEmitCLanguageTarget languageTarget); | ||
| } // namespace mlir | ||
|
|
||
| #endif // MLIR_CONVERSION_MATHTOEMITC_MATHTOEMITC_H | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,21 @@ | ||
| //===- MathToEmitCPass.h - Math to EmitC Pass -------------------*- C++ -*-===// | ||
| // | ||
| // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. | ||
| // See https://llvm.org/LICENSE.txt for license information. | ||
| // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||
| // | ||
| //===----------------------------------------------------------------------===// | ||
|
|
||
| #ifndef MLIR_CONVERSION_MATHTOEMITC_MATHTOEMITCPASS_H | ||
| #define MLIR_CONVERSION_MATHTOEMITC_MATHTOEMITCPASS_H | ||
|
|
||
| #include "mlir/Dialect/EmitC/IR/EmitC.h" | ||
| #include <memory> | ||
| namespace mlir { | ||
| class Pass; | ||
|
|
||
| #define GEN_PASS_DECL_CONVERTMATHTOEMITC | ||
| #include "mlir/Conversion/Passes.h.inc" | ||
| } // namespace mlir | ||
|
|
||
| #endif // MLIR_CONVERSION_MATHTOEMITC_MATHTOEMITCPASS_H |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -62,4 +62,13 @@ def EmitC_OpaqueAttr : EmitC_Attr<"Opaque", "opaque"> { | |
|
|
||
| def EmitC_OpaqueOrTypedAttr : AnyAttrOf<[EmitC_OpaqueAttr, TypedAttrInterface]>; | ||
|
|
||
| def MathToEmitCLanguageTarget : I32EnumAttr<"MathToEmitCLanguageTarget", | ||
| "Specifies the language target for generating callees.", [ | ||
| I32EnumAttrCase<"C", 0, "Use C-style function names">, | ||
| I32EnumAttrCase<"CPP", 1, "Use C++-style function names"> | ||
| ]> { | ||
| let cppNamespace = "::mlir::emitc"; | ||
| } | ||
|
||
|
|
||
|
|
||
| #endif // MLIR_DIALECT_EMITC_IR_EMITCATTRIBUTES | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,19 @@ | ||
| add_mlir_conversion_library(MLIRMathToEmitC | ||
| MathToEmitC.cpp | ||
| MathToEmitCPass.cpp | ||
|
|
||
| ADDITIONAL_HEADER_DIRS | ||
| ${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/MathToEmitC | ||
|
|
||
| DEPENDS | ||
| MLIRConversionPassIncGen | ||
|
|
||
| LINK_COMPONENTS | ||
| Core | ||
|
|
||
| LINK_LIBS PUBLIC | ||
| MLIREmitCDialect | ||
| MLIRMathDialect | ||
| MLIRPass | ||
| MLIRTransforms | ||
marbre marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| ) | ||
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
| @@ -0,0 +1,83 @@ | ||||||
| //===- MathToEmitC.cpp - Math to EmitC Patterns ----------------*- C++ -*-===// | ||||||
marbre marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||||||
| // | ||||||
| // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. | ||||||
| // See https://llvm.org/LICENSE.txt for license information. | ||||||
| // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||||||
| // | ||||||
| //===----------------------------------------------------------------------===// | ||||||
|
|
||||||
| #include "mlir/Conversion/MathToEmitC/MathToEmitC.h" | ||||||
|
|
||||||
| #include "mlir/Dialect/EmitC/IR/EmitC.h" | ||||||
| #include "mlir/Dialect/Math/IR/Math.h" | ||||||
| #include "mlir/Transforms/DialectConversion.h" | ||||||
|
|
||||||
| using namespace mlir; | ||||||
|
|
||||||
| namespace { | ||||||
| template <typename OpType> | ||||||
| class LowerToEmitCCallOpaque : public OpRewritePattern<OpType> { | ||||||
| std::string calleeStr; | ||||||
| emitc::MathToEmitCLanguageTarget languageTarget; | ||||||
|
|
||||||
| public: | ||||||
| LowerToEmitCCallOpaque(MLIRContext *context, std::string calleeStr, | ||||||
| emitc::MathToEmitCLanguageTarget languageTarget) | ||||||
| : OpRewritePattern<OpType>(context), calleeStr(std::move(calleeStr)), | ||||||
| languageTarget(languageTarget) {} | ||||||
|
|
||||||
| LogicalResult matchAndRewrite(OpType op, | ||||||
| PatternRewriter &rewriter) const override; | ||||||
| }; | ||||||
|
|
||||||
| template <typename OpType> | ||||||
| LogicalResult LowerToEmitCCallOpaque<OpType>::matchAndRewrite( | ||||||
| OpType op, PatternRewriter &rewriter) const { | ||||||
| if (!llvm::all_of(op->getOperandTypes(), llvm::IsaPred<Float32Type, Float64Type>)|| | ||||||
| !llvm::all_of(op->getResultTypes(),llvm::IsaPred<Float32Type, Float64Type>)) | ||||||
| return rewriter.notifyMatchFailure( | ||||||
| op.getLoc(), "expected all operands and results to be of type f32"); | ||||||
|
||||||
| op.getLoc(), "expected all operands and results to be of type f32"); | |
| op.getLoc(), "expected all operands and results to be of type f32 or f64"); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Right
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
| @@ -0,0 +1,52 @@ | ||||||
| //===- MathToEmitCPass.cpp - Math to EmitC Pass -----------------*- C++ -*-===// | ||||||
| // | ||||||
| // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. | ||||||
| // See https://llvm.org/LICENSE.txt for license information. | ||||||
| // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||||||
| // | ||||||
| //===----------------------------------------------------------------------===// | ||||||
| // | ||||||
| // This file implements a pass to convert the Math dialect to the EmitC dialect. | ||||||
| // | ||||||
| //===----------------------------------------------------------------------===// | ||||||
|
|
||||||
| #include "mlir/Conversion/MathToEmitC/MathToEmitCPass.h" | ||||||
| #include "mlir/Conversion/MathToEmitC/MathToEmitC.h" | ||||||
| #include "mlir/Dialect/EmitC/IR/EmitC.h" | ||||||
| #include "mlir/Dialect/Math/IR/Math.h" | ||||||
| #include "mlir/Pass/Pass.h" | ||||||
| #include "mlir/Transforms/DialectConversion.h" | ||||||
|
|
||||||
| namespace mlir { | ||||||
| #define GEN_PASS_DEF_CONVERTMATHTOEMITC | ||||||
| #include "mlir/Conversion/Passes.h.inc" | ||||||
| } // namespace mlir | ||||||
|
|
||||||
| using namespace mlir; | ||||||
| namespace { | ||||||
|
|
||||||
| // Replaces Math operations with `emitc.call_opaque` operations. | ||||||
| struct ConvertMathToEmitC | ||||||
| : public impl::ConvertMathToEmitCBase<ConvertMathToEmitC> { | ||||||
| using ConvertMathToEmitCBase::ConvertMathToEmitCBase; | ||||||
|
|
||||||
| public: | ||||||
| void runOnOperation() final; | ||||||
| }; | ||||||
|
|
||||||
| } // namespace | ||||||
|
|
||||||
| void ConvertMathToEmitC::runOnOperation() { | ||||||
| ConversionTarget target(getContext()); | ||||||
| target.addLegalOp<emitc::CallOpaqueOp>(); | ||||||
|
|
||||||
| target.addIllegalOp<math::FloorOp, math::ExpOp, math::RoundEvenOp, | ||||||
|
||||||
| target.addIllegalOp<math::FloorOp, math::ExpOp, math::RoundEvenOp, | |
| target.addIllegalOp<math::FloorOp, math::ExpOp, math::RoundOp, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good catch!
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,23 @@ | ||
| // RUN: mlir-opt -split-input-file -convert-math-to-emitc -verify-diagnostics %s | ||
|
|
||
| func.func @unsupported_tensor_type(%arg0 : tensor<4xf32>) -> tensor<4xf32> { | ||
| // expected-error @+1 {{failed to legalize operation 'math.absf' that was explicitly marked illegal}} | ||
| %0 = math.absf %arg0 : tensor<4xf32> | ||
| return %0 : tensor<4xf32> | ||
| } | ||
|
|
||
| // ----- | ||
|
|
||
| func.func @unsupported_f16_type(%arg0 : f16) -> f16 { | ||
| // expected-error @+1 {{failed to legalize operation 'math.absf' that was explicitly marked illegal}} | ||
| %0 = math.absf %arg0 : f16 | ||
| return %0 : f16 | ||
| } | ||
|
|
||
| // ----- | ||
|
|
||
| func.func @unsupported_f128_type(%arg0 : f128) -> f128 { | ||
| // expected-error @+1 {{failed to legalize operation 'math.absf' that was explicitly marked illegal}} | ||
| %0 = math.absf %arg0 : f128 | ||
| return %0 : f128 | ||
| } |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,65 @@ | ||
| // RUN: mlir-opt -convert-math-to-emitc=language-target=C %s | FileCheck %s --check-prefix=C | ||
| // RUN: mlir-opt -convert-math-to-emitc=language-target=CPP %s | FileCheck %s --check-prefix=CPP | ||
|
|
||
| func.func @absf_to_call_opaque(%arg0: f32) { | ||
| // C: emitc.call_opaque "fabsf" | ||
| // CPP: emitc.call_opaque "std::fabs" | ||
| %1 = math.absf %arg0 : f32 | ||
|
||
| return | ||
| } | ||
| func.func @floor_to_call_opaque(%arg0: f32) { | ||
| // C: emitc.call_opaque "floorf" | ||
| // CPP: emitc.call_opaque "std::floor" | ||
| %1 = math.floor %arg0 : f32 | ||
| return | ||
| } | ||
| func.func @sin_to_call_opaque(%arg0: f32) { | ||
| // C: emitc.call_opaque "sinf" | ||
| // CPP: emitc.call_opaque "std::sin" | ||
| %1 = math.sin %arg0 : f32 | ||
| return | ||
| } | ||
| func.func @cos_to_call_opaque(%arg0: f32) { | ||
| // C: emitc.call_opaque "cosf" | ||
| // CPP: emitc.call_opaque "std::cos" | ||
| %1 = math.cos %arg0 : f32 | ||
| return | ||
| } | ||
| func.func @asin_to_call_opaque(%arg0: f32) { | ||
| // C: emitc.call_opaque "asinf" | ||
| // CPP: emitc.call_opaque "std::asin" | ||
| %1 = math.asin %arg0 : f32 | ||
| return | ||
| } | ||
| func.func @acos_to_call_opaque(%arg0: f64) { | ||
| // C: emitc.call_opaque "acos" | ||
| // CPP: emitc.call_opaque "std::acos" | ||
| %1 = math.acos %arg0 : f64 | ||
| return | ||
| } | ||
| func.func @atan2_to_call_opaque(%arg0: f64, %arg1: f64) { | ||
| // C: emitc.call_opaque "atan2" | ||
| // CPP: emitc.call_opaque "std::atan2" | ||
| %1 = math.atan2 %arg0, %arg1 : f64 | ||
| return | ||
| } | ||
| func.func @ceil_to_call_opaque(%arg0: f64) { | ||
| // C: emitc.call_opaque "ceil" | ||
| // CPP: emitc.call_opaque "std::ceil" | ||
| %1 = math.ceil %arg0 : f64 | ||
| return | ||
| } | ||
| func.func @exp_to_call_opaque(%arg0: f64) { | ||
| // C: emitc.call_opaque "exp" | ||
| // CPP: emitc.call_opaque "std::exp" | ||
| %1 = math.exp %arg0 : f64 | ||
| return | ||
| } | ||
| func.func @powf_to_call_opaque(%arg0: f64, %arg1: f64) { | ||
| // C: emitc.call_opaque "pow" | ||
| // CPP: emitc.call_opaque "std::pow" | ||
| %1 = math.powf %arg0, %arg1 : f64 | ||
| return | ||
| } | ||
|
|
||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.