Skip to content

Commit 6ec6867

Browse files
[mlir][arith] Add support for sitofp, uitofp to ArithToAPFloat (#169284)
Add support for `arith.sitofp` and `arith.uitofp`.
1 parent 8217c64 commit 6ec6867

File tree

4 files changed

+120
-0
lines changed

4 files changed

+120
-0
lines changed

mlir/lib/Conversion/ArithToAPFloat/ArithToAPFloat.cpp

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -241,6 +241,73 @@ struct FpToIntConversion final : OpRewritePattern<OpTy> {
241241
bool isUnsigned;
242242
};
243243

244+
template <typename OpTy>
245+
struct IntToFpConversion final : OpRewritePattern<OpTy> {
246+
IntToFpConversion(MLIRContext *context, SymbolOpInterface symTable,
247+
bool isUnsigned, PatternBenefit benefit = 1)
248+
: OpRewritePattern<OpTy>(context, benefit), symTable(symTable),
249+
isUnsigned(isUnsigned) {}
250+
251+
LogicalResult matchAndRewrite(OpTy op,
252+
PatternRewriter &rewriter) const override {
253+
Location loc = op.getLoc();
254+
if (op.getIn().getType().getIntOrFloatBitWidth() > 64) {
255+
return rewriter.notifyMatchFailure(
256+
loc, "integer bitwidth > 64 is not supported");
257+
}
258+
259+
// Get APFloat function from runtime library.
260+
auto i1Type = IntegerType::get(symTable->getContext(), 1);
261+
auto i32Type = IntegerType::get(symTable->getContext(), 32);
262+
auto i64Type = IntegerType::get(symTable->getContext(), 64);
263+
FailureOr<FuncOp> fn =
264+
lookupOrCreateApFloatFn(rewriter, symTable, "convert_from_int",
265+
{i32Type, i32Type, i1Type, i64Type});
266+
if (failed(fn))
267+
return fn;
268+
269+
rewriter.setInsertionPoint(op);
270+
// Cast operands to 64-bit integers.
271+
auto inIntTy = cast<IntegerType>(op.getOperand().getType());
272+
Value operandBits = op.getOperand();
273+
if (operandBits.getType().getIntOrFloatBitWidth() < 64) {
274+
if (isUnsigned) {
275+
operandBits =
276+
arith::ExtUIOp::create(rewriter, loc, i64Type, operandBits);
277+
} else {
278+
operandBits =
279+
arith::ExtSIOp::create(rewriter, loc, i64Type, operandBits);
280+
}
281+
}
282+
283+
// Call APFloat function.
284+
auto outFloatTy = cast<FloatType>(op.getType());
285+
Value outSemValue = getSemanticsValue(rewriter, loc, outFloatTy);
286+
Value inWidthValue = arith::ConstantOp::create(
287+
rewriter, loc, i32Type,
288+
rewriter.getIntegerAttr(i32Type, inIntTy.getWidth()));
289+
Value isUnsignedValue = arith::ConstantOp::create(
290+
rewriter, loc, i1Type, rewriter.getIntegerAttr(i1Type, isUnsigned));
291+
SmallVector<Value> params = {outSemValue, inWidthValue, isUnsignedValue,
292+
operandBits};
293+
auto resultOp =
294+
func::CallOp::create(rewriter, loc, TypeRange(rewriter.getI64Type()),
295+
SymbolRefAttr::get(*fn), params);
296+
297+
// Truncate result to the original width.
298+
auto outIntWType = rewriter.getIntegerType(outFloatTy.getWidth());
299+
Value truncatedBits = arith::TruncIOp::create(rewriter, loc, outIntWType,
300+
resultOp->getResult(0));
301+
Value result =
302+
arith::BitcastOp::create(rewriter, loc, outFloatTy, truncatedBits);
303+
rewriter.replaceOp(op, result);
304+
return success();
305+
}
306+
307+
SymbolOpInterface symTable;
308+
bool isUnsigned;
309+
};
310+
244311
namespace {
245312
struct ArithToAPFloatConversionPass final
246313
: impl::ArithToAPFloatConversionPassBase<ArithToAPFloatConversionPass> {
@@ -269,6 +336,10 @@ void ArithToAPFloatConversionPass::runOnOperation() {
269336
/*isUnsigned=*/false);
270337
patterns.add<FpToIntConversion<arith::FPToUIOp>>(context, getOperation(),
271338
/*isUnsigned=*/true);
339+
patterns.add<IntToFpConversion<arith::SIToFPOp>>(context, getOperation(),
340+
/*isUnsigned=*/false);
341+
patterns.add<IntToFpConversion<arith::UIToFPOp>>(context, getOperation(),
342+
/*isUnsigned=*/true);
272343
LogicalResult result = success();
273344
ScopedDiagnosticHandler scopedHandler(context, [&result](Diagnostic &diag) {
274345
if (diag.getSeverity() == DiagnosticSeverity::Error) {

mlir/lib/ExecutionEngine/APFloatWrappers.cpp

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -119,4 +119,16 @@ MLIR_APFLOAT_WRAPPERS_EXPORT uint64_t _mlir_apfloat_convert_to_int(
119119
// result to the desired result width.
120120
return result.getZExtValue();
121121
}
122+
123+
MLIR_APFLOAT_WRAPPERS_EXPORT uint64_t _mlir_apfloat_convert_from_int(
124+
int32_t semantics, int32_t inputWidth, bool isUnsigned, uint64_t a) {
125+
llvm::APInt val(inputWidth, a, /*isSigned=*/!isUnsigned);
126+
const llvm::fltSemantics &sem = llvm::APFloatBase::EnumToSemantics(
127+
static_cast<llvm::APFloatBase::Semantics>(semantics));
128+
llvm::APFloat result(sem);
129+
// TODO: Custom rounding modes are not supported yet.
130+
result.convertFromAPInt(val, /*IsSigned=*/!isUnsigned,
131+
llvm::RoundingMode::NearestTiesToEven);
132+
return result.bitcastToAPInt().getZExtValue();
133+
}
122134
}

mlir/test/Conversion/ArithToApfloat/arith-to-apfloat.mlir

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -174,3 +174,27 @@ func.func @fptoui(%arg0: f16) {
174174
%0 = arith.fptoui %arg0 : f16 to i4
175175
return
176176
}
177+
178+
// -----
179+
180+
// CHECK: func.func private @_mlir_apfloat_convert_from_int(i32, i32, i1, i64) -> i64
181+
// CHECK: %[[sem_out:.*]] = arith.constant 18 : i32
182+
// CHECK: %[[in_width:.*]] = arith.constant 32 : i32
183+
// CHECK: %[[is_unsigned:.*]] = arith.constant false
184+
// CHECK: %[[res:.*]] = call @_mlir_apfloat_convert_from_int(%[[sem_out]], %[[in_width]], %[[is_unsigned]], %{{.*}}) : (i32, i32, i1, i64) -> i64
185+
func.func @sitofp(%arg0: i32) {
186+
%0 = arith.sitofp %arg0 : i32 to f4E2M1FN
187+
return
188+
}
189+
190+
// -----
191+
192+
// CHECK: func.func private @_mlir_apfloat_convert_from_int(i32, i32, i1, i64) -> i64
193+
// CHECK: %[[sem_out:.*]] = arith.constant 18 : i32
194+
// CHECK: %[[in_width:.*]] = arith.constant 32 : i32
195+
// CHECK: %[[is_unsigned:.*]] = arith.constant true
196+
// CHECK: %[[res:.*]] = call @_mlir_apfloat_convert_from_int(%[[sem_out]], %[[in_width]], %[[is_unsigned]], %{{.*}}) : (i32, i32, i1, i64) -> i64
197+
func.func @uitofp(%arg0: i32) {
198+
%0 = arith.uitofp %arg0 : i32 to f4E2M1FN
199+
return
200+
}

mlir/test/Integration/Dialect/Arith/CPU/test-apfloat-emulation.mlir

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,5 +53,18 @@ func.func @entry() {
5353
%cvt_int_unsigned = arith.fptoui %cvt : f8E4M3FN to i2
5454
vector.print %cvt_int_unsigned : i2
5555

56+
// CHECK-NEXT: -6
57+
// Bit pattern: 1...11110111, interpreted as signed: -9
58+
// Closest f4E2M1FN value: -6.0
59+
%c9 = arith.constant -9 : i16
60+
%cvt_from_signed_int = arith.sitofp %c9 : i16 to f4E2M1FN
61+
vector.print %cvt_from_signed_int : f4E2M1FN
62+
63+
// CHECK-NEXT: 6
64+
// Bit pattern: 1...11110111, interpreted as unsigned: 65527
65+
// Closest f4E2M1FN value: 6.0
66+
%cvt_from_unsigned_int = arith.uitofp %c9 : i16 to f4E2M1FN
67+
vector.print %cvt_from_unsigned_int : f4E2M1FN
68+
5669
return
5770
}

0 commit comments

Comments
 (0)