Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions mlir/include/mlir/Conversion/MathToEmitC/MathToEmitC.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
//===- MathToEmitC.h - Math to EmitC Pass -----------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#ifndef MLIR_CONVERSION_MATHTOEMITC_MATHTOEMITC_H
#define MLIR_CONVERSION_MATHTOEMITC_MATHTOEMITC_H

namespace mlir {
class RewritePatternSet;

void populateConvertMathToEmitCPatterns(RewritePatternSet &patterns);

} // namespace mlir

#endif // MLIR_CONVERSION_MATHTOEMITC_MATHTOEMITC_H
21 changes: 21 additions & 0 deletions mlir/include/mlir/Conversion/MathToEmitC/MathToEmitCPass.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
//===- MathToEmitCPass.h - Math to EmitC Pass -----------------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#ifndef MLIR_CONVERSION_MATHTOEMITC_MATHTOEMITCPASS_H
#define MLIR_CONVERSION_MATHTOEMITC_MATHTOEMITCPASS_H

#include <memory>

namespace mlir {
class Pass;

#define GEN_PASS_DECL_CONVERTMATHTOEMITC
#include "mlir/Conversion/Passes.h.inc"
} // namespace mlir

#endif // MLIR_CONVERSION_MATHTOEMITC_MATHTOEMITCPASS_H
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

newline

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right

1 change: 1 addition & 0 deletions mlir/include/mlir/Conversion/Passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
#include "mlir/Conversion/IndexToLLVM/IndexToLLVM.h"
#include "mlir/Conversion/IndexToSPIRV/IndexToSPIRV.h"
#include "mlir/Conversion/LinalgToStandard/LinalgToStandard.h"
#include "mlir/Conversion/MathToEmitC/MathToEmitCPass.h"
#include "mlir/Conversion/MathToFuncs/MathToFuncs.h"
#include "mlir/Conversion/MathToLLVM/MathToLLVM.h"
#include "mlir/Conversion/MathToLibm/MathToLibm.h"
Expand Down
17 changes: 17 additions & 0 deletions mlir/include/mlir/Conversion/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -780,6 +780,23 @@ def ConvertMathToSPIRV : Pass<"convert-math-to-spirv"> {
let dependentDialects = ["spirv::SPIRVDialect"];
}

//===----------------------------------------------------------------------===//
// MathToEmitC
//===----------------------------------------------------------------------===//

def ConvertMathToEmitC : Pass<"convert-math-to-emitc", "ModuleOp"> {
let summary = "Convert some Math operations to EmitC Call_opaque";
let description = [{
This pass converts supported Math ops to call_opaque calls to compiler generated
functions implementing these operations in software.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: Maybe something like:
This pass converts supported Math ops to opaque_call ops targeting libc/libm functions.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I changed to what you suggested.
We might add new targets as suggested in this thread

Unlike convert-math-to-funcs pass, this pass uses call_opaque,
therefore enables us to overload the same funtion with different argument types
}];
let dependentDialects = ["emitc::EmitCDialect",
"math::MathDialect"
];
}

