Skip to content

Commit b93bc43

Browse files
authored
[SYCL] Fix corner case in imf function 'fdiv_rn' (#20427)
Triton XPU backend developer reports imf function fdiv_rn returns incorrect result for corner case fdiv_rn(1.0f, 0.0f). The expected should be inf instead of nan. This PR fixes this issue. Signed-off-by: jinge90 <[email protected]>
1 parent 30e1d2f commit b93bc43

File tree

5 files changed

+81
-24
lines changed

5 files changed

+81
-24
lines changed

libdevice/imf_rounding_op.hpp

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -670,12 +670,24 @@ template <typename Ty> Ty __fp_div(Ty x, Ty y, int rd) {
670670
const UTy sig_off_mask = (one_bits << (sizeof(UTy) * 8 - 1)) - 1;
671671

672672
if (((x_exp == __iml_fp_config<Ty>::exp_mask) && (x_fra != 0x0)) ||
673-
((y_exp == __iml_fp_config<Ty>::exp_mask) && (y_fra != 0x0)) ||
674-
((y_bit & sig_off_mask) == 0x0)) {
673+
((y_exp == __iml_fp_config<Ty>::exp_mask) && (y_fra != 0x0))) {
675674
UTy tmp = __iml_fp_config<Ty>::nan_bits;
676675
return __builtin_bit_cast(Ty, tmp);
677676
}
678677

678+
// 0.f / 0.f ----> NAN
679+
if ((y_bit & sig_off_mask) == 0x0) {
680+
if ((x_bit & sig_off_mask) == 0x0) {
681+
UTy tmp = __iml_fp_config<Ty>::nan_bits;
682+
return __builtin_bit_cast(Ty, tmp);
683+
} else {
684+
// return +inf if x_sig and y_sig are same otherwise return -inf
685+
UTy tmp = (z_sig == 0) ? __iml_fp_config<Ty>::pos_inf_bits
686+
: __iml_fp_config<Ty>::neg_inf_bits;
687+
return __builtin_bit_cast(Ty, tmp);
688+
}
689+
}
690+
679691
if ((x_exp == __iml_fp_config<Ty>::exp_mask) && (x_fra == 0x0)) {
680692
if ((y_exp == __iml_fp_config<Ty>::exp_mask) && (y_fra == 0x0)) {
681693
UTy tmp = __iml_fp_config<Ty>::nan_bits;

sycl/test-e2e/DeviceLib/imf/a.out

1.44 MB
Binary file not shown.

sycl/test-e2e/DeviceLib/imf/fp32_rounding_test.cpp

Lines changed: 18 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -92,18 +92,24 @@ int main(int, char **) {
9292
}
9393

9494
{
95-
std::initializer_list<float> input_vals1 = {0x1p-1, 0x1.8bd054p+6,
96-
0x1.fcd686p+0, -0x1.7f9abp+3};
97-
std::initializer_list<float> input_vals2 = {-0x1.a8p+2, -0x1.674a3cp+5,
98-
0x1.f3d6aep+10, 0x1.d6bf48p+10};
99-
std::initializer_list<uint32_t> ref_vals_rd = {0xbd9a90e8, 0xc00d030d,
100-
0x3a824df9, 0xbbd09c3a};
101-
std::initializer_list<uint32_t> ref_vals_rn = {0xbd9a90e8, 0xc00d030c,
102-
0x3a824df9, 0xbbd09c39};
103-
std::initializer_list<uint32_t> ref_vals_ru = {0xbd9a90e7, 0xc00d030c,
104-
0x3a824dfa, 0xbbd09c39};
105-
std::initializer_list<uint32_t> ref_vals_rz = {0xbd9a90e7, 0xc00d030c,
106-
0x3a824df9, 0xbbd09c39};
95+
std::initializer_list<float> input_vals1 = {
96+
0x1p-1, 0x1.8bd054p+6, 0x1.fcd686p+0, -0x1.7f9abp+3,
97+
0x1p+0, -0x1p+0, 0x0p+0};
98+
std::initializer_list<float> input_vals2 = {
99+
-0x1.a8p+2, -0x1.674a3cp+5, 0x1.f3d6aep+10, 0x1.d6bf48p+10,
100+
0x0p+0, 0x0p+0, 0x0p+0};
101+
std::initializer_list<uint32_t> ref_vals_rd = {
102+
0xbd9a90e8, 0xc00d030d, 0x3a824df9, 0xbbd09c3a,
103+
0x7F800000, 0xFF800000, 0x7FC00000};
104+
std::initializer_list<uint32_t> ref_vals_rn = {
105+
0xbd9a90e8, 0xc00d030c, 0x3a824df9, 0xbbd09c39,
106+
0x7F800000, 0xFF800000, 0x7FC00000};
107+
std::initializer_list<uint32_t> ref_vals_ru = {
108+
0xbd9a90e7, 0xc00d030c, 0x3a824dfa, 0xbbd09c39,
109+
0x7F800000, 0xFF800000, 0x7FC00000};
110+
std::initializer_list<uint32_t> ref_vals_rz = {
111+
0xbd9a90e7, 0xc00d030c, 0x3a824df9, 0xbbd09c39,
112+
0x7F800000, 0xFF800000, 0x7FC00000};
107113
test2(device_queue, input_vals1, input_vals2, ref_vals_rd,
108114
F2T(uint32_t, sycl::ext::intel::math::fdiv_rd));
109115
std::cout << "sycl::ext::intel::math::fdiv_rd passes." << std::endl;

sycl/test-e2e/DeviceLib/imf/fp64_rounding_test.cpp

Lines changed: 22 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -115,24 +115,36 @@ int main(int, char **) {
115115
}
116116

117117
{
118-
std::initializer_list<double> input_vals1 = {
119-
0x1.5ef3da7bf609ap+4, 0x1.fbd37afb0f8edp-1, 0x1.9238e38e38e35p+6,
120-
0x1.7p+3};
121-
std::initializer_list<double> input_vals2 = {
122-
-0x1.bc7db6de6d33fp+9, 0x1.2f638fa4e71a6p+10, 0x1.08e38e38e38e3p+4,
123-
-0x1.94p+3};
118+
std::initializer_list<double> input_vals1 = {0x1.5ef3da7bf609ap+4,
119+
0x1.fbd37afb0f8edp-1,
120+
0x1.9238e38e38e35p+6,
121+
0x1.7p+3,
122+
0x1p+0,
123+
-0x1p+0,
124+
0x0p+0};
125+
std::initializer_list<double> input_vals2 = {-0x1.bc7db6de6d33fp+9,
126+
0x1.2f638fa4e71a6p+10,
127+
0x1.08e38e38e38e3p+4,
128+
-0x1.94p+3,
129+
0x0p+0,
130+
0x0p+0,
131+
0x0p+0};
124132
std::initializer_list<uint64_t> ref_vals_rd = {
125133
0xbf994414312c26ab, 0x3f4ac811fc63acd9, 0x40184b98e9aa180a,
126-
0xbfed260511be1959};
134+
0xbfed260511be1959, 0x7FF0000000000000, 0xFFF0000000000000,
135+
0x7FF8000000000000};
127136
std::initializer_list<uint64_t> ref_vals_rn = {
128137
0xbf994414312c26ab, 0x3f4ac811fc63acd9, 0x40184b98e9aa180b,
129-
0xbfed260511be1959};
138+
0xbfed260511be1959, 0x7FF0000000000000, 0xFFF0000000000000,
139+
0x7FF8000000000000};
130140
std::initializer_list<uint64_t> ref_vals_ru = {
131141
0xbf994414312c26aa, 0x3f4ac811fc63acda, 0x40184b98e9aa180b,
132-
0xbfed260511be1958};
142+
0xbfed260511be1958, 0x7FF0000000000000, 0xFFF0000000000000,
143+
0x7FF8000000000000};
133144
std::initializer_list<uint64_t> ref_vals_rz = {
134145
0xbf994414312c26aa, 0x3f4ac811fc63acd9, 0x40184b98e9aa180a,
135-
0xbfed260511be1958};
146+
0xbfed260511be1958, 0x7FF0000000000000, 0xFFF0000000000000,
147+
0x7FF8000000000000};
136148
test2(device_queue, input_vals1, input_vals2, ref_vals_rd,
137149
F2T(uint64_t, sycl::ext::intel::math::ddiv_rd));
138150
std::cout << "sycl::ext::intel::math::ddiv_rd passes." << std::endl;

sycl/test-e2e/DeviceLib/imf/imf_utils.hpp

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,11 @@ typedef _Float16 _iml_half_internal;
1414
typedef uint16_t _iml_half_internal;
1515
#endif
1616

17+
#define EXP_MASK32 0xFF
18+
#define EXP_MASK64 0x7FF
19+
#define FRA_MASK32 0x7FFFFF
20+
#define FRA_MASK64 0xFFFFFFFFFFFFF
21+
1722
template <class Ty> class imf_utils_default_equ {
1823
public:
1924
bool operator()(Ty x, Ty y) {
@@ -24,6 +29,28 @@ template <class Ty> class imf_utils_default_equ {
2429
};
2530
};
2631

32+
template <> class imf_utils_default_equ<uint32_t> {
33+
public:
34+
bool operator()(uint32_t x, uint32_t y) {
35+
bool x_is_nan =
36+
(((x >> 23) & EXP_MASK32) == EXP_MASK32) && ((x & FRA_MASK32) != 0);
37+
bool y_is_nan =
38+
(((y >> 23) & EXP_MASK32) == EXP_MASK32) && ((y & FRA_MASK32) != 0);
39+
return (x_is_nan && y_is_nan) || (x == y);
40+
}
41+
};
42+
43+
template <> class imf_utils_default_equ<uint64_t> {
44+
public:
45+
bool operator()(uint64_t x, uint64_t y) {
46+
bool x_is_nan =
47+
(((x >> 52) & EXP_MASK64) == EXP_MASK64) && ((x & FRA_MASK64) != 0);
48+
bool y_is_nan =
49+
(((y >> 52) & EXP_MASK64) == EXP_MASK64) && ((y & FRA_MASK64) != 0);
50+
return (x_is_nan && y_is_nan) || (x == y);
51+
}
52+
};
53+
2754
// Used to test half precision utils
2855
template <class InputTy, class OutputTy, class FuncTy,
2956
class EquTy = imf_utils_default_equ<OutputTy>>

0 commit comments

Comments
 (0)