Skip to content

Commit c42f471

Browse files
committed
add arith-to-apfloat
1 parent 3680b11 commit c42f471

File tree

13 files changed

+266
-111
lines changed

13 files changed

+266
-111
lines changed
Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
//===- ArithToAPFloat.h - Arith to APFloat impl conversion ---*- C++ ----*-===//
2+
//
3+
// Part of the APFloat Project, under the Apache License v2.0 with APFloat
4+
// Exceptions. See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH APFloat-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#ifndef MLIR_CONVERSION_ARITHTOAPFLOAT_ARITHTOAPFLOAT_H
10+
#define MLIR_CONVERSION_ARITHTOAPFLOAT_ARITHTOAPFLOAT_H
11+
12+
#include <memory>
13+
14+
namespace mlir {
15+
16+
class DialectRegistry;
17+
class RewritePatternSet;
18+
class Pass;
19+
20+
#define GEN_PASS_DECL_ARITHTOAPFLOATCONVERSIONPASS
21+
#include "mlir/Conversion/Passes.h.inc"
22+
23+
namespace arith {
24+
void populateArithToAPFloatConversionPatterns(RewritePatternSet &patterns);
25+
} // namespace arith
26+
} // namespace mlir
27+
28+
#endif // MLIR_CONVERSION_ARITHTOAPFloat_ARITHTOAPFloat_H

mlir/include/mlir/Conversion/Passes.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
#include "mlir/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.h"
1313
#include "mlir/Conversion/AffineToStandard/AffineToStandard.h"
1414
#include "mlir/Conversion/ArithToAMDGPU/ArithToAMDGPU.h"
15+
#include "mlir/Conversion/ArithToAPFloat/ArithToAPFloat.h"
1516
#include "mlir/Conversion/ArithToArmSME/ArithToArmSME.h"
1617
#include "mlir/Conversion/ArithToEmitC/ArithToEmitCPass.h"
1718
#include "mlir/Conversion/ArithToLLVM/ArithToLLVM.h"

mlir/include/mlir/Conversion/Passes.td

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -186,6 +186,19 @@ def ArithToLLVMConversionPass : Pass<"convert-arith-to-llvm"> {
186186
];
187187
}
188188

189+
//===----------------------------------------------------------------------===//
190+
// ArithToAPFloat
191+
//===----------------------------------------------------------------------===//
192+
193+
def ArithToAPFloatConversionPass : Pass<"convert-arith-to-apfloat"> {
194+
let summary = "Convert Arith dialect ops on FP8 types to APFloat lib calls";
195+
let description = [{
196+
This pass converts supported Arith ops which manipulate FP8 typed values to APFloat lib calls.
197+
}];
198+
let dependentDialects = ["func::FuncDialect"];
199+
let options = [];
200+
}
201+
189202
//===----------------------------------------------------------------------===//
190203
// ArithToSPIRV
191204
//===----------------------------------------------------------------------===//

mlir/include/mlir/Dialect/Func/Utils/Utils.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,14 @@ mlir::FailureOr<std::pair<mlir::func::FuncOp, mlir::func::CallOp>>
6060
deduplicateArgsOfFuncOp(mlir::RewriterBase &rewriter, mlir::func::FuncOp funcOp,
6161
mlir::ModuleOp moduleOp);
6262

63+
/// Create a FuncOp with signature `resultTypes`(`paramTypes`)` and name `name`.
64+
/// Return a failure if the FuncOp found has unexpected signature.
65+
FailureOr<FuncOp>
66+
lookupOrCreateFn(OpBuilder &b, Operation *moduleOp, StringRef name,
67+
ArrayRef<Type> paramTypes = {},
68+
ArrayRef<Type> resultTypes = {}, bool setPrivate = false,
69+
SymbolTableCollection *symbolTables = nullptr);
70+
6371
} // namespace func
6472
} // namespace mlir
6573

mlir/include/mlir/Dialect/LLVMIR/FunctionCallUtils.h

Lines changed: 0 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -56,23 +56,6 @@ FailureOr<LLVM::LLVMFuncOp>
5656
lookupOrCreateApFloatPrintFn(OpBuilder &b, Operation *moduleOp,
5757
SymbolTableCollection *symbolTables = nullptr);
5858

