Skip to content

Commit 1e43eb1

Browse files
committed
[mlir][math] Add FP software implementation lowering pass: math-to-apfloat
1 parent f6971bf commit 1e43eb1

File tree

9 files changed

+175
-53
lines changed

9 files changed

+175
-53
lines changed
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
//===- MathToAPFloat.h - Math to APFloat impl conversion ---*- C++ ------*-===//
2+
//
3+
// Part of the APFloat Project, under the Apache License v2.0 with APFloat
4+
// Exceptions. See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH APFloat-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#ifndef MLIR_CONVERSION_MATHTOAPFLOAT_MATHTOAPFLOAT_H
10+
#define MLIR_CONVERSION_MATHTOAPFLOAT_MATHTOAPFLOAT_H
11+
12+
#include <memory>
13+
14+
namespace mlir {
15+
class Pass;
16+
17+
#define GEN_PASS_DECL_MATHTOAPFLOATCONVERSIONPASS
18+
#include "mlir/Conversion/Passes.h.inc"
19+
} // namespace mlir
20+
21+
#endif // MLIR_CONVERSION_MATHTOAPFLOAT_MATHTOAPFLOAT_H

mlir/include/mlir/Conversion/Passes.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@
4444
#include "mlir/Conversion/IndexToLLVM/IndexToLLVM.h"
4545
#include "mlir/Conversion/IndexToSPIRV/IndexToSPIRV.h"
4646
#include "mlir/Conversion/LinalgToStandard/LinalgToStandard.h"
47+
#include "mlir/Conversion/MathToAPFloat/MathToAPFloat.h"
4748
#include "mlir/Conversion/MathToEmitC/MathToEmitCPass.h"
4849
#include "mlir/Conversion/MathToFuncs/MathToFuncs.h"
4950
#include "mlir/Conversion/MathToLLVM/MathToLLVM.h"

mlir/include/mlir/Conversion/Passes.td

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -775,6 +775,21 @@ def ConvertMathToLibmPass : Pass<"convert-math-to-libm", "ModuleOp"> {
775775
];
776776
}
777777

778+
//===----------------------------------------------------------------------===//
779+
// MathToAPFloat
780+
//===----------------------------------------------------------------------===//
781+
782+
def MathToAPFloatConversionPass
783+
: Pass<"convert-math-to-apfloat", "ModuleOp"> {
784+
let summary = "Convert Math ops to APFloat runtime library calls";
785+
let description = [{
786+
This pass converts supported Math ops to APFloat-based runtime library
787+
calls (APFloatWrappers.cpp). APFloat is a software implementation of
788+
floating-point mathmetic operations.
789+
}];
790+
let dependentDialects = ["math::MathDialect", "func::FuncDialect"];
791+
}
792+
778793
//===----------------------------------------------------------------------===//
779794
// MathToLLVM
780795
//===----------------------------------------------------------------------===//

mlir/include/mlir/Dialect/Func/Utils/Utils.h

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,22 @@ FailureOr<FuncOp> lookupFnDecl(SymbolOpInterface symTable, StringRef name,
6767
FunctionType funcT,
6868
SymbolTableCollection *symbolTables = nullptr);
6969

70+
/// Create a FuncOp decl and insert it into `symTable` operation. If
71+
/// `symbolTables` is provided, then the decl will be inserted into the
72+
/// SymbolTableCollection.
73+
FuncOp createFnDecl(OpBuilder &b, SymbolOpInterface symTable, StringRef name,
74+
FunctionType funcT, bool setPrivate,
75+
SymbolTableCollection *symbolTables = nullptr);
76+
77+
/// Helper function to look up or create the symbol for a runtime library
78+
/// function with the given parameter types. Returns an int64_t, unless a
79+
/// different result type is specified.
80+
FailureOr<FuncOp>
81+
lookupOrCreateFn(OpBuilder &b, SymbolOpInterface symTable, StringRef name,
82+
TypeRange paramTypes,
83+
SymbolTableCollection *symbolTables = nullptr,
84+
Type resultType = {});
85+
7086
} // namespace func
7187
} // namespace mlir
7288

