Skip to content

Commit 61821f2

Browse files
matthias-springerkcloudy0717
authored andcommitted
[mlir][arith] Add support for negf to ArithToAPFloat (llvm#169759)
Add support for `arith.negf`.
1 parent 96de28f commit 61821f2

File tree

4 files changed

+68
-2
lines changed

4 files changed

+68
-2
lines changed

mlir/lib/Conversion/ArithToAPFloat/ArithToAPFloat.cpp

Lines changed: 45 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -449,6 +449,49 @@ struct CmpFOpToAPFloatConversion final : OpRewritePattern<arith::CmpFOp> {
449449
SymbolOpInterface symTable;
450450
};
451451

452+
struct NegFOpToAPFloatConversion final : OpRewritePattern<arith::NegFOp> {
453+
NegFOpToAPFloatConversion(MLIRContext *context, SymbolOpInterface symTable,
454+
PatternBenefit benefit = 1)
455+
: OpRewritePattern<arith::NegFOp>(context, benefit), symTable(symTable) {}
456+
457+
LogicalResult matchAndRewrite(arith::NegFOp op,
458+
PatternRewriter &rewriter) const override {
459+
// Get APFloat function from runtime library.
460+
auto i32Type = IntegerType::get(symTable->getContext(), 32);
461+
auto i64Type = IntegerType::get(symTable->getContext(), 64);
462+
FailureOr<FuncOp> fn =
463+
lookupOrCreateApFloatFn(rewriter, symTable, "neg", {i32Type, i64Type});
464+
if (failed(fn))
465+
return fn;
466+
467+
// Cast operand to 64-bit integer.
468+
rewriter.setInsertionPoint(op);
469+
Location loc = op.getLoc();
470+
auto floatTy = cast<FloatType>(op.getOperand().getType());
471+
auto intWType = rewriter.getIntegerType(floatTy.getWidth());
472+
Value operandBits = arith::ExtUIOp::create(
473+
rewriter, loc, i64Type, arith::BitcastOp::create(rewriter, loc, intWType, op.getOperand()));
474+
475+
// Call APFloat function.
476+
Value semValue = getSemanticsValue(rewriter, loc, floatTy);
477+
SmallVector<Value> params = {semValue, operandBits};
478+
Value negatedBits =
479+
func::CallOp::create(rewriter, loc, TypeRange(i64Type),
480+
SymbolRefAttr::get(*fn), params)
481+
->getResult(0);
482+
483+
// Truncate result to the original width.
484+
Value truncatedBits = arith::TruncIOp::create(rewriter, loc, intWType,
485+
negatedBits);
486+
Value result =
487+
arith::BitcastOp::create(rewriter, loc, floatTy, truncatedBits);
488+
rewriter.replaceOp(op, result);
489+
return success();
490+
}
491+
492+
SymbolOpInterface symTable;
493+
};
494+
452495
namespace {
453496
struct ArithToAPFloatConversionPass final
454497
: impl::ArithToAPFloatConversionPassBase<ArithToAPFloatConversionPass> {
@@ -471,7 +514,8 @@ void ArithToAPFloatConversionPass::runOnOperation() {
471514
patterns.add<BinaryArithOpToAPFloatConversion<arith::RemFOp>>(
472515
context, "remainder", getOperation());
473516
patterns
474-
.add<FpToFpConversion<arith::ExtFOp>, FpToFpConversion<arith::TruncFOp>>(
517+
.add<FpToFpConversion<arith::ExtFOp>, FpToFpConversion<arith::TruncFOp>,
518+
CmpFOpToAPFloatConversion, NegFOpToAPFloatConversion>(
475519
context, getOperation());
476520
patterns.add<FpToIntConversion<arith::FPToSIOp>>(context, getOperation(),
477521
/*isUnsigned=*/false);
@@ -481,7 +525,6 @@ void ArithToAPFloatConversionPass::runOnOperation() {
481525
/*isUnsigned=*/false);
482526
patterns.add<IntToFpConversion<arith::UIToFPOp>>(context, getOperation(),
483527
/*isUnsigned=*/true);
484-
patterns.add<CmpFOpToAPFloatConversion>(context, getOperation());
485528
LogicalResult result = success();
486529
ScopedDiagnosticHandler scopedHandler(context, [&result](Diagnostic &diag) {
487530
if (diag.getSeverity() == DiagnosticSeverity::Error) {

mlir/lib/ExecutionEngine/APFloatWrappers.cpp

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -142,4 +142,13 @@ MLIR_APFLOAT_WRAPPERS_EXPORT int8_t _mlir_apfloat_compare(int32_t semantics,
142142
llvm::APFloat y(sem, llvm::APInt(bitWidth, b));
143143
return static_cast<int8_t>(x.compare(y));
144144
}
145+
146+
MLIR_APFLOAT_WRAPPERS_EXPORT uint64_t _mlir_apfloat_neg(int32_t semantics, uint64_t a) {
147+
const llvm::fltSemantics &sem = llvm::APFloatBase::EnumToSemantics(
148+
static_cast<llvm::APFloatBase::Semantics>(semantics));
149+
unsigned bitWidth = llvm::APFloatBase::semanticsSizeInBits(sem);
150+
llvm::APFloat x(sem, llvm::APInt(bitWidth, a));
151+
x.changeSign();
152+
return x.bitcastToAPInt().getZExtValue();
153+
}
145154
}

mlir/test/Conversion/ArithToApfloat/arith-to-apfloat.mlir

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -213,3 +213,13 @@ func.func @cmpf(%arg0: f4E2M1FN, %arg1: f4E2M1FN) {
213213
%0 = arith.cmpf "ult", %arg0, %arg1 : f4E2M1FN
214214
return
215215
}
216+
217+
// -----
218+
219+
// CHECK: func.func private @_mlir_apfloat_neg(i32, i64) -> i64
220+
// CHECK: %[[sem:.*]] = arith.constant 2 : i32
221+
// CHECK: %[[res:.*]] = call @_mlir_apfloat_neg(%[[sem]], %{{.*}}) : (i32, i64) -> i64
222+
func.func @negf(%arg0: f32) {
223+
%0 = arith.negf %arg0 : f32
224+
return
225+
}

mlir/test/Integration/Dialect/Arith/CPU/test-apfloat-emulation.mlir

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,10 @@ func.func @entry() {
4343
%cvt = arith.truncf %b2 : f32 to f8E4M3FN
4444
vector.print %cvt : f8E4M3FN
4545

46+
// CHECK-NEXT: -2.25
47+
%negated = arith.negf %cvt : f8E4M3FN
48+
vector.print %negated : f8E4M3FN
49+
4650
// CHECK-NEXT: 1
4751
%cmp1 = arith.cmpf "olt", %cvt, %c1 : f8E4M3FN
4852
vector.print %cmp1 : i1

0 commit comments

Comments
 (0)