Skip to content

Commit 4ce6f24

Browse files
committed
add arith-to-apfloat
1 parent 3680b11 commit 4ce6f24

File tree

9 files changed

+85
-89
lines changed

9 files changed

+85
-89
lines changed

mlir/include/mlir/Conversion/Passes.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
#include "mlir/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.h"
1313
#include "mlir/Conversion/AffineToStandard/AffineToStandard.h"
1414
#include "mlir/Conversion/ArithToAMDGPU/ArithToAMDGPU.h"
15+
#include "mlir/Conversion/ArithToAPFloat/ArithToAPFloat.h"
1516
#include "mlir/Conversion/ArithToArmSME/ArithToArmSME.h"
1617
#include "mlir/Conversion/ArithToEmitC/ArithToEmitCPass.h"
1718
#include "mlir/Conversion/ArithToLLVM/ArithToLLVM.h"

mlir/include/mlir/Conversion/Passes.td

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -186,6 +186,19 @@ def ArithToLLVMConversionPass : Pass<"convert-arith-to-llvm"> {
186186
];
187187
}
188188

189+
//===----------------------------------------------------------------------===//
190+
// ArithToAPFloat
191+
//===----------------------------------------------------------------------===//
192+
193+
def ArithToAPFloatConversionPass : Pass<"convert-arith-to-apfloat"> {
194+
let summary = "Convert Arith dialect ops on FP8 types to APFloat lib calls";
195+
let description = [{
196+
This pass converts supported Arith ops which manipulate FP8 typed values to APFloat lib calls.
197+
}];
198+
let dependentDialects = ["func::FuncDialect"];
199+
let options = [];
200+
}
201+
189202
//===----------------------------------------------------------------------===//
190203
// ArithToSPIRV
191204
//===----------------------------------------------------------------------===//

mlir/include/mlir/Dialect/Func/Utils/Utils.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,14 @@ mlir::FailureOr<std::pair<mlir::func::FuncOp, mlir::func::CallOp>>
6060
deduplicateArgsOfFuncOp(mlir::RewriterBase &rewriter, mlir::func::FuncOp funcOp,
6161
mlir::ModuleOp moduleOp);
6262

63+
/// Create a FuncOp with signature `resultTypes`(`paramTypes`)` and name `name`.
64+
/// Return a failure if the FuncOp found has unexpected signature.
65+
FailureOr<FuncOp>
66+
lookupOrCreateFn(OpBuilder &b, Operation *moduleOp, StringRef name,
67+
ArrayRef<Type> paramTypes = {},
68+
ArrayRef<Type> resultTypes = {}, bool setPrivate = false,
69+
SymbolTableCollection *symbolTables = nullptr);
70+
6371
} // namespace func
6472
} // namespace mlir
6573

mlir/include/mlir/Dialect/LLVMIR/FunctionCallUtils.h

Lines changed: 0 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -56,23 +56,6 @@ FailureOr<LLVM::LLVMFuncOp>
5656
lookupOrCreateApFloatPrintFn(OpBuilder &b, Operation *moduleOp,
5757
SymbolTableCollection *symbolTables = nullptr);
5858

59-
#define APFLOAT_BIN_OPS(X) \
60-
X(add) \
61-
X(subtract) \
62-
X(multiply) \
63-
X(divide) \
64-
X(remainder) \
65-
X(mod)
66-
67-
#define LOOKUP_OR_CREATE_APFLOAT_FN_DECL(OP) \
68-
FailureOr<LLVM::LLVMFuncOp> lookupOrCreateApFloat##OP##Fn( \
69-
OpBuilder &b, Operation *moduleOp, \
70-
SymbolTableCollection *symbolTables = nullptr);
71-
72-
APFLOAT_BIN_OPS(LOOKUP_OR_CREATE_APFLOAT_FN_DECL)
73-
74-
#undef LOOKUP_OR_CREATE_APFLOAT_FN_DECL
75-
7659
/// Declares a function to print a C-string.
7760
/// If a custom runtime function is defined via `runtimeFunctionName`, it must
7861
/// have the signature void(char const*). The default function is `printString`.

mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp

Lines changed: 0 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -573,53 +573,6 @@ void mlir::arith::registerConvertArithToLLVMInterface(
573573
});
574574
}
575575

