Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 21 additions & 0 deletions mlir/include/mlir/Conversion/MathToAPFloat/MathToAPFloat.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
//===- MathToAPFloat.h - Math to APFloat impl conversion ---*- C++ ------*-===//
//
// Part of the APFloat Project, under the Apache License v2.0 with APFloat
// Exceptions. See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH APFloat-exception
//
//===----------------------------------------------------------------------===//

#ifndef MLIR_CONVERSION_MATHTOAPFLOAT_MATHTOAPFLOAT_H
#define MLIR_CONVERSION_MATHTOAPFLOAT_MATHTOAPFLOAT_H

#include <memory>

namespace mlir {
class Pass;

#define GEN_PASS_DECL_MATHTOAPFLOATCONVERSIONPASS
#include "mlir/Conversion/Passes.h.inc"
} // namespace mlir

#endif // MLIR_CONVERSION_MATHTOAPFLOAT_MATHTOAPFLOAT_H
1 change: 1 addition & 0 deletions mlir/include/mlir/Conversion/Passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
#include "mlir/Conversion/IndexToLLVM/IndexToLLVM.h"
#include "mlir/Conversion/IndexToSPIRV/IndexToSPIRV.h"
#include "mlir/Conversion/LinalgToStandard/LinalgToStandard.h"
#include "mlir/Conversion/MathToAPFloat/MathToAPFloat.h"
#include "mlir/Conversion/MathToEmitC/MathToEmitCPass.h"
#include "mlir/Conversion/MathToFuncs/MathToFuncs.h"
#include "mlir/Conversion/MathToLLVM/MathToLLVM.h"
Expand Down
15 changes: 15 additions & 0 deletions mlir/include/mlir/Conversion/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -775,6 +775,21 @@ def ConvertMathToLibmPass : Pass<"convert-math-to-libm", "ModuleOp"> {
];
}

//===----------------------------------------------------------------------===//
// MathToAPFloat
//===----------------------------------------------------------------------===//

def MathToAPFloatConversionPass
: Pass<"convert-math-to-apfloat", "ModuleOp"> {
let summary = "Convert Math ops to APFloat runtime library calls";
let description = [{
This pass converts supported Math ops to APFloat-based runtime library
calls (APFloatWrappers.cpp). APFloat is a software implementation of
floating-point mathmetic operations.
}];
let dependentDialects = ["math::MathDialect", "func::FuncDialect"];
}

//===----------------------------------------------------------------------===//
// MathToLLVM
//===----------------------------------------------------------------------===//
Expand Down
16 changes: 16 additions & 0 deletions mlir/include/mlir/Dialect/Func/Utils/Utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,22 @@ FailureOr<FuncOp> lookupFnDecl(SymbolOpInterface symTable, StringRef name,
FunctionType funcT,
SymbolTableCollection *symbolTables = nullptr);

/// Create a FuncOp decl and insert it into `symTable` operation. If
/// `symbolTables` is provided, then the decl will be inserted into the
/// SymbolTableCollection.
FuncOp createFnDecl(OpBuilder &b, SymbolOpInterface symTable, StringRef name,
FunctionType funcT, bool setPrivate,
SymbolTableCollection *symbolTables = nullptr);

/// Helper function to look up or create the symbol for a runtime library
/// function with the given parameter types. Returns an int64_t, unless a
/// different result type is specified.
FailureOr<FuncOp>
lookupOrCreateFnDecl(OpBuilder &b, SymbolOpInterface symTable, StringRef name,
TypeRange paramTypes,
SymbolTableCollection *symbolTables = nullptr,
Type resultType = {});

} // namespace func
} // namespace mlir

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
//===----------------------------------------------------------------------===//

#include "mlir/Conversion/ArithToAPFloat/ArithToAPFloat.h"
#include "Utils.h"

