Skip to content

Commit b57a01a

Browse files
[SYCL] Fix apparently incorrect bfloat16 conversions (#20243)
Those looked wrong but I don't have any tests. We can either merge as-is or create an issue for the original author(s) to follow up. Seems that `bfloat16` APIs don't need any changes since C++23's `std::bfloat16` behaves in a similar way (although it's fundamental instead of a class type).
1 parent bd1e50d commit b57a01a

File tree

1 file changed

+9
-7
lines changed

1 file changed

+9
-7
lines changed

sycl/include/sycl/ext/oneapi/bfloat16.hpp

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -407,25 +407,25 @@ inline bfloat16 getBFloat16FromDoubleWithRTE(const double &d) {
407407
// handling +/-infinity and NAN for double input
408408
if (fp64_exp == 0x7FF) {
409409
if (!fp64_mant)
410-
return bf16_sign ? 0xFF80 : 0x7F80;
410+
return bit_cast<bfloat16, uint16_t>(bf16_sign ? 0xFF80 : 0x7F80);
411411

412412
// returns a quiet NaN
413-
return 0x7FC0;
413+
return bit_cast<bfloat16, uint16_t>(0x7FC0);
414414
}
415415

416416
// Subnormal double precision is converted to 0
417417
if (fp64_exp == 0)
418-
return bf16_sign ? 0x8000 : 0x0;
418+
return bit_cast<bfloat16, uint16_t>(bf16_sign ? 0x8000 : 0x0);
419419

420420
fp64_exp -= 1023;
421421

422422
// handling overflow, convert to +/-infinity
423423
if (static_cast<int16_t>(fp64_exp) > 127)
424-
return bf16_sign ? 0xFF80 : 0x7F80;
424+
return bit_cast<bfloat16, uint16_t>(bf16_sign ? 0xFF80 : 0x7F80);
425425

426426
// handling underflow
427427
if (static_cast<int16_t>(fp64_exp) < -133)
428-
return bf16_sign ? 0x8000 : 0x0;
428+
return bit_cast<bfloat16, uint16_t>(bf16_sign ? 0x8000 : 0x0);
429429

430430
//-133 <= fp64_exp <= 127, 1.signicand * 2^fp64_exp
431431
// For these numbers, they are NOT subnormal double-precision numbers but
@@ -444,7 +444,8 @@ inline bfloat16 getBFloat16FromDoubleWithRTE(const double &d) {
444444
bf16_mant = 0;
445445
fp64_exp = 1;
446446
}
447-
return (bf16_sign << 15) | (fp64_exp << 7) | bf16_mant;
447+
return bit_cast<bfloat16, uint16_t>((bf16_sign << 15) | (fp64_exp << 7) |
448+
bf16_mant);
448449
}
449450

450451
// For normal value, discard 45 bits from mantissa
@@ -462,7 +463,8 @@ inline bfloat16 getBFloat16FromDoubleWithRTE(const double &d) {
462463
}
463464
fp64_exp += 127;
464465

465-
return (bf16_sign << 15) | (fp64_exp << 7) | bf16_mant;
466+
return bit_cast<bfloat16, uint16_t>((bf16_sign << 15) | (fp64_exp << 7) |
467+
bf16_mant);
466468
}
467469

468470
// Function to get the most significant bit position of a number.

0 commit comments

Comments
 (0)