59-
#define APFLOAT_BIN_OPS(X) \
60-
X(add) \
61-
X(subtract) \
62-
X(multiply) \
63-
X(divide) \
64-
X(remainder) \
65-
X(mod)
66-
67-
#define LOOKUP_OR_CREATE_APFLOAT_FN_DECL(OP) \
68-
FailureOr<LLVM::LLVMFuncOp> lookupOrCreateApFloat##OP##Fn( \
69-
OpBuilder &b, Operation *moduleOp, \
70-
SymbolTableCollection *symbolTables = nullptr);
71-
72-
APFLOAT_BIN_OPS(LOOKUP_OR_CREATE_APFLOAT_FN_DECL)
73-
74-
#undef LOOKUP_OR_CREATE_APFLOAT_FN_DECL
75-
7659
/// Declares a function to print a C-string.
7760
/// If a custom runtime function is defined via `runtimeFunctionName`, it must
7861
/// have the signature void(char const*). The default function is `printString`.
Lines changed: 136 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,136 @@
1+
//===- ArithToAPFloat.cpp - Arithmetic to APFloat impl conversion ---------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#include "mlir/Conversion/ArithToAPFloat/ArithToAPFloat.h"
10+
11+
#include "mlir/Dialect/Arith/IR/Arith.h"
12+
#include "mlir/Dialect/Arith/Transforms/Passes.h"
13+
#include "mlir/Dialect/Func/IR/FuncOps.h"
14+
#include "mlir/Dialect/Func/Utils/Utils.h"
15+
#include "mlir/IR/Verifier.h"
16+
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
17+
18+
namespace mlir {
19+
#define GEN_PASS_DEF_ARITHTOAPFLOATCONVERSIONPASS
20+
#include "mlir/Conversion/Passes.h.inc"
21+
} // namespace mlir
22+
23+
using namespace mlir;
24+
using namespace mlir::func;
25+
26+
#define APFLOAT_BIN_OPS(X) \
27+
X(add) \
28+
X(subtract) \
29+
X(multiply) \
30+
X(divide) \
31+
X(remainder) \
32+
X(mod)
33+
34+
#define APFLOAT_EXTERN_K(OP) kApFloat_##OP
35+
36+
#define APFLOAT_EXTERN_NAME(OP) \
37+
static constexpr llvm::StringRef APFLOAT_EXTERN_K(OP) = "_mlir_" \
38+
"apfloat_" #OP;
39+
40+
namespace mlir::func {
41+
#define LOOKUP_OR_CREATE_APFLOAT_FN_DECL(OP) \
42+
FailureOr<FuncOp> lookupOrCreateApFloat##OP##Fn( \
43+
OpBuilder &b, Operation *moduleOp, \
44+
SymbolTableCollection *symbolTables = nullptr);
45+
46+
APFLOAT_BIN_OPS(LOOKUP_OR_CREATE_APFLOAT_FN_DECL)
47+
48+
#undef LOOKUP_OR_CREATE_APFLOAT_FN_DECL
49+
50+
APFLOAT_BIN_OPS(APFLOAT_EXTERN_NAME)
51+
52+
#define LOOKUP_OR_CREATE_APFLOAT_FN_DEFN(OP) \
53+
FailureOr<FuncOp> lookupOrCreateApFloat##OP##Fn( \
54+
OpBuilder &b, Operation *moduleOp, \
55+
SymbolTableCollection *symbolTables) { \
56+
return lookupOrCreateFn(b, moduleOp, APFLOAT_EXTERN_K(OP), \
57+
{IntegerType::get(moduleOp->getContext(), 32), \
58+
IntegerType::get(moduleOp->getContext(), 64), \
59+
IntegerType::get(moduleOp->getContext(), 64)}, \
60+
{IntegerType::get(moduleOp->getContext(), 64)}, \
61+
/*setPrivate*/ true, symbolTables); \
62+
}
63+
64+
APFLOAT_BIN_OPS(LOOKUP_OR_CREATE_APFLOAT_FN_DEFN)
65+
#undef LOOKUP_OR_CREATE_APFLOAT_FN_DEFN
66+
} // namespace mlir::func
67+
68+
struct FancyAddFLowering : OpRewritePattern<arith::AddFOp> {
69+
using OpRewritePattern::OpRewritePattern;
70+
71+
LogicalResult matchAndRewrite(arith::AddFOp op,
72+
PatternRewriter &rewriter) const override {
73+
// Get APFloat adder function from runtime library.
74+
auto parent = op->getParentOfType<ModuleOp>();
75+
if (!parent)
76+
return failure();
77+
if (!llvm::isa<Float8E5M2Type, Float8E4M3Type, Float8E4M3FNType,
78+
Float8E5M2FNUZType, Float8E4M3FNUZType,
79+
Float8E4M3B11FNUZType, Float8E3M4Type, Float4E2M1FNType,
80+
Float6E2M3FNType, Float6E3M2FNType, Float8E8M0FNUType>(
81+
op.getType()))
82+
return failure();
83+
FailureOr<Operation *> adder = lookupOrCreateApFloataddFn(rewriter, parent);
84+
85+
// Cast operands to 64-bit integers.
86+
Location loc = op.getLoc();
87+
auto floatTy = cast<FloatType>(op.getType());
88+
auto intWType = rewriter.getIntegerType(floatTy.getWidth());
89+
auto int64Type = rewriter.getI64Type();
90+
Value lhsBits = arith::ExtUIOp::create(
91+
rewriter, loc, int64Type,
92+
arith::BitcastOp::create(rewriter, loc, intWType, op.getLhs()));
93+
Value rhsBits = arith::ExtUIOp::create(
94+
rewriter, loc, int64Type,
95+
arith::BitcastOp::create(rewriter, loc, intWType, op.getRhs()));
96+
97+
// Call software implementation of floating point addition.
98+
int32_t sem =
99+
llvm::APFloatBase::SemanticsToEnum(floatTy.getFloatSemantics());
100+
Value semValue = arith::ConstantOp::create(
101+
rewriter, loc, rewriter.getI32Type(),
102+
rewriter.getIntegerAttr(rewriter.getI32Type(), sem));
103+
SmallVector<Value> params = {semValue, lhsBits, rhsBits};
104+
auto resultOp =
105+
func::CallOp::create(rewriter, loc, TypeRange(rewriter.getI64Type()),
106+
SymbolRefAttr::get(*adder), params);
107+
108+
// Truncate result to the original width.
109+
Value truncatedBits = arith::TruncIOp::create(rewriter, loc, intWType,
110+
resultOp->getResult(0));
111+
rewriter.replaceAllUsesWith(
112+
op, arith::BitcastOp::create(rewriter, loc, floatTy, truncatedBits));
113+
return success();
114+
}
115+
};
116+
117+
void arith::populateArithToAPFloatConversionPatterns(
118+
RewritePatternSet &patterns) {
119+
patterns.add<FancyAddFLowering>(patterns.getContext());
120+
}
121+
122+
namespace {
123+
struct ArithToAPFloatConversionPass final
124+
: impl::ArithToAPFloatConversionPassBase<ArithToAPFloatConversionPass> {
125+
using impl::ArithToAPFloatConversionPassBase<
126+
ArithToAPFloatConversionPass>::ArithToAPFloatConversionPassBase;
127+
128+
void runOnOperation() override {
129+
Operation *op = getOperation();
130+
RewritePatternSet patterns(op->getContext());
131+
arith::populateArithToAPFloatConversionPatterns(patterns);
132+
if (failed(applyPatternsGreedily(op, std::move(patterns))))
133+
return signalPassFailure();
134+
}
135+
};
136+
} // namespace
Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
add_mlir_conversion_library(MLIRArithToAPFloat
2+
ArithToAPFloat.cpp
3+
4+
ADDITIONAL_HEADER_DIRS
5+
${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/ArithToLLVM
6+
7+
DEPENDS
8+
MLIRConversionPassIncGen
9+
10+
LINK_COMPONENTS
11+
Core
12+
13+
LINK_LIBS PUBLIC
14+
MLIRArithDialect
15+
MLIRArithTransforms
16+
MLIRFuncDialect
17+
)

mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp

Lines changed: 0 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -573,53 +573,6 @@ void mlir::arith::registerConvertArithToLLVMInterface(
573573
});
574574
}
575575