mlir/lib/Conversion/ArithToAPFloat/ArithToAPFloat.cpp

Lines changed: 14 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -25,47 +25,6 @@ namespace mlir {
2525
using namespace mlir;
2626
using namespace mlir::func;
2727

28-
static FuncOp createFnDecl(OpBuilder &b, SymbolOpInterface symTable,
29-
StringRef name, FunctionType funcT, bool setPrivate,
30-
SymbolTableCollection *symbolTables = nullptr) {
31-
OpBuilder::InsertionGuard g(b);
32-
assert(!symTable->getRegion(0).empty() && "expected non-empty region");
33-
b.setInsertionPointToStart(&symTable->getRegion(0).front());
34-
FuncOp funcOp = FuncOp::create(b, symTable->getLoc(), name, funcT);
35-
if (setPrivate)
36-
funcOp.setPrivate();
37-
if (symbolTables) {
38-
SymbolTable &symbolTable = symbolTables->getSymbolTable(symTable);
39-
symbolTable.insert(funcOp, symTable->getRegion(0).front().begin());
40-
}
41-
return funcOp;
42-
}
43-
44-
/// Helper function to look up or create the symbol for a runtime library
45-
/// function with the given parameter types. Returns an int64_t, unless a
46-
/// different result type is specified.
47-
static FailureOr<FuncOp>
48-
lookupOrCreateApFloatFn(OpBuilder &b, SymbolOpInterface symTable,
49-
StringRef name, TypeRange paramTypes,
50-
SymbolTableCollection *symbolTables = nullptr,
51-
Type resultType = {}) {
52-
if (!resultType)
53-
resultType = IntegerType::get(symTable->getContext(), 64);
54-
std::string funcName = (llvm::Twine("_mlir_apfloat_") + name).str();
55-
auto funcT = FunctionType::get(b.getContext(), paramTypes, {resultType});
56-
FailureOr<FuncOp> func =
57-
lookupFnDecl(symTable, funcName, funcT, symbolTables);
58-
// Failed due to type mismatch.
59-
if (failed(func))
60-
return func;
61-
// Successfully matched existing decl.
62-
if (*func)
63-
return *func;
64-
65-
return createFnDecl(b, symTable, funcName, funcT,
66-
/*setPrivate=*/true, symbolTables);
67-
}
68-
6928
/// Helper function to look up or create the symbol for a runtime library
7029
/// function for a binary arithmetic operation.
7130
///
@@ -81,8 +40,9 @@ lookupOrCreateBinaryFn(OpBuilder &b, SymbolOpInterface symTable, StringRef name,
8140
SymbolTableCollection *symbolTables = nullptr) {
8241
auto i32Type = IntegerType::get(symTable->getContext(), 32);
8342
auto i64Type = IntegerType::get(symTable->getContext(), 64);
84-
return lookupOrCreateApFloatFn(b, symTable, name, {i32Type, i64Type, i64Type},
85-
symbolTables);
43+
std::string funcName = (llvm::Twine("_mlir_apfloat_") + name).str();
44+
return lookupOrCreateFn(b, symTable, funcName, {i32Type, i64Type, i64Type},
45+
symbolTables);
8646
}
8747

8848
static Value getSemanticsValue(OpBuilder &b, Location loc, FloatType floatTy) {
@@ -231,8 +191,9 @@ struct FpToFpConversion final : OpRewritePattern<OpTy> {
231191
// Get APFloat function from runtime library.
232192
auto i32Type = IntegerType::get(symTable->getContext(), 32);
233193
auto i64Type = IntegerType::get(symTable->getContext(), 64);
234-
FailureOr<FuncOp> fn = lookupOrCreateApFloatFn(
235-
rewriter, symTable, "convert", {i32Type, i32Type, i64Type});
194+
FailureOr<FuncOp> fn =
195+
lookupOrCreateFn(rewriter, symTable, "_mlir_apfloat_convert",
196+
{i32Type, i32Type, i64Type});
236197
if (failed(fn))
237198
return fn;
238199

@@ -289,8 +250,8 @@ struct FpToIntConversion final : OpRewritePattern<OpTy> {
289250
auto i32Type = IntegerType::get(symTable->getContext(), 32);
290251
auto i64Type = IntegerType::get(symTable->getContext(), 64);
291252
FailureOr<FuncOp> fn =
292-
lookupOrCreateApFloatFn(rewriter, symTable, "convert_to_int",
293-
{i32Type, i32Type, i1Type, i64Type});
253+
lookupOrCreateFn(rewriter, symTable, "_mlir_apfloat_convert_to_int",
254+
{i32Type, i32Type, i1Type, i64Type});
294255
if (failed(fn))
295256
return fn;
296257

@@ -351,8 +312,8 @@ struct IntToFpConversion final : OpRewritePattern<OpTy> {
351312
auto i32Type = IntegerType::get(symTable->getContext(), 32);
352313
auto i64Type = IntegerType::get(symTable->getContext(), 64);
353314
FailureOr<FuncOp> fn =
354-
lookupOrCreateApFloatFn(rewriter, symTable, "convert_from_int",
355-
{i32Type, i32Type, i1Type, i64Type});
315+
lookupOrCreateFn(rewriter, symTable, "_mlir_apfloat_convert_from_int",
316+
{i32Type, i32Type, i1Type, i64Type});
356317
if (failed(fn))
357318
return fn;
358319

@@ -421,8 +382,8 @@ struct CmpFOpToAPFloatConversion final : OpRewritePattern<arith::CmpFOp> {
421382
auto i32Type = IntegerType::get(symTable->getContext(), 32);
422383
auto i64Type = IntegerType::get(symTable->getContext(), 64);
423384
FailureOr<FuncOp> fn =
424-
lookupOrCreateApFloatFn(rewriter, symTable, "compare",
425-
{i32Type, i64Type, i64Type}, nullptr, i8Type);
385+
lookupOrCreateFn(rewriter, symTable, "_mlir_apfloat_compare",
386+
{i32Type, i64Type, i64Type}, nullptr, i8Type);
426387
if (failed(fn))
427388
return fn;
428389

@@ -569,8 +530,8 @@ struct NegFOpToAPFloatConversion final : OpRewritePattern<arith::NegFOp> {
569530
// Get APFloat function from runtime library.
570531
auto i32Type = IntegerType::get(symTable->getContext(), 32);
571532
auto i64Type = IntegerType::get(symTable->getContext(), 64);
572-
FailureOr<FuncOp> fn =
573-
lookupOrCreateApFloatFn(rewriter, symTable, "neg", {i32Type, i64Type});
533+
FailureOr<FuncOp> fn = lookupOrCreateFn(
534+
rewriter, symTable, "_mlir_apfloat_neg", {i32Type, i64Type});
574535
if (failed(fn))
575536
return fn;
576537

mlir/lib/Conversion/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ add_subdirectory(IndexToLLVM)
3535
add_subdirectory(IndexToSPIRV)
3636
add_subdirectory(LinalgToStandard)
3737
add_subdirectory(LLVMCommon)
38+
add_subdirectory(MathToAPFloat)
3839
add_subdirectory(MathToEmitC)
3940
add_subdirectory(MathToFuncs)
4041
add_subdirectory(MathToLibm)
Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
add_mlir_conversion_library(MLIRMathToAPFloat
2+
MathToAPFloat.cpp
3+
4+
ADDITIONAL_HEADER_DIRS
5+
${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/MathToLLVM
6+
7+
DEPENDS
8+
MLIRConversionPassIncGen
9+
10+
LINK_COMPONENTS
11+
Core
12+
13+
LINK_LIBS PUBLIC
14+
MLIRMathDialect
15+
MLIRFuncDialect
16+
MLIRFuncUtils
17+
)
Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
//===- MathToAPFloat.cpp - Mathmetic to APFloat Conversion ----------------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#include "mlir/Conversion/MathToAPFloat/MathToAPFloat.h"
10+
11+
#include "mlir/Dialect/Func/IR/FuncOps.h"
12+
#include "mlir/Dialect/Func/Utils/Utils.h"
13+
#include "mlir/Dialect/Math/IR/Math.h"
14+
#include "mlir/Dialect/Math/Transforms/Passes.h"
15+
#include "mlir/Dialect/Vector/IR/VectorOps.h"
16+
#include "mlir/IR/PatternMatch.h"
17+
#include "mlir/IR/Verifier.h"
18+
#include "mlir/Transforms/WalkPatternRewriteDriver.h"
19+
20+
namespace mlir {
21+
#define GEN_PASS_DEF_MATHTOAPFLOATCONVERSIONPASS
22+
#include "mlir/Conversion/Passes.h.inc"
23+
} // namespace mlir
24+
25+
using namespace mlir;
26+
using namespace mlir::func;
27+
28+
namespace {
29+
struct MathToAPFloatConversionPass final
30+
: impl::MathToAPFloatConversionPassBase<MathToAPFloatConversionPass> {
31+
using Base::Base;
32+
33+
void runOnOperation() override;
34+
};
35+
36+
void MathToAPFloatConversionPass::runOnOperation() {
37+
MLIRContext *context = &getContext();
38+
RewritePatternSet patterns(context);
39+
LogicalResult result = success();
40+
ScopedDiagnosticHandler scopedHandler(context, [&result](Diagnostic &diag) {
41+
if (diag.getSeverity() == DiagnosticSeverity::Error) {
42+
result = failure();
43+
}
44+
// NB: if you don't return failure, no other diag handlers will fire (see
45+
// mlir/lib/IR/Diagnostics.cpp:DiagnosticEngineImpl::emit).
46+
return failure();
47+
});
48+
walkAndApplyPatterns(getOperation(), std::move(patterns));
49+
if (failed(result))
50+
return signalPassFailure();
51+
}
52+
} // namespace

mlir/lib/Dialect/Func/Utils/Utils.cpp

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -279,3 +279,41 @@ func::lookupFnDecl(SymbolOpInterface symTable, StringRef name,
279279
}
280280
return func;
281281
}
282+
283+
func::FuncOp func::createFnDecl(OpBuilder &b, SymbolOpInterface symTable,
284+
StringRef name, FunctionType funcT,
285+
bool setPrivate,
286+
SymbolTableCollection *symbolTables) {
287+
OpBuilder::InsertionGuard g(b);
288+
assert(!symTable->getRegion(0).empty() && "expected non-empty region");
289+
b.setInsertionPointToStart(&symTable->getRegion(0).front());
290+
func::FuncOp funcOp =
291+
func::FuncOp::create(b, symTable->getLoc(), name, funcT);
292+
if (setPrivate)
293+
funcOp.setPrivate();
294+
if (symbolTables) {
295+
SymbolTable &symbolTable = symbolTables->getSymbolTable(symTable);
296+
symbolTable.insert(funcOp, symTable->getRegion(0).front().begin());
297+
}
298+
return funcOp;
299+
}
300+
301+
FailureOr<func::FuncOp>
302+
func::lookupOrCreateFn(OpBuilder &b, SymbolOpInterface symTable,
303+
StringRef funcName, TypeRange paramTypes,
304+
SymbolTableCollection *symbolTables, Type resultType) {
305+
if (!resultType)
306+
resultType = IntegerType::get(symTable->getContext(), 64);
307+
auto funcT = FunctionType::get(b.getContext(), paramTypes, {resultType});
308+
FailureOr<func::FuncOp> func =
309+
lookupFnDecl(symTable, funcName, funcT, symbolTables);
310+
// Failed due to type mismatch.
311+
if (failed(func))
312+
return func;
313+
// Successfully matched existing decl.
314+
if (*func)
315+
return *func;
316+
317+
return createFnDecl(b, symTable, funcName, funcT,
318+
/*setPrivate=*/true, symbolTables);
319+
}

0 commit comments

Comments
 (0)