576-
struct FancyAddFLowering : public ConvertOpToLLVMPattern<arith::AddFOp> {
577-
using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
578-
579-
LogicalResult
580-
matchAndRewrite(arith::AddFOp op, OpAdaptor adaptor,
581-
ConversionPatternRewriter &rewriter) const override {
582-
// Get APFloat adder function from runtime library.
583-
auto parent = op->getParentOfType<ModuleOp>();
584-
if (!parent)
585-
return failure();
586-
if (!llvm::isa<Float8E5M2Type, Float8E4M3Type, Float8E4M3FNType,
587-
Float8E5M2FNUZType, Float8E4M3FNUZType,
588-
Float8E4M3B11FNUZType, Float8E3M4Type, Float4E2M1FNType,
589-
Float6E2M3FNType, Float6E3M2FNType, Float8E8M0FNUType>(
590-
op.getType()))
591-
return failure();
592-
auto floatTy = cast<FloatType>(op.getType());
593-
FailureOr<Operation *> adder =
594-
LLVM::lookupOrCreateApFloatAddFFn(rewriter, parent);
595-
596-
// Cast operands to 64-bit integers.
597-
Location loc = op.getLoc();
598-
Value lhsBits = LLVM::ZExtOp::create(rewriter, loc, rewriter.getI64Type(),
599-
adaptor.getLhs());
600-
Value rhsBits = LLVM::ZExtOp::create(rewriter, loc, rewriter.getI64Type(),
601-
adaptor.getRhs());
602-
603-
// Call software implementation of floating point addition.
604-
int32_t sem =
605-
llvm::APFloatBase::SemanticsToEnum(floatTy.getFloatSemantics());
606-
Value semValue = LLVM::ConstantOp::create(
607-
rewriter, loc, rewriter.getI32Type(),
608-
rewriter.getIntegerAttr(rewriter.getI32Type(), sem));
609-
SmallVector<Value> params = {semValue, lhsBits, rhsBits};
610-
auto resultOp =
611-
LLVM::CallOp::create(rewriter, loc, TypeRange(rewriter.getI64Type()),
612-
SymbolRefAttr::get(*adder), params);
613-
614-
// Truncate result to the original width.
615-
Value truncatedBits = LLVM::TruncOp::create(
616-
rewriter, loc, rewriter.getIntegerType(floatTy.getWidth()),
617-
resultOp->getResult(0));
618-
rewriter.replaceOp(op, truncatedBits);
619-
return success();
620-
}
621-
};
622-
623576
//===----------------------------------------------------------------------===//
624577
// Pattern Population
625578
//===----------------------------------------------------------------------===//
@@ -635,7 +588,6 @@ void mlir::arith::populateArithToLLVMConversionPatterns(
635588
// clang-format off
636589
patterns.add<
637590
AddFOpLowering,
638-
FancyAddFLowering,
639591
AddIOpLowering,
640592
AndIOpLowering,
641593
AddUIExtendedOpLowering,

mlir/lib/Conversion/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ add_subdirectory(AffineToStandard)
22
add_subdirectory(AMDGPUToROCDL)
33
add_subdirectory(ArithCommon)
44
add_subdirectory(ArithToAMDGPU)
5+
add_subdirectory(ArithToAPFloat)
56
add_subdirectory(ArithToArmSME)
67
add_subdirectory(ArithToEmitC)
78
add_subdirectory(ArithToLLVM)

mlir/lib/Dialect/Func/Utils/Utils.cpp

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -254,3 +254,45 @@ func::deduplicateArgsOfFuncOp(RewriterBase &rewriter, func::FuncOp funcOp,
254254

255255
return std::make_pair(*newFuncOpOrFailure, newCallOp);
256256
}
257+
258+
FailureOr<func::FuncOp>
259+
func::lookupOrCreateFn(OpBuilder &b, Operation *moduleOp, StringRef name,
260+
ArrayRef<Type> paramTypes, ArrayRef<Type> resultTypes,
261+
bool setPrivate, SymbolTableCollection *symbolTables) {
262+
assert(moduleOp->hasTrait<OpTrait::SymbolTable>() &&
263+
"expected SymbolTable operation");
264+
265+
FuncOp func;
266+
if (symbolTables) {
267+
func = symbolTables->lookupSymbolIn<FuncOp>(
268+
moduleOp, StringAttr::get(moduleOp->getContext(), name));
269+
} else {
270+
func = llvm::dyn_cast_or_null<FuncOp>(
271+
SymbolTable::lookupSymbolIn(moduleOp, name));
272+
}
273+
274+
FunctionType funcT =
275+
FunctionType::get(b.getContext(), paramTypes, resultTypes);
276+
// Assert the signature of the found function is same as expected
277+
if (func) {
278+
if (funcT != func.getFunctionType()) {
279+
func.emitError("redefinition of function '")
280+
<< name << "' of different type " << funcT << " is prohibited";
281+
return failure();
282+
}
283+
return func;
284+
}
285+
286+
OpBuilder::InsertionGuard g(b);
287+
assert(!moduleOp->getRegion(0).empty() && "expected non-empty region");
288+
b.setInsertionPointToStart(&moduleOp->getRegion(0).front());
289+
FuncOp funcOp = FuncOp::create(b, moduleOp->getLoc(), name, funcT);
290+
if (setPrivate)
291+
funcOp.setPrivate();
292+
if (symbolTables) {
293+
SymbolTable &symbolTable = symbolTables->getSymbolTable(moduleOp);
294+
symbolTable.insert(funcOp, moduleOp->getRegion(0).front().begin());
295+
}
296+
297+
return funcOp;
298+
}

mlir/lib/Dialect/LLVMIR/IR/FunctionCallUtils.cpp

Lines changed: 0 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -31,14 +31,6 @@ static constexpr llvm::StringRef kPrintBF16 = "printBF16";
3131
static constexpr llvm::StringRef kPrintF32 = "printF32";
3232
static constexpr llvm::StringRef kPrintF64 = "printF64";
3333
static constexpr llvm::StringRef kPrintApFloat = "printApFloat";
34-
35-
#define APFLOAT_EXTERN_K(OP) kApFloat_##OP
36-
37-
#define APFLOAT_EXTERN_NAME(OP) \
38-
static constexpr llvm::StringRef APFLOAT_EXTERN_K(OP) = "APFloat_" #OP;
39-
40-
APFLOAT_BIN_OPS(APFLOAT_EXTERN_NAME)
41-
4234
static constexpr llvm::StringRef kPrintString = "printString";
4335
static constexpr llvm::StringRef kPrintOpen = "printOpen";
4436
static constexpr llvm::StringRef kPrintClose = "printClose";
@@ -179,21 +171,6 @@ mlir::LLVM::lookupOrCreateApFloatPrintFn(OpBuilder &b, Operation *moduleOp,
179171
LLVM::LLVMVoidType::get(moduleOp->getContext()), symbolTables);
180172
}
181173