//===----------------------------------------------------------------------===//
// MathToFuncs
//===----------------------------------------------------------------------===//
Expand Down
1 change: 1 addition & 0 deletions mlir/lib/Conversion/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ add_subdirectory(IndexToLLVM)
add_subdirectory(IndexToSPIRV)
add_subdirectory(LinalgToStandard)
add_subdirectory(LLVMCommon)
add_subdirectory(MathToEmitC)
add_subdirectory(MathToFuncs)
add_subdirectory(MathToLibm)
add_subdirectory(MathToLLVM)
Expand Down
20 changes: 20 additions & 0 deletions mlir/lib/Conversion/MathToEmitC/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
add_mlir_conversion_library(MLIRMathToEmitC
MathToEmitC.cpp
MathToEmitCPass.cpp

ADDITIONAL_HEADER_DIRS
${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/MathToEmitC

DEPENDS
MLIRConversionPassIncGen

LINK_COMPONENTS
Core

LINK_LIBS PUBLIC
MLIRLLVMCommonConversion
MLIREmitCDialect
MLIRMathDialect
MLIRPass
MLIRTransforms
)
65 changes: 65 additions & 0 deletions mlir/lib/Conversion/MathToEmitC/MathToEmitC.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
//===- MathToEmitC.cpp - Math to EmitC Pass Implementation ----------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#include "mlir/Conversion/MathToEmitC/MathToEmitC.h"

#include "mlir/Dialect/EmitC/IR/EmitC.h"
#include "mlir/Dialect/Math/IR/Math.h"
#include "mlir/Transforms/DialectConversion.h"

using namespace mlir;

namespace {
template <typename OpType>
class LowerToEmitCCallOpaque : public mlir::OpRewritePattern<OpType> {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Drop all the mlir:: namespaces.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, Deleted

std::string calleeStr;

public:
LowerToEmitCCallOpaque(MLIRContext *context, std::string calleeStr)
: OpRewritePattern<OpType>(context), calleeStr(std::move(calleeStr)) {}

LogicalResult matchAndRewrite(OpType op,
PatternRewriter &rewriter) const override;
};

template <typename OpType>
LogicalResult LowerToEmitCCallOpaque<OpType>::matchAndRewrite(
OpType op, PatternRewriter &rewriter) const {
auto actualOp = mlir::cast<OpType>(op);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we need this cast?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The cast is redundant here, I removed it

if (!llvm::all_of(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can use something like

!llvm::all_of(actualOp->getOperandTypes(), llvm::IsaPred<Float32Type, Float64Type>)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I did, thanks!

actualOp->getOperands(),
[](Value operand) { return isa<FloatType>(operand.getType()); }) ||
!llvm::all_of(actualOp->getResultTypes(),
[](mlir::Type type) { return isa<FloatType>(type); })) {
op.emitError("non-float types are not supported");
return mlir::failure();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

return op.notifyMatchFailure(op.getLoc(), "...")

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, changed the lit test accordingly

}
mlir::StringAttr callee = rewriter.getStringAttr(calleeStr);
rewriter.replaceOpWithNewOp<mlir::emitc::CallOpaqueOp>(
actualOp, actualOp.getType(), callee, actualOp->getOperands());
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Using calleeStr directly should work I think.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You're right

return mlir::success();
}

} // namespace

// Populates patterns to replace `math` operations with `emitc.call_opaque`,
// using function names consistent with those in <math.h>.
void mlir::populateConvertMathToEmitCPatterns(RewritePatternSet &patterns) {
auto *context = patterns.getContext();
patterns.insert<LowerToEmitCCallOpaque<math::FloorOp>>(context, "floor");
patterns.insert<LowerToEmitCCallOpaque<math::RoundEvenOp>>(context, "rint");
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This may be unsafe, not sure. rint uses the current rounding mode, RoundToEven ignores the rounding mode.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You're right- for now I support only round

patterns.insert<LowerToEmitCCallOpaque<math::ExpOp>>(context, "exp");
patterns.insert<LowerToEmitCCallOpaque<math::CosOp>>(context, "cos");
patterns.insert<LowerToEmitCCallOpaque<math::SinOp>>(context, "sin");
patterns.insert<LowerToEmitCCallOpaque<math::AcosOp>>(context, "acos");
patterns.insert<LowerToEmitCCallOpaque<math::AsinOp>>(context, "asin");
patterns.insert<LowerToEmitCCallOpaque<math::Atan2Op>>(context, "atan2");
patterns.insert<LowerToEmitCCallOpaque<math::CeilOp>>(context, "ceil");
patterns.insert<LowerToEmitCCallOpaque<math::AbsFOp>>(context, "fabs");
patterns.insert<LowerToEmitCCallOpaque<math::PowFOp>>(context, "pow");
}
58 changes: 58 additions & 0 deletions mlir/lib/Conversion/MathToEmitC/MathToEmitCPass.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
//===- MathToEmitCPass.cpp - Math to EmitC Pass -----------------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file implements a pass to convert the Math dialect to the EmitC dialect.
//
//===----------------------------------------------------------------------===//

#include "mlir/Conversion/MathToEmitC/MathToEmitCPass.h"
#include "mlir/Conversion/MathToEmitC/MathToEmitC.h"
#include "mlir/Dialect/EmitC/IR/EmitC.h"
#include "mlir/Dialect/Math/IR/Math.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/DialectConversion.h"

namespace mlir {
#define GEN_PASS_DEF_CONVERTMATHTOEMITC
#include "mlir/Conversion/Passes.h.inc"
} // namespace mlir

using namespace mlir;
namespace {

// Replaces Math operations with `emitc.call_opaque` operations.
struct ConvertMathToEmitCPass
: public impl::ConvertMathToEmitCBase<ConvertMathToEmitCPass> {
public:
void runOnOperation() final;
};

} // end anonymous namespace

void ConvertMathToEmitCPass::runOnOperation() {
auto moduleOp = getOperation();
// Insert #include <math.h> at the beginning of the module
OpBuilder builder(moduleOp.getBodyRegion());
builder.setInsertionPointToStart(&moduleOp.getBodyRegion().front());
builder.create<emitc::IncludeOp>(moduleOp.getLoc(),
builder.getStringAttr("math.h"));

ConversionTarget target(getContext());
target.addLegalOp<emitc::CallOpaqueOp>();

target.addIllegalOp<math::FloorOp, math::ExpOp, math::RoundEvenOp,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
target.addIllegalOp<math::FloorOp, math::ExpOp, math::RoundEvenOp,
target.addIllegalOp<math::FloorOp, math::ExpOp, math::RoundOp,

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch!

math::CosOp, math::SinOp, math::Atan2Op, math::CeilOp,
math::AcosOp, math::AsinOp, math::AbsFOp, math::PowFOp,
math::FPowIOp, math::IPowIOp>();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove the integer variants here too, please.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right


RewritePatternSet patterns(&getContext());
populateConvertMathToEmitCPatterns(patterns);

if (failed(applyPartialConversion(moduleOp, target, std::move(patterns))))
signalPassFailure();
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

newline

133 changes: 133 additions & 0 deletions mlir/test/Conversion/MathToEmitC/math-to-emitc.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
// RUN: mlir-opt --split-input-file -convert-math-to-emitc -verify-diagnostics %s | FileCheck %s
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Split this into math-to-emitc.mlir for positive tests and math-to-emitc-failed.mlir for negative tests. You can then drop --split-input-file, verify-diagnostics and the // ----- separators from the positve tests.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added math-to-emitc-failed.mlir which tests invalid types.


// CHECK-LABEL: emitc.include "math.h"

// CHECK-LABEL: func.func @absf_to_call_opaque(
// CHECK-SAME: %[[VAL_0:.*]]: f32) {
// CHECK: %[[VAL_1:.*]] = emitc.call_opaque "fabs"(%[[VAL_0]]) : (f32) -> f32
// CHECK: return
// CHECK: }
func.func @absf_to_call_opaque(%arg0: f32) {
%1 = math.absf %arg0 : f32
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add the f64 variants for the tests please

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I decided to support only float32 until I decide how to proceed with this thread. I would appreciate you opinion about it.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added a commit that support both f64 and f32, and updated the lit tests

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you make the tests exhaustive, with something like this:

func.func @floor(%arg0: f32, %arg1: f64) {
    // C: emitc.call_opaque "floorf" (%arg0)
    // C-NEXT: emitc.call_opaque "floor" (%arg1)
    // CPP: emitc.call_opaque "std::floor" (%arg0)
    // CPP-NEXT: emitc.call_opaque "std::floor" (%arg1)
    %0 = math.floor %arg0 : f32
    %1 = math.floor %arg1 : f64
    return
}

Addiotionally a test for math.round is missing currently.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good. I changed the test accordingly and added one for math.round.

return
}

// -----

// CHECK-LABEL: func.func @floor_to_call_opaque(
// CHECK-SAME: %[[VAL_0:.*]]: f32) {
// CHECK: %[[VAL_1:.*]] = emitc.call_opaque "floor"(%[[VAL_0]]) : (f32) -> f32
// CHECK: return
// CHECK: }
func.func @floor_to_call_opaque(%arg0: f32) {
%1 = math.floor %arg0 : f32
return
}

// -----

// CHECK-LABEL: func.func @sin_to_call_opaque(
// CHECK-SAME: %[[VAL_0:.*]]: f32) {
// CHECK: %[[VAL_1:.*]] = emitc.call_opaque "sin"(%[[VAL_0]]) : (f32) -> f32
// CHECK: return
// CHECK: }
func.func @sin_to_call_opaque(%arg0: f32) {
%1 = math.sin %arg0 : f32
return
}

// -----

// CHECK-LABEL: func.func @cos_to_call_opaque(
// CHECK-SAME: %[[VAL_0:.*]]: f32) {
// CHECK: %[[VAL_1:.*]] = emitc.call_opaque "cos"(%[[VAL_0]]) : (f32) -> f32
// CHECK: return
// CHECK: }
func.func @cos_to_call_opaque(%arg0: f32) {
%1 = math.cos %arg0 : f32
return
}


// -----

// CHECK-LABEL: func.func @asin_to_call_opaque(
// CHECK-SAME: %[[VAL_0:.*]]: f32) {
// CHECK: %[[VAL_1:.*]] = emitc.call_opaque "asin"(%[[VAL_0]]) : (f32) -> f32
// CHECK: return
// CHECK: }
func.func @asin_to_call_opaque(%arg0: f32) {
%1 = math.asin %arg0 : f32
return
}

// -----

// CHECK-LABEL: func.func @acos_to_call_opaque(
// CHECK-SAME: %[[VAL_0:.*]]: f32) {
// CHECK: %[[VAL_1:.*]] = emitc.call_opaque "acos"(%[[VAL_0]]) : (f32) -> f32
// CHECK: return
// CHECK: }
func.func @acos_to_call_opaque(%arg0: f32) {
%1 = math.acos %arg0 : f32
return
}

// -----

// CHECK-LABEL: func.func @atan2_to_call_opaque(
// CHECK-SAME: %[[VAL_0:.*]]: f32,
// CHECK-SAME: %[[VAL_1:.*]]: f32) {
// CHECK: %[[VAL_2:.*]] = emitc.call_opaque "atan2"(%[[VAL_0]], %[[VAL_1]]) : (f32, f32) -> f32
// CHECK: return
// CHECK: }
func.func @atan2_to_call_opaque(%arg0: f32, %arg1: f32) {
%1 = math.atan2 %arg0, %arg1 : f32
return
}

// -----

// CHECK-LABEL: func.func @ceil_to_call_opaque(
// CHECK-SAME: %[[VAL_0:.*]]: f32) {
// CHECK: %[[VAL_1:.*]] = emitc.call_opaque "ceil"(%[[VAL_0]]) : (f32) -> f32
// CHECK: return
// CHECK: }
func.func @ceil_to_call_opaque(%arg0: f32) {
%1 = math.ceil %arg0 : f32
return
}

// -----

// CHECK-LABEL: func.func @exp_to_call_opaque(
// CHECK-SAME: %[[VAL_0:.*]]: f32) {
// CHECK: %[[VAL_1:.*]] = emitc.call_opaque "exp"(%[[VAL_0]]) : (f32) -> f32
// CHECK: return
// CHECK: }
func.func @exp_to_call_opaque(%arg0: f32) {
%1 = math.exp %arg0 : f32
return
}

// -----

// CHECK-LABEL: func.func @powf_to_call_opaque(
// CHECK-SAME: %[[VAL_0:.*]]: f32,
// CHECK-SAME: %[[VAL_1:.*]]: f32) {
// CHECK: %[[VAL_2:.*]] = emitc.call_opaque "pow"(%[[VAL_0]], %[[VAL_1]]) : (f32, f32) -> f32
// CHECK: return
// CHECK: }
func.func @powf_to_call_opaque(%arg0: f32, %arg1: f32) {
%1 = math.powf %arg0, %arg1 : f32
return
}

// -----

func.func @test(%arg0 : tensor<4xf32>) -> tensor<4xf32> {
// expected-error @+2 {{failed to legalize operation 'math.absf' that was explicitly marked illegal}}
// expected-error @+1 {{non-float types are not supported}}
%0 = math.absf %arg0 : tensor<4xf32>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add tests for unsupported bit-widths, like f16, f128

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added those to math-to-emitc-failed.mlir

return %0 : tensor<4xf32>
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

newline

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right

Loading
Loading