576-
struct FancyAddFLowering : public ConvertOpToLLVMPattern<arith::AddFOp> {
577-
using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
578-
579-
LogicalResult
580-
matchAndRewrite(arith::AddFOp op, OpAdaptor adaptor,
581-
ConversionPatternRewriter &rewriter) const override {
582-
// Get APFloat adder function from runtime library.
583-
auto parent = op->getParentOfType<ModuleOp>();
584-
if (!parent)
585-
return failure();
586-
if (!llvm::isa<Float8E5M2Type, Float8E4M3Type, Float8E4M3FNType,
587-
Float8E5M2FNUZType, Float8E4M3FNUZType,
588-
Float8E4M3B11FNUZType, Float8E3M4Type, Float4E2M1FNType,
589-
Float6E2M3FNType, Float6E3M2FNType, Float8E8M0FNUType>(
590-
op.getType()))
591-
return failure();
592-
auto floatTy = cast<FloatType>(op.getType());
593-
FailureOr<Operation *> adder =
594-
LLVM::lookupOrCreateApFloatAddFFn(rewriter, parent);
595-
596-
// Cast operands to 64-bit integers.
597-
Location loc = op.getLoc();
598-
Value lhsBits = LLVM::ZExtOp::create(rewriter, loc, rewriter.getI64Type(),
599-
adaptor.getLhs());
600-
Value rhsBits = LLVM::ZExtOp::create(rewriter, loc, rewriter.getI64Type(),
601-
adaptor.getRhs());
602-
603-
// Call software implementation of floating point addition.
604-
int32_t sem =
605-
llvm::APFloatBase::SemanticsToEnum(floatTy.getFloatSemantics());
606-
Value semValue = LLVM::ConstantOp::create(
607-
rewriter, loc, rewriter.getI32Type(),
608-
rewriter.getIntegerAttr(rewriter.getI32Type(), sem));
609-
SmallVector<Value> params = {semValue, lhsBits, rhsBits};
610-
auto resultOp =
611-
LLVM::CallOp::create(rewriter, loc, TypeRange(rewriter.getI64Type()),
612-
SymbolRefAttr::get(*adder), params);
613-
614-
// Truncate result to the original width.
615-
Value truncatedBits = LLVM::TruncOp::create(
616-
rewriter, loc, rewriter.getIntegerType(floatTy.getWidth()),
617-
resultOp->getResult(0));
618-
rewriter.replaceOp(op, truncatedBits);
619-
return success();
620-
}
621-
};
622-
623576
//===----------------------------------------------------------------------===//
624577
// Pattern Population
625578
//===----------------------------------------------------------------------===//
@@ -635,7 +588,6 @@ void mlir::arith::populateArithToLLVMConversionPatterns(
635588
// clang-format off
636589
patterns.add<
637590
AddFOpLowering,
638-
FancyAddFLowering,
639591
AddIOpLowering,
640592
AndIOpLowering,
641593
AddUIExtendedOpLowering,

mlir/lib/Conversion/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ add_subdirectory(AffineToStandard)
22
add_subdirectory(AMDGPUToROCDL)
33
add_subdirectory(ArithCommon)
44
add_subdirectory(ArithToAMDGPU)
5+
add_subdirectory(ArithToAPFloat)
56
add_subdirectory(ArithToArmSME)
67
add_subdirectory(ArithToEmitC)
78
add_subdirectory(ArithToLLVM)

mlir/lib/Dialect/Func/Utils/Utils.cpp

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -254,3 +254,45 @@ func::deduplicateArgsOfFuncOp(RewriterBase &rewriter, func::FuncOp funcOp,
254254

255255
return std::make_pair(*newFuncOpOrFailure, newCallOp);
256256
}
257+
258+
FailureOr<func::FuncOp>
259+
func::lookupOrCreateFn(OpBuilder &b, Operation *moduleOp, StringRef name,
260+
ArrayRef<Type> paramTypes, ArrayRef<Type> resultTypes,
261+
bool setPrivate, SymbolTableCollection *symbolTables) {
262+
assert(moduleOp->hasTrait<OpTrait::SymbolTable>() &&
263+
"expected SymbolTable operation");
264+
265+
FuncOp func;
266+
if (symbolTables) {
267+
func = symbolTables->lookupSymbolIn<FuncOp>(
268+
moduleOp, StringAttr::get(moduleOp->getContext(), name));
269+
} else {
270+
func = llvm::dyn_cast_or_null<FuncOp>(
271+
SymbolTable::lookupSymbolIn(moduleOp, name));
272+
}
273+
274+
FunctionType funcT =
275+
FunctionType::get(b.getContext(), paramTypes, resultTypes);
276+
// Assert the signature of the found function is same as expected
277+
if (func) {
278+
if (funcT != func.getFunctionType()) {
279+
func.emitError("redefinition of function '")
280+
<< name << "' of different type " << funcT << " is prohibited";
281+
return failure();
282+
}
283+
return func;
284+
}
285+
286+
OpBuilder::InsertionGuard g(b);
287+
assert(!moduleOp->getRegion(0).empty() && "expected non-empty region");
288+
b.setInsertionPointToStart(&moduleOp->getRegion(0).front());
289+
FuncOp funcOp = FuncOp::create(b, moduleOp->getLoc(), name, funcT);
290+
if (setPrivate)
291+
funcOp.setPrivate();
292+
if (symbolTables) {
293+
SymbolTable &symbolTable = symbolTables->getSymbolTable(moduleOp);
294+
symbolTable.insert(funcOp, moduleOp->getRegion(0).front().begin());
295+
}
296+
297+
return funcOp;
298+
}

0 commit comments

Comments
 (0)