@@ -159,9 +159,8 @@ struct FpToFpConversion final : OpRewritePattern<OpTy> {
159159 Location loc = op.getLoc ();
160160 auto inFloatTy = cast<FloatType>(op.getOperand ().getType ());
161161 auto inIntWType = rewriter.getIntegerType (inFloatTy.getWidth ());
162- auto int64Type = rewriter.getI64Type ();
163162 Value operandBits = arith::ExtUIOp::create (
164- rewriter, loc, int64Type ,
163+ rewriter, loc, i64Type ,
165164 arith::BitcastOp::create (rewriter, loc, inIntWType, op.getOperand ()));
166165
167166 // Call APFloat function.
@@ -185,6 +184,63 @@ struct FpToFpConversion final : OpRewritePattern<OpTy> {
185184 SymbolOpInterface symTable;
186185};
187186
187+ template <typename OpTy>
188+ struct FpToIntConversion final : OpRewritePattern<OpTy> {
189+ FpToIntConversion (MLIRContext *context, SymbolOpInterface symTable,
190+ bool isUnsigned, PatternBenefit benefit = 1 )
191+ : OpRewritePattern<OpTy>(context, benefit), symTable(symTable),
192+ isUnsigned (isUnsigned) {}
193+
194+ LogicalResult matchAndRewrite (OpTy op,
195+ PatternRewriter &rewriter) const override {
196+ if (op.getType ().getIntOrFloatBitWidth () > 64 )
197+ return rewriter.notifyMatchFailure (
198+ op, " result type > 64 bits is not supported" );
199+
200+ // Get APFloat function from runtime library.
201+ auto i1Type = IntegerType::get (symTable->getContext (), 1 );
202+ auto i32Type = IntegerType::get (symTable->getContext (), 32 );
203+ auto i64Type = IntegerType::get (symTable->getContext (), 64 );
204+ FailureOr<FuncOp> fn =
205+ lookupOrCreateApFloatFn (rewriter, symTable, " convert_to_int" ,
206+ {i32Type, i32Type, i1Type, i64Type});
207+ if (failed (fn))
208+ return fn;
209+
210+ rewriter.setInsertionPoint (op);
211+ // Cast operands to 64-bit integers.
212+ Location loc = op.getLoc ();
213+ auto inFloatTy = cast<FloatType>(op.getOperand ().getType ());
214+ auto inIntWType = rewriter.getIntegerType (inFloatTy.getWidth ());
215+ Value operandBits = arith::ExtUIOp::create (
216+ rewriter, loc, i64Type,
217+ arith::BitcastOp::create (rewriter, loc, inIntWType, op.getOperand ()));
218+
219+ // Call APFloat function.
220+ Value inSemValue = getSemanticsValue (rewriter, loc, inFloatTy);
221+ auto outIntTy = cast<IntegerType>(op.getType ());
222+ Value outWidthValue = arith::ConstantOp::create (
223+ rewriter, loc, i32Type,
224+ rewriter.getIntegerAttr (i32Type, outIntTy.getWidth ()));
225+ Value isUnsignedValue = arith::ConstantOp::create (
226+ rewriter, loc, i1Type, rewriter.getIntegerAttr (i1Type, isUnsigned));
227+ SmallVector<Value> params = {inSemValue, outWidthValue, isUnsignedValue,
228+ operandBits};
229+ auto resultOp =
230+ func::CallOp::create (rewriter, loc, TypeRange (rewriter.getI64Type ()),
231+ SymbolRefAttr::get (*fn), params);
232+
233+ // Truncate result to the original width.
234+ Value truncatedBits = arith::TruncIOp::create (rewriter, loc, outIntTy,
235+ resultOp->getResult (0 ));
236+ rewriter.replaceOp (op, truncatedBits);
237+ return success ();
238+ }
239+
240+ SymbolOpInterface symTable;
241+ bool isUnsigned;
242+ };
243+
188244namespace {
189245struct ArithToAPFloatConversionPass final
190246 : impl::ArithToAPFloatConversionPassBase<ArithToAPFloatConversionPass> {
@@ -209,6 +265,10 @@ void ArithToAPFloatConversionPass::runOnOperation() {
209265 patterns
210266 .add <FpToFpConversion<arith::ExtFOp>, FpToFpConversion<arith::TruncFOp>>(
211267 context, getOperation ());
268+ patterns.add <FpToIntConversion<arith::FPToSIOp>>(context, getOperation (),
269+ /* isUnsigned=*/ false );
270+ patterns.add <FpToIntConversion<arith::FPToUIOp>>(context, getOperation (),
271+ /* isUnsigned=*/ true );
212272 LogicalResult result = success ();
213273 ScopedDiagnosticHandler scopedHandler (context, [&result](Diagnostic &diag) {
214274 if (diag.getSeverity () == DiagnosticSeverity::Error) {
0 commit comments