2121#include " mlir/IR/DialectImplementation.h"
2222#include " mlir/IR/DialectResourceBlobManager.h"
2323#include " mlir/IR/IntegerSet.h"
24+ #include " llvm/ADT/APFloat.h"
2425#include " llvm/ADT/StringExtras.h"
2526#include " llvm/Support/Endian.h"
27+ #include < cmath>
2628#include < optional>
2729
2830using namespace mlir ;
@@ -121,14 +123,16 @@ Attribute Parser::parseAttribute(Type type) {
121123
122124 // Parse floating point and integer attributes.
123125 case Token::floatliteral:
126+ case Token::kw_inf:
127+ case Token::kw_nan:
124128 return parseFloatAttr (type, /* isNegative=*/ false );
125129 case Token::integer:
126130 return parseDecOrHexAttr (type, /* isNegative=*/ false );
127131 case Token::minus: {
128132 consumeToken (Token::minus);
129133 if (getToken ().is (Token::integer))
130134 return parseDecOrHexAttr (type, /* isNegative=*/ true );
131- if (getToken ().is (Token::floatliteral))
135+ if (getToken ().isAny (Token::floatliteral, Token::kw_inf, Token::kw_nan ))
132136 return parseFloatAttr (type, /* isNegative=*/ true );
133137
134138 return (emitWrongTokenError (
@@ -342,21 +346,25 @@ ParseResult Parser::parseAttributeDict(NamedAttrList &attributes) {
342346
343347// / Parse a float attribute.
344348Attribute Parser::parseFloatAttr (Type type, bool isNegative) {
345- auto val = getToken ().getFloatingPointValue ();
346- if (!val)
347- return (emitError (" floating point value too large for attribute" ), nullptr );
348- consumeToken (Token::floatliteral);
349+ const Token tok = getToken ();
350+ consumeToken ();
349351 if (!type) {
350352 // Default to F64 when no type is specified.
351353 if (!consumeIf (Token::colon))
352354 type = builder.getF64Type ();
353355 else if (!(type = parseType ()))
354356 return nullptr ;
355357 }
356- if (!isa<FloatType>(type))
357- return (emitError (" floating point value not valid for specified type" ),
358+ auto floatType = dyn_cast<FloatType>(type);
359+ if (!floatType)
360+ return (emitError (tok.getLoc (),
361+ " floating point value not valid for specified type" ),
358362 nullptr );
359- return FloatAttr::get (type, isNegative ? -*val : *val);
363+ std::optional<APFloat> apResult;
364+ if (failed (parseFloatFromLiteral (apResult, tok, isNegative,
365+ floatType.getFloatSemantics ())))
366+ return Attribute ();
367+ return FloatAttr::get (floatType, *apResult);
360368}
361369
362370// / Construct an APint from a parsed value, a known attribute type and
@@ -622,7 +630,7 @@ TensorLiteralParser::getIntAttrElements(SMLoc loc, Type eltTy,
622630 }
623631
624632 // Check to see if floating point values were parsed.
625- if (token.is (Token::floatliteral)) {
633+ if (token.isAny (Token::floatliteral, Token::kw_inf, Token::kw_nan )) {
626634 return p.emitError (tokenLoc)
627635 << " expected integer elements, but parsed floating-point" ;
628636 }
@@ -729,6 +737,8 @@ ParseResult TensorLiteralParser::parseElement() {
729737 // Parse a boolean element.
730738 case Token::kw_true:
731739 case Token::kw_false:
740+ case Token::kw_inf:
741+ case Token::kw_nan:
732742 case Token::floatliteral:
733743 case Token::integer:
734744 storage.emplace_back (/* isNegative=*/ false , p.getToken ());
@@ -738,7 +748,8 @@ ParseResult TensorLiteralParser::parseElement() {
738748 // Parse a signed integer or a negative floating-point element.
739749 case Token::minus:
740750 p.consumeToken (Token::minus);
741- if (!p.getToken ().isAny (Token::floatliteral, Token::integer))
751+ if (!p.getToken ().isAny (Token::floatliteral, Token::kw_inf, Token::kw_nan,
752+ Token::integer))
742753 return p.emitError (" expected integer or floating point literal" );
743754 storage.emplace_back (/* isNegative=*/ true , p.getToken ());
744755 p.consumeToken ();
0 commit comments