Skip to content

Commit 7533d56

Browse files
committed
[mlir][math] Add FP software implementation lowering pass: math-to-apfloat
1 parent f6971bf commit 7533d56

File tree

15 files changed

+442
-93
lines changed

15 files changed

+442
-93
lines changed
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
//===- MathToAPFloat.h - Math to APFloat impl conversion ---*- C++ ------*-===//
2+
//
3+
// Part of the APFloat Project, under the Apache License v2.0 with APFloat
4+
// Exceptions. See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH APFloat-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#ifndef MLIR_CONVERSION_MATHTOAPFLOAT_MATHTOAPFLOAT_H
10+
#define MLIR_CONVERSION_MATHTOAPFLOAT_MATHTOAPFLOAT_H
11+
12+
#include <memory>
13+
14+
namespace mlir {
15+
class Pass;
16+
17+
#define GEN_PASS_DECL_MATHTOAPFLOATCONVERSIONPASS
18+
#include "mlir/Conversion/Passes.h.inc"
19+
} // namespace mlir
20+
21+
#endif // MLIR_CONVERSION_MATHTOAPFLOAT_MATHTOAPFLOAT_H

mlir/include/mlir/Conversion/Passes.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@
4444
#include "mlir/Conversion/IndexToLLVM/IndexToLLVM.h"
4545
#include "mlir/Conversion/IndexToSPIRV/IndexToSPIRV.h"
4646
#include "mlir/Conversion/LinalgToStandard/LinalgToStandard.h"
47+
#include "mlir/Conversion/MathToAPFloat/MathToAPFloat.h"
4748
#include "mlir/Conversion/MathToEmitC/MathToEmitCPass.h"
4849
#include "mlir/Conversion/MathToFuncs/MathToFuncs.h"
4950
#include "mlir/Conversion/MathToLLVM/MathToLLVM.h"

mlir/include/mlir/Conversion/Passes.td

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -775,6 +775,21 @@ def ConvertMathToLibmPass : Pass<"convert-math-to-libm", "ModuleOp"> {
775775
];
776776
}
777777

778+
//===----------------------------------------------------------------------===//
779+
// MathToAPFloat
780+
//===----------------------------------------------------------------------===//
781+
782+
def MathToAPFloatConversionPass
783+
: Pass<"convert-math-to-apfloat", "ModuleOp"> {
784+
let summary = "Convert Math ops to APFloat runtime library calls";
785+
let description = [{
786+
This pass converts supported Math ops to APFloat-based runtime library
787+
calls (APFloatWrappers.cpp). APFloat is a software implementation of
788+
floating-point mathmetic operations.
789+
}];
790+
let dependentDialects = ["math::MathDialect", "func::FuncDialect"];
791+
}
792+
778793
//===----------------------------------------------------------------------===//
779794
// MathToLLVM
780795
//===----------------------------------------------------------------------===//

mlir/include/mlir/Dialect/Func/Utils/Utils.h

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,22 @@ FailureOr<FuncOp> lookupFnDecl(SymbolOpInterface symTable, StringRef name,
6767
FunctionType funcT,
6868
SymbolTableCollection *symbolTables = nullptr);
6969

70+
/// Create a FuncOp decl and insert it into `symTable` operation. If
71+
/// `symbolTables` is provided, then the decl will be inserted into the
72+
/// SymbolTableCollection.
73+
FuncOp createFnDecl(OpBuilder &b, SymbolOpInterface symTable, StringRef name,
74+
FunctionType funcT, bool setPrivate,
75+
SymbolTableCollection *symbolTables = nullptr);
76+
77+
/// Helper function to look up or create the symbol for a runtime library
78+
/// function with the given parameter types. Returns an int64_t, unless a
79+
/// different result type is specified.
80+
FailureOr<FuncOp>
81+
lookupOrCreateFnDecl(OpBuilder &b, SymbolOpInterface symTable, StringRef name,
82+
TypeRange paramTypes,
83+
SymbolTableCollection *symbolTables = nullptr,
84+
Type resultType = {});
85+
7086
} // namespace func
7187
} // namespace mlir
7288

mlir/lib/Conversion/ArithToAPFloat/ArithToAPFloat.cpp renamed to mlir/lib/Conversion/ArithAndMathToAPFloat/ArithToAPFloat.cpp

Lines changed: 25 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
//===----------------------------------------------------------------------===//
88