182-
#define LOOKUP_OR_CREATE_APFLOAT_FN_DEFN(OP) \
183-
FailureOr<LLVM::LLVMFuncOp> mlir::LLVM::lookupOrCreateApFloat##OP##Fn( \
184-
OpBuilder &b, Operation *moduleOp, \
185-
SymbolTableCollection *symbolTables) { \
186-
return lookupOrCreateReservedFn( \
187-
b, moduleOp, APFLOAT_EXTERN_K(OP), \
188-
{IntegerType::get(moduleOp->getContext(), 32), \
189-
IntegerType::get(moduleOp->getContext(), 64), \
190-
IntegerType::get(moduleOp->getContext(), 64)}, \
191-
IntegerType::get(moduleOp->getContext(), 64), symbolTables); \
192-
}
193-
194-
APFLOAT_BIN_OPS(LOOKUP_OR_CREATE_APFLOAT_FN_DEFN)
195-
#undef LOOKUP_OR_CREATE_APFLOAT_FN_DEFN
196-
197174
static LLVM::LLVMPointerType getCharPtr(MLIRContext *context) {
198175
return LLVM::LLVMPointerType::get(context);
199176
}

mlir/test/Integration/Dialect/Arith/CPU/test-apfloat-emulation.mlir

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
// Check that the ceildivsi lowering is correct.
22
// We do not check any poison or UB values, as it is not possible to catch them.
33

4-
// RUN: mlir-opt %s --convert-to-llvm
4+
// RUN: mlir-opt %s --convert-arith-to-apfloat
55

66
// Put rhs into separate function so that it won't be constant-folded.
77
func.func @foo() -> f4E2M1FN {
@@ -17,3 +17,22 @@ func.func @entry() {
1717
return
1818
}
1919

20+
// CHECK-LABEL: func.func private @_mlir_apfloat_add(i32, i64, i64) -> i64
21+
22+
// CHECK-LABEL: func.func @foo() -> f4E2M1FN {
23+
// CHECK: %[[CONSTANT_0:.*]] = arith.constant 4.000000e+00 : f4E2M1FN
24+
// CHECK: return %[[CONSTANT_0]] : f4E2M1FN
25+
// CHECK: }
26+
27+
// CHECK-LABEL: func.func @entry() {
28+
// CHECK: %[[CONSTANT_0:.*]] = arith.constant 18 : i32
29+
// CHECK: %[[CONSTANT_1:.*]] = arith.constant 6 : i64
30+
// CHECK: %[[VAL_0:.*]] = call @foo() : () -> f4E2M1FN
31+
// CHECK: %[[BITCAST_0:.*]] = arith.bitcast %[[VAL_0]] : f4E2M1FN to i4
32+
// CHECK: %[[EXTUI_0:.*]] = arith.extui %[[BITCAST_0]] : i4 to i64
33+
// CHECK: %[[VAL_1:.*]] = call @_mlir_apfloat_add(%[[CONSTANT_0]], %[[EXTUI_0]], %[[CONSTANT_1]]) : (i32, i64, i64) -> i64
34+
// CHECK: %[[TRUNCI_0:.*]] = arith.trunci %[[VAL_1]] : i64 to i4
35+
// CHECK: %[[BITCAST_1:.*]] = arith.bitcast %[[TRUNCI_0]] : i4 to f4E2M1FN
36+
// CHECK: vector.print %[[BITCAST_1]] : f4E2M1FN
37+
// CHECK: return
38+
// CHECK: }

0 commit comments

Comments
 (0)