@@ -41,24 +41,15 @@ static FuncOp createFnDecl(OpBuilder &b, SymbolOpInterface symTable,
4141}
4242
4343// / Helper function to look up or create the symbol for a runtime library
44- // / function for a binary arithmetic operation.
45- // /
46- // / Parameter 1: APFloat semantics
47- // / Parameter 2: Left-hand side operand
48- // / Parameter 3: Right-hand side operand
49- // /
50- // / This function will return a failure if the function is found but has an
51- // / unexpected signature.
52- // /
44+ // / function with the given parameter types. Always returns an int64_t.
5345static FailureOr<FuncOp>
54- lookupOrCreateBinaryFn (OpBuilder &b, SymbolOpInterface symTable, StringRef name ,
55- SymbolTableCollection *symbolTables = nullptr ) {
56- auto i32Type = IntegerType::get (symTable-> getContext (), 32 );
46+ lookupOrCreateApFloatFn (OpBuilder &b, SymbolOpInterface symTable,
47+ StringRef name, TypeRange paramTypes,
48+ SymbolTableCollection *symbolTables = nullptr ) {
5749 auto i64Type = IntegerType::get (symTable->getContext (), 64 );
5850
5951 std::string funcName = (llvm::Twine (" _mlir_apfloat_" ) + name).str ();
60- FunctionType funcT =
61- FunctionType::get (b.getContext (), {i32Type, i64Type, i64Type}, {i64Type});
52+ auto funcT = FunctionType::get (b.getContext (), paramTypes, {i64Type});
6253 FailureOr<FuncOp> func =
6354 lookupFnDecl (symTable, funcName, funcT, symbolTables);
6455 // Failed due to type mismatch.
@@ -72,6 +63,31 @@ lookupOrCreateBinaryFn(OpBuilder &b, SymbolOpInterface symTable, StringRef name,
7263 /* setPrivate=*/ true , symbolTables);
7364}
7465
66+ // / Helper function to look up or create the symbol for a runtime library
67+ // / function for a binary arithmetic operation.
68+ // /
69+ // / Parameter 1: APFloat semantics
70+ // / Parameter 2: Left-hand side operand
71+ // / Parameter 3: Right-hand side operand
72+ // /
73+ // / This function will return a failure if the function is found but has an
74+ // / unexpected signature.
75+ // /
76+ static FailureOr<FuncOp>
77+ lookupOrCreateBinaryFn (OpBuilder &b, SymbolOpInterface symTable, StringRef name,
78+ SymbolTableCollection *symbolTables = nullptr ) {
79+ auto i32Type = IntegerType::get (symTable->getContext (), 32 );
80+ auto i64Type = IntegerType::get (symTable->getContext (), 64 );
81+ return lookupOrCreateApFloatFn (b, symTable, name, {i32Type, i64Type, i64Type},
82+ symbolTables);
83+ }
84+
85+ static Value getSemanticsValue (OpBuilder &b, Location loc, FloatType floatTy) {
86+ int32_t sem = llvm::APFloatBase::SemanticsToEnum (floatTy.getFloatSemantics ());
87+ return arith::ConstantOp::create (b, loc, b.getI32Type (),
88+ b.getIntegerAttr (b.getI32Type (), sem));
89+ }
90+
7591// / Rewrite a binary arithmetic operation to an APFloat function call.
7692template <typename OpTy>
7793struct BinaryArithOpToAPFloatConversion final : OpRewritePattern<OpTy> {
@@ -104,11 +120,7 @@ struct BinaryArithOpToAPFloatConversion final : OpRewritePattern<OpTy> {
104120 arith::BitcastOp::create (rewriter, loc, intWType, op.getRhs ()));
105121
106122 // Call APFloat function.
107- int32_t sem =
108- llvm::APFloatBase::SemanticsToEnum (floatTy.getFloatSemantics ());
109- Value semValue = arith::ConstantOp::create (
110- rewriter, loc, rewriter.getI32Type (),
111- rewriter.getIntegerAttr (rewriter.getI32Type (), sem));
123+ Value semValue = getSemanticsValue (rewriter, loc, floatTy);
112124 SmallVector<Value> params = {semValue, lhsBits, rhsBits};
113125 auto resultOp =
114126 func::CallOp::create (rewriter, loc, TypeRange (rewriter.getI64Type ()),
@@ -126,6 +138,53 @@ struct BinaryArithOpToAPFloatConversion final : OpRewritePattern<OpTy> {
126138 const char *APFloatName;
127139};
128140
141+ template <typename OpTy>
142+ struct FpToFpConversion final : OpRewritePattern<OpTy> {
143+ FpToFpConversion (MLIRContext *context, SymbolOpInterface symTable,
144+ PatternBenefit benefit = 1 )
145+ : OpRewritePattern<OpTy>(context, benefit), symTable(symTable) {}
146+
147+ LogicalResult matchAndRewrite (OpTy op,
148+ PatternRewriter &rewriter) const override {
149+ // Get APFloat function from runtime library.
150+ auto i32Type = IntegerType::get (symTable->getContext (), 32 );
151+ auto i64Type = IntegerType::get (symTable->getContext (), 64 );
152+ FailureOr<FuncOp> fn = lookupOrCreateApFloatFn (
153+ rewriter, symTable, " convert" , {i32Type, i32Type, i64Type});
154+ if (failed (fn))
155+ return fn;
156+
157+ rewriter.setInsertionPoint (op);
158+ // Cast operands to 64-bit integers.
159+ Location loc = op.getLoc ();
160+ auto inFloatTy = cast<FloatType>(op.getOperand ().getType ());
161+ auto inIntWType = rewriter.getIntegerType (inFloatTy.getWidth ());
162+ auto int64Type = rewriter.getI64Type ();
163+ Value operandBits = arith::ExtUIOp::create (
164+ rewriter, loc, int64Type,
165+ arith::BitcastOp::create (rewriter, loc, inIntWType, op.getOperand ()));
166+
167+ // Call APFloat function.
168+ Value inSemValue = getSemanticsValue (rewriter, loc, inFloatTy);
169+ auto outFloatTy = cast<FloatType>(op.getType ());
170+ Value outSemValue = getSemanticsValue (rewriter, loc, outFloatTy);
171+ std::array<Value, 3 > params = {inSemValue, outSemValue, operandBits};
172+ auto resultOp =
173+ func::CallOp::create (rewriter, loc, TypeRange (rewriter.getI64Type ()),
174+ SymbolRefAttr::get (*fn), params);
175+
176+ // Truncate result to the original width.
177+ auto outIntWType = rewriter.getIntegerType (outFloatTy.getWidth ());
178+ Value truncatedBits = arith::TruncIOp::create (rewriter, loc, outIntWType,
179+ resultOp->getResult (0 ));
180+ rewriter.replaceOp (
181+ op, arith::BitcastOp::create (rewriter, loc, outFloatTy, truncatedBits));
182+ return success ();
183+ }
184+
185+ SymbolOpInterface symTable;
186+ };
187+
129188namespace {
130189struct ArithToAPFloatConversionPass final
131190 : impl::ArithToAPFloatConversionPassBase<ArithToAPFloatConversionPass> {
@@ -147,6 +206,9 @@ void ArithToAPFloatConversionPass::runOnOperation() {
147206 context, " divide" , getOperation ());
148207 patterns.add <BinaryArithOpToAPFloatConversion<arith::RemFOp>>(
149208 context, " remainder" , getOperation ());
209+ patterns
210+ .add <FpToFpConversion<arith::ExtFOp>, FpToFpConversion<arith::TruncFOp>>(
211+ context, getOperation ());
150212 LogicalResult result = success ();
151213 ScopedDiagnosticHandler scopedHandler (context, [&result](Diagnostic &diag) {
152214 if (diag.getSeverity () == DiagnosticSeverity::Error) {
0 commit comments