@@ -449,6 +449,49 @@ struct CmpFOpToAPFloatConversion final : OpRewritePattern<arith::CmpFOp> {
449449 SymbolOpInterface symTable;
450450};
451451
452+ struct NegFOpToAPFloatConversion final : OpRewritePattern<arith::NegFOp> {
453+ NegFOpToAPFloatConversion (MLIRContext *context, SymbolOpInterface symTable,
454+ PatternBenefit benefit = 1 )
455+ : OpRewritePattern<arith::NegFOp>(context, benefit), symTable(symTable) {}
456+
457+ LogicalResult matchAndRewrite (arith::NegFOp op,
458+ PatternRewriter &rewriter) const override {
459+ // Get APFloat function from runtime library.
460+ auto i32Type = IntegerType::get (symTable->getContext (), 32 );
461+ auto i64Type = IntegerType::get (symTable->getContext (), 64 );
462+ FailureOr<FuncOp> fn =
463+ lookupOrCreateApFloatFn (rewriter, symTable, " neg" , {i32Type, i64Type});
464+ if (failed (fn))
465+ return fn;
466+
467+ // Cast operand to 64-bit integer.
468+ rewriter.setInsertionPoint (op);
469+ Location loc = op.getLoc ();
470+ auto floatTy = cast<FloatType>(op.getOperand ().getType ());
471+ auto intWType = rewriter.getIntegerType (floatTy.getWidth ());
472+ Value operandBits = arith::ExtUIOp::create (
473+ rewriter, loc, i64Type, arith::BitcastOp::create (rewriter, loc, intWType, op.getOperand ()));
474+
475+ // Call APFloat function.
476+ Value semValue = getSemanticsValue (rewriter, loc, floatTy);
477+ SmallVector<Value> params = {semValue, operandBits};
478+ Value negatedBits =
479+ func::CallOp::create (rewriter, loc, TypeRange (i64Type),
480+ SymbolRefAttr::get (*fn), params)
481+ ->getResult (0 );
482+
483+ // Truncate result to the original width.
484+ Value truncatedBits = arith::TruncIOp::create (rewriter, loc, intWType,
485+ negatedBits);
486+ Value result =
487+ arith::BitcastOp::create (rewriter, loc, floatTy, truncatedBits);
488+ rewriter.replaceOp (op, result);
489+ return success ();
490+ }
491+
492+ SymbolOpInterface symTable;
493+ };
494+
452495namespace {
453496struct ArithToAPFloatConversionPass final
454497 : impl::ArithToAPFloatConversionPassBase<ArithToAPFloatConversionPass> {
@@ -471,7 +514,8 @@ void ArithToAPFloatConversionPass::runOnOperation() {
471514 patterns.add <BinaryArithOpToAPFloatConversion<arith::RemFOp>>(
472515 context, " remainder" , getOperation ());
473516 patterns
474- .add <FpToFpConversion<arith::ExtFOp>, FpToFpConversion<arith::TruncFOp>>(
517+ .add <FpToFpConversion<arith::ExtFOp>, FpToFpConversion<arith::TruncFOp>,
518+ CmpFOpToAPFloatConversion, NegFOpToAPFloatConversion>(
475519 context, getOperation ());
476520 patterns.add <FpToIntConversion<arith::FPToSIOp>>(context, getOperation (),
477521 /* isUnsigned=*/ false );
@@ -481,7 +525,6 @@ void ArithToAPFloatConversionPass::runOnOperation() {
481525 /* isUnsigned=*/ false );
482526 patterns.add <IntToFpConversion<arith::UIToFPOp>>(context, getOperation (),
483527 /* isUnsigned=*/ true );
484- patterns.add <CmpFOpToAPFloatConversion>(context, getOperation ());
485528 LogicalResult result = success ();
486529 ScopedDiagnosticHandler scopedHandler (context, [&result](Diagnostic &diag) {
487530 if (diag.getSeverity () == DiagnosticSeverity::Error) {
0 commit comments