77// ===----------------------------------------------------------------------===//
88
99#include " mlir/Conversion/ArithToAPFloat/ArithToAPFloat.h"
10+ #include " Utils.h"
1011
1112#include " mlir/Dialect/Arith/IR/Arith.h"
1213#include " mlir/Dialect/Arith/Transforms/Passes.h"
@@ -25,47 +26,6 @@ namespace mlir {
2526using namespace mlir ;
2627using namespace mlir ::func;
2728
28- static FuncOp createFnDecl (OpBuilder &b, SymbolOpInterface symTable,
29- StringRef name, FunctionType funcT, bool setPrivate,
30- SymbolTableCollection *symbolTables = nullptr ) {
31- OpBuilder::InsertionGuard g (b);
32- assert (!symTable->getRegion (0 ).empty () && " expected non-empty region" );
33- b.setInsertionPointToStart (&symTable->getRegion (0 ).front ());
34- FuncOp funcOp = FuncOp::create (b, symTable->getLoc (), name, funcT);
35- if (setPrivate)
36- funcOp.setPrivate ();
37- if (symbolTables) {
38- SymbolTable &symbolTable = symbolTables->getSymbolTable (symTable);
39- symbolTable.insert (funcOp, symTable->getRegion (0 ).front ().begin ());
40- }
41- return funcOp;
42- }
43-
44- // / Helper function to look up or create the symbol for a runtime library
45- // / function with the given parameter types. Returns an int64_t, unless a
46- // / different result type is specified.
47- static FailureOr<FuncOp>
48- lookupOrCreateApFloatFn (OpBuilder &b, SymbolOpInterface symTable,
49- StringRef name, TypeRange paramTypes,
50- SymbolTableCollection *symbolTables = nullptr ,
51- Type resultType = {}) {
52- if (!resultType)
53- resultType = IntegerType::get (symTable->getContext (), 64 );
54- std::string funcName = (llvm::Twine (" _mlir_apfloat_" ) + name).str ();
55- auto funcT = FunctionType::get (b.getContext (), paramTypes, {resultType});
56- FailureOr<FuncOp> func =
57- lookupFnDecl (symTable, funcName, funcT, symbolTables);
58- // Failed due to type mismatch.
59- if (failed (func))
60- return func;
61- // Successfully matched existing decl.
62- if (*func)
63- return *func;
64-
65- return createFnDecl (b, symTable, funcName, funcT,
66- /* setPrivate=*/ true , symbolTables);
67- }
68-
6929// / Helper function to look up or create the symbol for a runtime library
7030// / function for a binary arithmetic operation.
7131// /
@@ -81,14 +41,9 @@ lookupOrCreateBinaryFn(OpBuilder &b, SymbolOpInterface symTable, StringRef name,
8141 SymbolTableCollection *symbolTables = nullptr ) {
8242 auto i32Type = IntegerType::get (symTable->getContext (), 32 );
8343 auto i64Type = IntegerType::get (symTable->getContext (), 64 );
84- return lookupOrCreateApFloatFn (b, symTable, name, {i32Type, i64Type, i64Type},
85- symbolTables);
86- }
87-
88- static Value getSemanticsValue (OpBuilder &b, Location loc, FloatType floatTy) {
89- int32_t sem = llvm::APFloatBase::SemanticsToEnum (floatTy.getFloatSemantics ());
90- return arith::ConstantOp::create (b, loc, b.getI32Type (),
91- b.getIntegerAttr (b.getI32Type (), sem));
44+ std::string funcName = (llvm::Twine (" _mlir_apfloat_" ) + name).str ();
45+ return lookupOrCreateFnDecl (b, symTable, funcName, {i32Type, i64Type, i64Type},
46+ symbolTables);
9247}
9348
9449// / Given two operands of vector type and vector result type (with the same
@@ -197,7 +152,7 @@ struct BinaryArithOpToAPFloatConversion final : OpRewritePattern<OpTy> {
197152 arith::BitcastOp::create (rewriter, loc, intWType, rhs));
198153
199154 // Call APFloat function.
200- Value semValue = getSemanticsValue (rewriter, loc, floatTy);
155+ Value semValue = getAPFloatSemanticsValue (rewriter, loc, floatTy);
201156 SmallVector<Value> params = {semValue, lhsBits, rhsBits};
202157 auto resultOp = func::CallOp::create (rewriter, loc,
203158 TypeRange (rewriter.getI64Type ()),
@@ -231,8 +186,9 @@ struct FpToFpConversion final : OpRewritePattern<OpTy> {
231186 // Get APFloat function from runtime library.
232187 auto i32Type = IntegerType::get (symTable->getContext (), 32 );
233188 auto i64Type = IntegerType::get (symTable->getContext (), 64 );
234- FailureOr<FuncOp> fn = lookupOrCreateApFloatFn (
235- rewriter, symTable, " convert" , {i32Type, i32Type, i64Type});
189+ FailureOr<FuncOp> fn =
190+ lookupOrCreateFnDecl (rewriter, symTable, " _mlir_apfloat_convert" ,
191+ {i32Type, i32Type, i64Type});
236192 if (failed (fn))
237193 return fn;
238194
@@ -250,9 +206,10 @@ struct FpToFpConversion final : OpRewritePattern<OpTy> {
250206 arith::BitcastOp::create (rewriter, loc, inIntWType, operand1));
251207
252208 // Call APFloat function.
253- Value inSemValue = getSemanticsValue (rewriter, loc, inFloatTy);
209+ Value inSemValue = getAPFloatSemanticsValue (rewriter, loc, inFloatTy);
254210 auto outFloatTy = cast<FloatType>(resultType);
255- Value outSemValue = getSemanticsValue (rewriter, loc, outFloatTy);
211+ Value outSemValue =
212+ getAPFloatSemanticsValue (rewriter, loc, outFloatTy);
256213 std::array<Value, 3 > params = {inSemValue, outSemValue, operandBits};
257214 auto resultOp = func::CallOp::create (rewriter, loc,
258215 TypeRange (rewriter.getI64Type ()),
@@ -289,8 +246,8 @@ struct FpToIntConversion final : OpRewritePattern<OpTy> {
289246 auto i32Type = IntegerType::get (symTable->getContext (), 32 );
290247 auto i64Type = IntegerType::get (symTable->getContext (), 64 );
291248 FailureOr<FuncOp> fn =
292- lookupOrCreateApFloatFn (rewriter, symTable, " convert_to_int " ,
293- {i32Type, i32Type, i1Type, i64Type});
249+ lookupOrCreateFnDecl (rewriter, symTable, " _mlir_apfloat_convert_to_int " ,
250+ {i32Type, i32Type, i1Type, i64Type});
294251 if (failed (fn))
295252 return fn;
296253
@@ -308,7 +265,7 @@ struct FpToIntConversion final : OpRewritePattern<OpTy> {
308265 arith::BitcastOp::create (rewriter, loc, inIntWType, operand1));
309266
310267 // Call APFloat function.
311- Value inSemValue = getSemanticsValue (rewriter, loc, inFloatTy);
268+ Value inSemValue = getAPFloatSemanticsValue (rewriter, loc, inFloatTy);
312269 auto outIntTy = cast<IntegerType>(resultType);
313270 Value outWidthValue = arith::ConstantOp::create (
314271 rewriter, loc, i32Type,
@@ -351,8 +308,8 @@ struct IntToFpConversion final : OpRewritePattern<OpTy> {
351308 auto i32Type = IntegerType::get (symTable->getContext (), 32 );
352309 auto i64Type = IntegerType::get (symTable->getContext (), 64 );
353310 FailureOr<FuncOp> fn =
354- lookupOrCreateApFloatFn (rewriter, symTable, " convert_from_int " ,
355- {i32Type, i32Type, i1Type, i64Type});
311+ lookupOrCreateFnDecl (rewriter, symTable, " _mlir_apfloat_convert_from_int " ,
312+ {i32Type, i32Type, i1Type, i64Type});
356313 if (failed (fn))
357314 return fn;
358315
@@ -377,7 +334,8 @@ struct IntToFpConversion final : OpRewritePattern<OpTy> {
377334
378335 // Call APFloat function.
379336 auto outFloatTy = cast<FloatType>(resultType);
380- Value outSemValue = getSemanticsValue (rewriter, loc, outFloatTy);
337+ Value outSemValue =
338+ getAPFloatSemanticsValue (rewriter, loc, outFloatTy);
381339 Value inWidthValue = arith::ConstantOp::create (
382340 rewriter, loc, i32Type,
383341 rewriter.getIntegerAttr (i32Type, inIntTy.getWidth ()));
@@ -421,8 +379,8 @@ struct CmpFOpToAPFloatConversion final : OpRewritePattern<arith::CmpFOp> {
421379 auto i32Type = IntegerType::get (symTable->getContext (), 32 );
422380 auto i64Type = IntegerType::get (symTable->getContext (), 64 );
423381 FailureOr<FuncOp> fn =
424- lookupOrCreateApFloatFn (rewriter, symTable, " compare " ,
425- {i32Type, i64Type, i64Type}, nullptr , i8Type);
382+ lookupOrCreateFnDecl (rewriter, symTable, " _mlir_apfloat_compare " ,
383+ {i32Type, i64Type, i64Type}, nullptr , i8Type);
426384 if (failed (fn))
427385 return fn;
428386
@@ -443,7 +401,7 @@ struct CmpFOpToAPFloatConversion final : OpRewritePattern<arith::CmpFOp> {
443401 arith::BitcastOp::create (rewriter, loc, intWType, rhs));
444402
445403 // Call APFloat function.
446- Value semValue = getSemanticsValue (rewriter, loc, floatTy);
404+ Value semValue = getAPFloatSemanticsValue (rewriter, loc, floatTy);
447405 SmallVector<Value> params = {semValue, lhsBits, rhsBits};
448406 Value comparisonResult =
449407 func::CallOp::create (rewriter, loc, TypeRange (i8Type),
@@ -569,8 +527,8 @@ struct NegFOpToAPFloatConversion final : OpRewritePattern<arith::NegFOp> {
569527 // Get APFloat function from runtime library.
570528 auto i32Type = IntegerType::get (symTable->getContext (), 32 );
571529 auto i64Type = IntegerType::get (symTable->getContext (), 64 );
572- FailureOr<FuncOp> fn =
573- lookupOrCreateApFloatFn ( rewriter, symTable, " neg " , {i32Type, i64Type});
530+ FailureOr<FuncOp> fn = lookupOrCreateFnDecl (
531+ rewriter, symTable, " _mlir_apfloat_neg " , {i32Type, i64Type});
574532 if (failed (fn))
575533 return fn;
576534
@@ -588,7 +546,7 @@ struct NegFOpToAPFloatConversion final : OpRewritePattern<arith::NegFOp> {
588546 arith::BitcastOp::create (rewriter, loc, intWType, operand1));
589547
590548 // Call APFloat function.
591- Value semValue = getSemanticsValue (rewriter, loc, floatTy);
549+ Value semValue = getAPFloatSemanticsValue (rewriter, loc, floatTy);
592550 SmallVector<Value> params = {semValue, operandBits};
593551 Value negatedBits =
594552 func::CallOp::create (rewriter, loc, TypeRange (i64Type),
0 commit comments