#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Arith/Transforms/Passes.h"
Expand All @@ -25,47 +26,6 @@ namespace mlir {
using namespace mlir;
using namespace mlir::func;

static FuncOp createFnDecl(OpBuilder &b, SymbolOpInterface symTable,
StringRef name, FunctionType funcT, bool setPrivate,
SymbolTableCollection *symbolTables = nullptr) {
OpBuilder::InsertionGuard g(b);
assert(!symTable->getRegion(0).empty() && "expected non-empty region");
b.setInsertionPointToStart(&symTable->getRegion(0).front());
FuncOp funcOp = FuncOp::create(b, symTable->getLoc(), name, funcT);
if (setPrivate)
funcOp.setPrivate();
if (symbolTables) {
SymbolTable &symbolTable = symbolTables->getSymbolTable(symTable);
symbolTable.insert(funcOp, symTable->getRegion(0).front().begin());
}
return funcOp;
}

/// Helper function to look up or create the symbol for a runtime library
/// function with the given parameter types. Returns an int64_t, unless a
/// different result type is specified.
static FailureOr<FuncOp>
lookupOrCreateApFloatFn(OpBuilder &b, SymbolOpInterface symTable,
StringRef name, TypeRange paramTypes,
SymbolTableCollection *symbolTables = nullptr,
Type resultType = {}) {
if (!resultType)
resultType = IntegerType::get(symTable->getContext(), 64);
std::string funcName = (llvm::Twine("_mlir_apfloat_") + name).str();
auto funcT = FunctionType::get(b.getContext(), paramTypes, {resultType});
FailureOr<FuncOp> func =
lookupFnDecl(symTable, funcName, funcT, symbolTables);
// Failed due to type mismatch.
if (failed(func))
return func;
// Successfully matched existing decl.
if (*func)
return *func;

return createFnDecl(b, symTable, funcName, funcT,
/*setPrivate=*/true, symbolTables);
}

/// Helper function to look up or create the symbol for a runtime library
/// function for a binary arithmetic operation.
///
Expand All @@ -81,14 +41,9 @@ lookupOrCreateBinaryFn(OpBuilder &b, SymbolOpInterface symTable, StringRef name,
SymbolTableCollection *symbolTables = nullptr) {
auto i32Type = IntegerType::get(symTable->getContext(), 32);
auto i64Type = IntegerType::get(symTable->getContext(), 64);
return lookupOrCreateApFloatFn(b, symTable, name, {i32Type, i64Type, i64Type},
symbolTables);
}

static Value getSemanticsValue(OpBuilder &b, Location loc, FloatType floatTy) {
int32_t sem = llvm::APFloatBase::SemanticsToEnum(floatTy.getFloatSemantics());
return arith::ConstantOp::create(b, loc, b.getI32Type(),
b.getIntegerAttr(b.getI32Type(), sem));
std::string funcName = (llvm::Twine("_mlir_apfloat_") + name).str();
return lookupOrCreateFnDecl(b, symTable, funcName,
{i32Type, i64Type, i64Type}, symbolTables);
}

/// Given two operands of vector type and vector result type (with the same
Expand Down Expand Up @@ -197,7 +152,7 @@ struct BinaryArithOpToAPFloatConversion final : OpRewritePattern<OpTy> {
arith::BitcastOp::create(rewriter, loc, intWType, rhs));

// Call APFloat function.
Value semValue = getSemanticsValue(rewriter, loc, floatTy);
Value semValue = getAPFloatSemanticsValue(rewriter, loc, floatTy);
SmallVector<Value> params = {semValue, lhsBits, rhsBits};
auto resultOp = func::CallOp::create(rewriter, loc,
TypeRange(rewriter.getI64Type()),
Expand Down Expand Up @@ -231,8 +186,9 @@ struct FpToFpConversion final : OpRewritePattern<OpTy> {
// Get APFloat function from runtime library.
auto i32Type = IntegerType::get(symTable->getContext(), 32);
auto i64Type = IntegerType::get(symTable->getContext(), 64);
FailureOr<FuncOp> fn = lookupOrCreateApFloatFn(
rewriter, symTable, "convert", {i32Type, i32Type, i64Type});
FailureOr<FuncOp> fn =
lookupOrCreateFnDecl(rewriter, symTable, "_mlir_apfloat_convert",
{i32Type, i32Type, i64Type});
if (failed(fn))
return fn;

Expand All @@ -250,9 +206,10 @@ struct FpToFpConversion final : OpRewritePattern<OpTy> {
arith::BitcastOp::create(rewriter, loc, inIntWType, operand1));

// Call APFloat function.
Value inSemValue = getSemanticsValue(rewriter, loc, inFloatTy);
Value inSemValue = getAPFloatSemanticsValue(rewriter, loc, inFloatTy);
auto outFloatTy = cast<FloatType>(resultType);
Value outSemValue = getSemanticsValue(rewriter, loc, outFloatTy);
Value outSemValue =
getAPFloatSemanticsValue(rewriter, loc, outFloatTy);
std::array<Value, 3> params = {inSemValue, outSemValue, operandBits};
auto resultOp = func::CallOp::create(rewriter, loc,
TypeRange(rewriter.getI64Type()),
Expand Down Expand Up @@ -289,8 +246,8 @@ struct FpToIntConversion final : OpRewritePattern<OpTy> {
auto i32Type = IntegerType::get(symTable->getContext(), 32);
auto i64Type = IntegerType::get(symTable->getContext(), 64);
FailureOr<FuncOp> fn =
lookupOrCreateApFloatFn(rewriter, symTable, "convert_to_int",
{i32Type, i32Type, i1Type, i64Type});
lookupOrCreateFnDecl(rewriter, symTable, "_mlir_apfloat_convert_to_int",
{i32Type, i32Type, i1Type, i64Type});
if (failed(fn))
return fn;

Expand All @@ -308,7 +265,7 @@ struct FpToIntConversion final : OpRewritePattern<OpTy> {
arith::BitcastOp::create(rewriter, loc, inIntWType, operand1));

// Call APFloat function.
Value inSemValue = getSemanticsValue(rewriter, loc, inFloatTy);
Value inSemValue = getAPFloatSemanticsValue(rewriter, loc, inFloatTy);
auto outIntTy = cast<IntegerType>(resultType);
Value outWidthValue = arith::ConstantOp::create(
rewriter, loc, i32Type,
Expand Down Expand Up @@ -350,9 +307,9 @@ struct IntToFpConversion final : OpRewritePattern<OpTy> {
auto i1Type = IntegerType::get(symTable->getContext(), 1);
auto i32Type = IntegerType::get(symTable->getContext(), 32);
auto i64Type = IntegerType::get(symTable->getContext(), 64);
FailureOr<FuncOp> fn =
lookupOrCreateApFloatFn(rewriter, symTable, "convert_from_int",
{i32Type, i32Type, i1Type, i64Type});
FailureOr<FuncOp> fn = lookupOrCreateFnDecl(
rewriter, symTable, "_mlir_apfloat_convert_from_int",
{i32Type, i32Type, i1Type, i64Type});
if (failed(fn))
return fn;

Expand All @@ -377,7 +334,8 @@ struct IntToFpConversion final : OpRewritePattern<OpTy> {

// Call APFloat function.
auto outFloatTy = cast<FloatType>(resultType);
Value outSemValue = getSemanticsValue(rewriter, loc, outFloatTy);
Value outSemValue =
getAPFloatSemanticsValue(rewriter, loc, outFloatTy);
Value inWidthValue = arith::ConstantOp::create(
rewriter, loc, i32Type,
rewriter.getIntegerAttr(i32Type, inIntTy.getWidth()));
Expand Down Expand Up @@ -421,8 +379,8 @@ struct CmpFOpToAPFloatConversion final : OpRewritePattern<arith::CmpFOp> {
auto i32Type = IntegerType::get(symTable->getContext(), 32);
auto i64Type = IntegerType::get(symTable->getContext(), 64);
FailureOr<FuncOp> fn =
lookupOrCreateApFloatFn(rewriter, symTable, "compare",
{i32Type, i64Type, i64Type}, nullptr, i8Type);
lookupOrCreateFnDecl(rewriter, symTable, "_mlir_apfloat_compare",
{i32Type, i64Type, i64Type}, nullptr, i8Type);
if (failed(fn))
return fn;

Expand All @@ -443,7 +401,7 @@ struct CmpFOpToAPFloatConversion final : OpRewritePattern<arith::CmpFOp> {
arith::BitcastOp::create(rewriter, loc, intWType, rhs));

// Call APFloat function.
Value semValue = getSemanticsValue(rewriter, loc, floatTy);
Value semValue = getAPFloatSemanticsValue(rewriter, loc, floatTy);
SmallVector<Value> params = {semValue, lhsBits, rhsBits};
Value comparisonResult =
func::CallOp::create(rewriter, loc, TypeRange(i8Type),
Expand Down Expand Up @@ -569,8 +527,8 @@ struct NegFOpToAPFloatConversion final : OpRewritePattern<arith::NegFOp> {
// Get APFloat function from runtime library.
auto i32Type = IntegerType::get(symTable->getContext(), 32);
auto i64Type = IntegerType::get(symTable->getContext(), 64);
FailureOr<FuncOp> fn =
lookupOrCreateApFloatFn(rewriter, symTable, "neg", {i32Type, i64Type});
FailureOr<FuncOp> fn = lookupOrCreateFnDecl(
rewriter, symTable, "_mlir_apfloat_neg", {i32Type, i64Type});
if (failed(fn))
return fn;

Expand All @@ -588,7 +546,7 @@ struct NegFOpToAPFloatConversion final : OpRewritePattern<arith::NegFOp> {
arith::BitcastOp::create(rewriter, loc, intWType, operand1));

// Call APFloat function.
Value semValue = getSemanticsValue(rewriter, loc, floatTy);
Value semValue = getAPFloatSemanticsValue(rewriter, loc, floatTy);
SmallVector<Value> params = {semValue, operandBits};
Value negatedBits =
func::CallOp::create(rewriter, loc, TypeRange(i64Type),
Expand Down
49 changes: 49 additions & 0 deletions mlir/lib/Conversion/ArithAndMathToAPFloat/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
add_mlir_library(ArithAndMathToAPFloatUtils
Utils.cpp
PARTIAL_SOURCES_INTENDED

LINK_LIBS PUBLIC
MLIRArithDialect
)

add_mlir_conversion_library(MLIRArithToAPFloat
ArithToAPFloat.cpp
PARTIAL_SOURCES_INTENDED

ADDITIONAL_HEADER_DIRS
${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/ArithToLLVM

DEPENDS
MLIRConversionPassIncGen

LINK_COMPONENTS
Core

LINK_LIBS PUBLIC
ArithAndMathToAPFloatUtils
MLIRArithDialect
MLIRArithTransforms
MLIRFuncDialect
MLIRFuncUtils
MLIRVectorDialect
)

add_mlir_conversion_library(MLIRMathToAPFloat
MathToAPFloat.cpp
PARTIAL_SOURCES_INTENDED

ADDITIONAL_HEADER_DIRS
${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/MathToLLVM

DEPENDS
MLIRConversionPassIncGen

LINK_COMPONENTS
Core

LINK_LIBS PUBLIC
ArithAndMathToAPFloatUtils
MLIRMathDialect
MLIRFuncDialect
MLIRFuncUtils
)
Loading