Skip to content

Commit 7899470

Browse files
[mlir][arith] Add support for extf, truncf to ArithToAPFloat (#169275)
Add support for `arith.extf` and `arith.truncf`. No support for custom rounding modes yet.
1 parent 2f8e712 commit 7899470

File tree

4 files changed

+130
-24
lines changed

4 files changed

+130
-24
lines changed

mlir/lib/Conversion/ArithToAPFloat/ArithToAPFloat.cpp

Lines changed: 81 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -41,24 +41,15 @@ static FuncOp createFnDecl(OpBuilder &b, SymbolOpInterface symTable,
4141
}
4242

4343
/// Helper function to look up or create the symbol for a runtime library
44-
/// function for a binary arithmetic operation.
45-
///
46-
/// Parameter 1: APFloat semantics
47-
/// Parameter 2: Left-hand side operand
48-
/// Parameter 3: Right-hand side operand
49-
///
50-
/// This function will return a failure if the function is found but has an
51-
/// unexpected signature.
52-
///
44+
/// function with the given parameter types. Always returns an int64_t.
5345
static FailureOr<FuncOp>
54-
lookupOrCreateBinaryFn(OpBuilder &b, SymbolOpInterface symTable, StringRef name,
55-
SymbolTableCollection *symbolTables = nullptr) {
56-
auto i32Type = IntegerType::get(symTable->getContext(), 32);
46+
lookupOrCreateApFloatFn(OpBuilder &b, SymbolOpInterface symTable,
47+
StringRef name, TypeRange paramTypes,
48+
SymbolTableCollection *symbolTables = nullptr) {
5749
auto i64Type = IntegerType::get(symTable->getContext(), 64);
5850

5951
std::string funcName = (llvm::Twine("_mlir_apfloat_") + name).str();
60-
FunctionType funcT =
61-
FunctionType::get(b.getContext(), {i32Type, i64Type, i64Type}, {i64Type});
52+
auto funcT = FunctionType::get(b.getContext(), paramTypes, {i64Type});
6253
FailureOr<FuncOp> func =
6354
lookupFnDecl(symTable, funcName, funcT, symbolTables);
6455
// Failed due to type mismatch.
@@ -72,6 +63,31 @@ lookupOrCreateBinaryFn(OpBuilder &b, SymbolOpInterface symTable, StringRef name,
7263
/*setPrivate=*/true, symbolTables);
7364
}
7465

66+
/// Helper function to look up or create the symbol for a runtime library
67+
/// function for a binary arithmetic operation.
68+
///
69+
/// Parameter 1: APFloat semantics
70+
/// Parameter 2: Left-hand side operand
71+
/// Parameter 3: Right-hand side operand
72+
///
73+
/// This function will return a failure if the function is found but has an
74+
/// unexpected signature.
75+
///
76+
static FailureOr<FuncOp>
77+
lookupOrCreateBinaryFn(OpBuilder &b, SymbolOpInterface symTable, StringRef name,
78+
SymbolTableCollection *symbolTables = nullptr) {
79+
auto i32Type = IntegerType::get(symTable->getContext(), 32);
80+
auto i64Type = IntegerType::get(symTable->getContext(), 64);
81+
return lookupOrCreateApFloatFn(b, symTable, name, {i32Type, i64Type, i64Type},
82+
symbolTables);
83+
}
84+
85+
static Value getSemanticsValue(OpBuilder &b, Location loc, FloatType floatTy) {
86+
int32_t sem = llvm::APFloatBase::SemanticsToEnum(floatTy.getFloatSemantics());
87+
return arith::ConstantOp::create(b, loc, b.getI32Type(),
88+
b.getIntegerAttr(b.getI32Type(), sem));
89+
}
90+
7591
/// Rewrite a binary arithmetic operation to an APFloat function call.
7692
template <typename OpTy>
7793
struct BinaryArithOpToAPFloatConversion final : OpRewritePattern<OpTy> {
@@ -104,11 +120,7 @@ struct BinaryArithOpToAPFloatConversion final : OpRewritePattern<OpTy> {
104120
arith::BitcastOp::create(rewriter, loc, intWType, op.getRhs()));
105121

106122
// Call APFloat function.
107-
int32_t sem =
108-
llvm::APFloatBase::SemanticsToEnum(floatTy.getFloatSemantics());
109-
Value semValue = arith::ConstantOp::create(
110-
rewriter, loc, rewriter.getI32Type(),
111-
rewriter.getIntegerAttr(rewriter.getI32Type(), sem));
123+
Value semValue = getSemanticsValue(rewriter, loc, floatTy);
112124
SmallVector<Value> params = {semValue, lhsBits, rhsBits};
113125
auto resultOp =
114126
func::CallOp::create(rewriter, loc, TypeRange(rewriter.getI64Type()),
@@ -126,6 +138,53 @@ struct BinaryArithOpToAPFloatConversion final : OpRewritePattern<OpTy> {
126138
const char *APFloatName;
127139
};
128140

141+
template <typename OpTy>
142+
struct FpToFpConversion final : OpRewritePattern<OpTy> {
143+
FpToFpConversion(MLIRContext *context, SymbolOpInterface symTable,
144+
PatternBenefit benefit = 1)
145+
: OpRewritePattern<OpTy>(context, benefit), symTable(symTable) {}
146+
147+
LogicalResult matchAndRewrite(OpTy op,
148+
PatternRewriter &rewriter) const override {
149+
// Get APFloat function from runtime library.
150+
auto i32Type = IntegerType::get(symTable->getContext(), 32);
151+
auto i64Type = IntegerType::get(symTable->getContext(), 64);
152+
FailureOr<FuncOp> fn = lookupOrCreateApFloatFn(
153+
rewriter, symTable, "convert", {i32Type, i32Type, i64Type});
154+
if (failed(fn))
155+
return fn;
156+
157+
rewriter.setInsertionPoint(op);
158+
// Cast operands to 64-bit integers.
159+
Location loc = op.getLoc();
160+
auto inFloatTy = cast<FloatType>(op.getOperand().getType());
161+
auto inIntWType = rewriter.getIntegerType(inFloatTy.getWidth());
162+
auto int64Type = rewriter.getI64Type();
163+
Value operandBits = arith::ExtUIOp::create(
164+
rewriter, loc, int64Type,
165+
arith::BitcastOp::create(rewriter, loc, inIntWType, op.getOperand()));
166+
167+
// Call APFloat function.
168+
Value inSemValue = getSemanticsValue(rewriter, loc, inFloatTy);
169+
auto outFloatTy = cast<FloatType>(op.getType());
170+
Value outSemValue = getSemanticsValue(rewriter, loc, outFloatTy);
171+
std::array<Value, 3> params = {inSemValue, outSemValue, operandBits};
172+
auto resultOp =
173+
func::CallOp::create(rewriter, loc, TypeRange(rewriter.getI64Type()),
174+
SymbolRefAttr::get(*fn), params);
175+
176+
// Truncate result to the original width.
177+
auto outIntWType = rewriter.getIntegerType(outFloatTy.getWidth());
178+
Value truncatedBits = arith::TruncIOp::create(rewriter, loc, outIntWType,
179+
resultOp->getResult(0));
180+
rewriter.replaceOp(
181+
op, arith::BitcastOp::create(rewriter, loc, outFloatTy, truncatedBits));
182+
return success();
183+
}
184+
185+
SymbolOpInterface symTable;
186+
};
187+
129188
namespace {
130189
struct ArithToAPFloatConversionPass final
131190
: impl::ArithToAPFloatConversionPassBase<ArithToAPFloatConversionPass> {
@@ -147,6 +206,9 @@ void ArithToAPFloatConversionPass::runOnOperation() {
147206
context, "divide", getOperation());
148207
patterns.add<BinaryArithOpToAPFloatConversion<arith::RemFOp>>(
149208
context, "remainder", getOperation());
209+
patterns
210+
.add<FpToFpConversion<arith::ExtFOp>, FpToFpConversion<arith::TruncFOp>>(
211+
context, getOperation());
150212
LogicalResult result = success();
151213
ScopedDiagnosticHandler scopedHandler(context, [&result](Diagnostic &diag) {
152214
if (diag.getSeverity() == DiagnosticSeverity::Error) {

mlir/lib/ExecutionEngine/APFloatWrappers.cpp

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@
5151

5252
/// Binary operations with rounding mode.
5353
#define APFLOAT_BINARY_OP_ROUNDING_MODE(OP, ROUNDING_MODE) \
54-
MLIR_APFLOAT_WRAPPERS_EXPORT int64_t _mlir_apfloat_##OP( \
54+
MLIR_APFLOAT_WRAPPERS_EXPORT uint64_t _mlir_apfloat_##OP( \
5555
int32_t semantics, uint64_t a, uint64_t b) { \
5656
const llvm::fltSemantics &sem = llvm::APFloatBase::EnumToSemantics( \
5757
static_cast<llvm::APFloatBase::Semantics>(semantics)); \
@@ -86,4 +86,19 @@ MLIR_APFLOAT_WRAPPERS_EXPORT void printApFloat(int32_t semantics, uint64_t a) {
8686
double d = x.convertToDouble();
8787
fprintf(stdout, "%lg", d);
8888
}
89+
90+
MLIR_APFLOAT_WRAPPERS_EXPORT uint64_t
91+
_mlir_apfloat_convert(int32_t inSemantics, int32_t outSemantics, uint64_t a) {
92+
const llvm::fltSemantics &inSem = llvm::APFloatBase::EnumToSemantics(
93+
static_cast<llvm::APFloatBase::Semantics>(inSemantics));
94+
const llvm::fltSemantics &outSem = llvm::APFloatBase::EnumToSemantics(
95+
static_cast<llvm::APFloatBase::Semantics>(outSemantics));
96+
unsigned bitWidthIn = llvm::APFloatBase::semanticsSizeInBits(inSem);
97+
llvm::APFloat val(inSem, llvm::APInt(bitWidthIn, a));
98+
// TODO: Custom rounding modes are not supported yet.
99+
bool losesInfo;
100+
val.convert(outSem, llvm::RoundingMode::NearestTiesToEven, &losesInfo);
101+
llvm::APInt result = val.bitcastToAPInt();
102+
return result.getZExtValue();
103+
}
89104
}

mlir/test/Conversion/ArithToApfloat/arith-to-apfloat.mlir

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -126,3 +126,25 @@ func.func @remf(%arg0: f4E2M1FN, %arg1: f4E2M1FN) {
126126
%0 = arith.remf %arg0, %arg1 : f4E2M1FN
127127
return
128128
}
129+
130+
// -----
131+
132+
// CHECK: func.func private @_mlir_apfloat_convert(i32, i32, i64) -> i64
133+
// CHECK: %[[sem_in:.*]] = arith.constant 18 : i32
134+
// CHECK: %[[sem_out:.*]] = arith.constant 2 : i32
135+
// CHECK: call @_mlir_apfloat_convert(%[[sem_in]], %[[sem_out]], %{{.*}}) : (i32, i32, i64) -> i64
136+
func.func @extf(%arg0: f4E2M1FN) {
137+
%0 = arith.extf %arg0 : f4E2M1FN to f32
138+
return
139+
}
140+
141+
// -----
142+
143+
// CHECK: func.func private @_mlir_apfloat_convert(i32, i32, i64) -> i64
144+
// CHECK: %[[sem_in:.*]] = arith.constant 1 : i32
145+
// CHECK: %[[sem_out:.*]] = arith.constant 18 : i32
146+
// CHECK: call @_mlir_apfloat_convert(%[[sem_in]], %[[sem_out]], %{{.*}}) : (i32, i32, i64) -> i64
147+
func.func @truncf(%arg0: bf16) {
148+
%0 = arith.truncf %arg0 : bf16 to f4E2M1FN
149+
return
150+
}

mlir/test/Integration/Dialect/Arith/CPU/test-apfloat-emulation.mlir

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -27,14 +27,21 @@ func.func @entry() {
2727
%a1 = arith.constant 1.4 : f8E4M3FN
2828
%a2 = arith.constant 1.4 : f32
2929
%b1, %b2 = func.call @foo() : () -> (f8E4M3FN, f32)
30-
%c1 = arith.addf %a1, %b1 : f8E4M3FN // not supported by LLVM
31-
%c2 = arith.addf %a2, %b2 : f32 // supported by LLVM
3230

33-
// CHECK: 3.5
31+
// CHECK: 2.2
32+
vector.print %b2 : f32
33+
34+
// CHECK-NEXT: 3.5
35+
%c1 = arith.addf %a1, %b1 : f8E4M3FN // not supported by LLVM
3436
vector.print %c1 : f8E4M3FN
3537

36-
// CHECK: 3.6
38+
// CHECK-NEXT: 3.6
39+
%c2 = arith.addf %a2, %b2 : f32 // supported by LLVM
3740
vector.print %c2 : f32
3841

42+
// CHECK-NEXT: 2.25
43+
%cvt = arith.truncf %b2 : f32 to f8E4M3FN
44+
vector.print %cvt : f8E4M3FN
45+
3946
return
4047
}

0 commit comments

Comments
 (0)