@@ -99,6 +99,14 @@ static FloatAttr convertFloatAttr(FloatAttr srcAttr, FloatType dstType,
9999 return builder.getF32FloatAttr (dstVal.convertToFloat ());
100100}
101101
102+ // Get IntegerAttr from FloatAttr.
103+ IntegerAttr getIntegerAttrFromFloatAttr (FloatAttr floatAttr, Type dstType,
104+ ConversionPatternRewriter &rewriter) {
105+ APFloat floatVal = floatAttr.getValue ();
106+ APInt intVal = floatVal.bitcastToAPInt ();
107+ return rewriter.getIntegerAttr (dstType, intVal);
108+ }
109+
102110// / Returns true if the given `type` is a boolean scalar or vector type.
103111static bool isBoolScalarOrVector (Type type) {
104112 assert (type && " Not a valid type" );
@@ -296,8 +304,16 @@ struct ConstantCompositeOpPattern final
296304 SmallVector<Attribute, 8 > elements;
297305 if (isa<FloatType>(srcElemType)) {
298306 for (FloatAttr srcAttr : dstElementsAttr.getValues <FloatAttr>()) {
299- FloatAttr dstAttr =
300- convertFloatAttr (srcAttr, cast<FloatType>(dstElemType), rewriter);
307+ Attribute dstAttr = nullptr ;
308+ // Handle 8-bit float conversion to 8-bit integer.
309+ if (srcElemType.getIntOrFloatBitWidth () == 8 &&
310+ isa<IntegerType>(dstElemType)) {
311+ dstAttr =
312+ getIntegerAttrFromFloatAttr (srcAttr, dstElemType, rewriter);
313+ } else {
314+ dstAttr = convertFloatAttr (srcAttr, cast<FloatType>(dstElemType),
315+ rewriter);
316+ }
301317 if (!dstAttr)
302318 return failure ();
303319 elements.push_back (dstAttr);
@@ -361,11 +377,17 @@ struct ConstantScalarOpPattern final
361377 // Floating-point types.
362378 if (isa<FloatType>(srcType)) {
363379 auto srcAttr = cast<FloatAttr>(cstAttr);
364- auto dstAttr = srcAttr;
380+ Attribute dstAttr = srcAttr;
365381
366382 // Floating-point types not supported in the target environment are all
367383 // converted to float type.
368- if (srcType != dstType) {
384+ if (srcType.getIntOrFloatBitWidth () == 8 && isa<IntegerType>(dstType) &&
385+ dstType.getIntOrFloatBitWidth () == 8 ) {
386+ // If the source is an 8-bit float, convert it to a 8-bit integer.
387+ dstAttr = getIntegerAttrFromFloatAttr (srcAttr, dstType, rewriter);
388+ if (!dstAttr)
389+ return failure ();
390+ } else if (srcType != dstType) {
369391 dstAttr = convertFloatAttr (srcAttr, cast<FloatType>(dstType), rewriter);
370392 if (!dstAttr)
371393 return failure ();
0 commit comments