99
#include "mlir/Conversion/ArithToAPFloat/ArithToAPFloat.h"
10+
#include "Utils.h"
1011

1112
#include "mlir/Dialect/Arith/IR/Arith.h"
1213
#include "mlir/Dialect/Arith/Transforms/Passes.h"
@@ -25,47 +26,6 @@ namespace mlir {
2526
using namespace mlir;
2627
using namespace mlir::func;
2728

28-
static FuncOp createFnDecl(OpBuilder &b, SymbolOpInterface symTable,
29-
StringRef name, FunctionType funcT, bool setPrivate,
30-
SymbolTableCollection *symbolTables = nullptr) {
31-
OpBuilder::InsertionGuard g(b);
32-
assert(!symTable->getRegion(0).empty() && "expected non-empty region");
33-
b.setInsertionPointToStart(&symTable->getRegion(0).front());
34-
FuncOp funcOp = FuncOp::create(b, symTable->getLoc(), name, funcT);
35-
if (setPrivate)
36-
funcOp.setPrivate();
37-
if (symbolTables) {
38-
SymbolTable &symbolTable = symbolTables->getSymbolTable(symTable);
39-
symbolTable.insert(funcOp, symTable->getRegion(0).front().begin());
40-
}
41-
return funcOp;
42-
}
43-
44-
/// Helper function to look up or create the symbol for a runtime library
45-
/// function with the given parameter types. Returns an int64_t, unless a
46-
/// different result type is specified.
47-
static FailureOr<FuncOp>
48-
lookupOrCreateApFloatFn(OpBuilder &b, SymbolOpInterface symTable,
49-
StringRef name, TypeRange paramTypes,
50-
SymbolTableCollection *symbolTables = nullptr,
51-
Type resultType = {}) {
52-
if (!resultType)
53-
resultType = IntegerType::get(symTable->getContext(), 64);
54-
std::string funcName = (llvm::Twine("_mlir_apfloat_") + name).str();
55-
auto funcT = FunctionType::get(b.getContext(), paramTypes, {resultType});
56-
FailureOr<FuncOp> func =
57-
lookupFnDecl(symTable, funcName, funcT, symbolTables);
58-
// Failed due to type mismatch.
59-
if (failed(func))
60-
return func;
61-
// Successfully matched existing decl.
62-
if (*func)
63-
return *func;
64-
65-
return createFnDecl(b, symTable, funcName, funcT,
66-
/*setPrivate=*/true, symbolTables);
67-
}
68-
6929
/// Helper function to look up or create the symbol for a runtime library
7030
/// function for a binary arithmetic operation.
7131
///
@@ -81,14 +41,9 @@ lookupOrCreateBinaryFn(OpBuilder &b, SymbolOpInterface symTable, StringRef name,
8141
SymbolTableCollection *symbolTables = nullptr) {
8242
auto i32Type = IntegerType::get(symTable->getContext(), 32);
8343
auto i64Type = IntegerType::get(symTable->getContext(), 64);
84-
return lookupOrCreateApFloatFn(b, symTable, name, {i32Type, i64Type, i64Type},
85-
symbolTables);
86-
}
87-
88-
static Value getSemanticsValue(OpBuilder &b, Location loc, FloatType floatTy) {
89-
int32_t sem = llvm::APFloatBase::SemanticsToEnum(floatTy.getFloatSemantics());
90-
return arith::ConstantOp::create(b, loc, b.getI32Type(),
91-
b.getIntegerAttr(b.getI32Type(), sem));
44+
std::string funcName = (llvm::Twine("_mlir_apfloat_") + name).str();
45+
return lookupOrCreateFnDecl(b, symTable, funcName,
46+
{i32Type, i64Type, i64Type}, symbolTables);
9247
}
9348

9449
/// Given two operands of vector type and vector result type (with the same
@@ -197,7 +152,7 @@ struct BinaryArithOpToAPFloatConversion final : OpRewritePattern<OpTy> {
197152
arith::BitcastOp::create(rewriter, loc, intWType, rhs));
198153

199154
// Call APFloat function.
200-
Value semValue = getSemanticsValue(rewriter, loc, floatTy);
155+
Value semValue = getAPFloatSemanticsValue(rewriter, loc, floatTy);
201156
SmallVector<Value> params = {semValue, lhsBits, rhsBits};
202157
auto resultOp = func::CallOp::create(rewriter, loc,
203158
TypeRange(rewriter.getI64Type()),
@@ -231,8 +186,9 @@ struct FpToFpConversion final : OpRewritePattern<OpTy> {
231186
// Get APFloat function from runtime library.
232187
auto i32Type = IntegerType::get(symTable->getContext(), 32);
233188
auto i64Type = IntegerType::get(symTable->getContext(), 64);
234-
FailureOr<FuncOp> fn = lookupOrCreateApFloatFn(
235-
rewriter, symTable, "convert", {i32Type, i32Type, i64Type});
189+
FailureOr<FuncOp> fn =
190+
lookupOrCreateFnDecl(rewriter, symTable, "_mlir_apfloat_convert",
191+
{i32Type, i32Type, i64Type});
236192
if (failed(fn))
237193
return fn;
238194

@@ -250,9 +206,10 @@ struct FpToFpConversion final : OpRewritePattern<OpTy> {
250206
arith::BitcastOp::create(rewriter, loc, inIntWType, operand1));
251207

252208
// Call APFloat function.
253-
Value inSemValue = getSemanticsValue(rewriter, loc, inFloatTy);
209+
Value inSemValue = getAPFloatSemanticsValue(rewriter, loc, inFloatTy);
254210
auto outFloatTy = cast<FloatType>(resultType);
255-
Value outSemValue = getSemanticsValue(rewriter, loc, outFloatTy);
211+
Value outSemValue =
212+
getAPFloatSemanticsValue(rewriter, loc, outFloatTy);
256213
std::array<Value, 3> params = {inSemValue, outSemValue, operandBits};
257214
auto resultOp = func::CallOp::create(rewriter, loc,
258215
TypeRange(rewriter.getI64Type()),
@@ -289,8 +246,8 @@ struct FpToIntConversion final : OpRewritePattern<OpTy> {
289246
auto i32Type = IntegerType::get(symTable->getContext(), 32);
290247
auto i64Type = IntegerType::get(symTable->getContext(), 64);
291248
FailureOr<FuncOp> fn =
292-
lookupOrCreateApFloatFn(rewriter, symTable, "convert_to_int",
293-
{i32Type, i32Type, i1Type, i64Type});
249+
lookupOrCreateFnDecl(rewriter, symTable, "_mlir_apfloat_convert_to_int",
250+
{i32Type, i32Type, i1Type, i64Type});
294251
if (failed(fn))
295252
return fn;
296253

@@ -308,7 +265,7 @@ struct FpToIntConversion final : OpRewritePattern<OpTy> {
308265
arith::BitcastOp::create(rewriter, loc, inIntWType, operand1));
309266

310267
// Call APFloat function.
311-
Value inSemValue = getSemanticsValue(rewriter, loc, inFloatTy);
268+
Value inSemValue = getAPFloatSemanticsValue(rewriter, loc, inFloatTy);
312269
auto outIntTy = cast<IntegerType>(resultType);
313270
Value outWidthValue = arith::ConstantOp::create(
314271
rewriter, loc, i32Type,
@@ -350,9 +307,9 @@ struct IntToFpConversion final : OpRewritePattern<OpTy> {
350307
auto i1Type = IntegerType::get(symTable->getContext(), 1);
351308
auto i32Type = IntegerType::get(symTable->getContext(), 32);
352309
auto i64Type = IntegerType::get(symTable->getContext(), 64);
353-
FailureOr<FuncOp> fn =
354-
lookupOrCreateApFloatFn(rewriter, symTable, "convert_from_int",
355-
{i32Type, i32Type, i1Type, i64Type});
310+
FailureOr<FuncOp> fn = lookupOrCreateFnDecl(
311+
rewriter, symTable, "_mlir_apfloat_convert_from_int",
312+
{i32Type, i32Type, i1Type, i64Type});
356313
if (failed(fn))
357314
return fn;
358315

@@ -377,7 +334,8 @@ struct IntToFpConversion final : OpRewritePattern<OpTy> {
377334

378335
// Call APFloat function.
379336
auto outFloatTy = cast<FloatType>(resultType);
380-
Value outSemValue = getSemanticsValue(rewriter, loc, outFloatTy);
337+
Value outSemValue =
338+
getAPFloatSemanticsValue(rewriter, loc, outFloatTy);
381339
Value inWidthValue = arith::ConstantOp::create(
382340
rewriter, loc, i32Type,
383341
rewriter.getIntegerAttr(i32Type, inIntTy.getWidth()));
@@ -421,8 +379,8 @@ struct CmpFOpToAPFloatConversion final : OpRewritePattern<arith::CmpFOp> {
421379
auto i32Type = IntegerType::get(symTable->getContext(), 32);
422380
auto i64Type = IntegerType::get(symTable->getContext(), 64);
423381
FailureOr<FuncOp> fn =
424-
lookupOrCreateApFloatFn(rewriter, symTable, "compare",
425-
{i32Type, i64Type, i64Type}, nullptr, i8Type);
382+
lookupOrCreateFnDecl(rewriter, symTable, "_mlir_apfloat_compare",
383+
{i32Type, i64Type, i64Type}, nullptr, i8Type);
426384
if (failed(fn))
427385
return fn;
428386

@@ -443,7 +401,7 @@ struct CmpFOpToAPFloatConversion final : OpRewritePattern<arith::CmpFOp> {
443401
arith::BitcastOp::create(rewriter, loc, intWType, rhs));
444402

445403
// Call APFloat function.
446-
Value semValue = getSemanticsValue(rewriter, loc, floatTy);
404+
Value semValue = getAPFloatSemanticsValue(rewriter, loc, floatTy);
447405
SmallVector<Value> params = {semValue, lhsBits, rhsBits};
448406
Value comparisonResult =
449407
func::CallOp::create(rewriter, loc, TypeRange(i8Type),
@@ -569,8 +527,8 @@ struct NegFOpToAPFloatConversion final : OpRewritePattern<arith::NegFOp> {
569527
// Get APFloat function from runtime library.
570528
auto i32Type = IntegerType::get(symTable->getContext(), 32);
571529
auto i64Type = IntegerType::get(symTable->getContext(), 64);
572-
FailureOr<FuncOp> fn =
573-
lookupOrCreateApFloatFn(rewriter, symTable, "neg", {i32Type, i64Type});
530+
FailureOr<FuncOp> fn = lookupOrCreateFnDecl(
531+
rewriter, symTable, "_mlir_apfloat_neg", {i32Type, i64Type});
574532
if (failed(fn))
575533
return fn;
576534

@@ -588,7 +546,7 @@ struct NegFOpToAPFloatConversion final : OpRewritePattern<arith::NegFOp> {
588546
arith::BitcastOp::create(rewriter, loc, intWType, operand1));
589547

590548
// Call APFloat function.
591-
Value semValue = getSemanticsValue(rewriter, loc, floatTy);
549+
Value semValue = getAPFloatSemanticsValue(rewriter, loc, floatTy);
592550
SmallVector<Value> params = {semValue, operandBits};
593551
Value negatedBits =
594552
func::CallOp::create(rewriter, loc, TypeRange(i64Type),
Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
add_mlir_library(ArithAndMathToAPFloatUtils
2+
Utils.cpp
3+
PARTIAL_SOURCES_INTENDED
4+
5+
LINK_LIBS PUBLIC
6+
MLIRArithDialect
7+
)
8+
9+
add_mlir_conversion_library(MLIRArithToAPFloat
10+
ArithToAPFloat.cpp
11+
PARTIAL_SOURCES_INTENDED
12+
13+
ADDITIONAL_HEADER_DIRS
14+
${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/ArithToLLVM
15+
16+
DEPENDS
17+
MLIRConversionPassIncGen
18+
19+
LINK_COMPONENTS
20+
Core
21+
22+
LINK_LIBS PUBLIC
23+
ArithAndMathToAPFloatUtils
24+
MLIRArithDialect
25+
MLIRArithTransforms
26+
MLIRFuncDialect
27+
MLIRFuncUtils
28+
MLIRVectorDialect
29+
)
30+
31+
add_mlir_conversion_library(MLIRMathToAPFloat
32+
MathToAPFloat.cpp
33+
PARTIAL_SOURCES_INTENDED
34+
35+
ADDITIONAL_HEADER_DIRS
36+
${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/MathToLLVM
37+
38+
DEPENDS
39+
MLIRConversionPassIncGen
40+
41+
LINK_COMPONENTS
42+
Core
43+
44+
LINK_LIBS PUBLIC
45+
ArithAndMathToAPFloatUtils
46+
MLIRMathDialect
47+
MLIRFuncDialect
48+
MLIRFuncUtils
49+
)

0 commit comments

Comments
 (0)