Skip to content

Commit a690214

Browse files
committed
Add dedicated sign handling for float16 dtype
1 parent 80d7649 commit a690214

File tree

2 files changed

+28
-4
lines changed

2 files changed

+28
-4
lines changed

dpnp/backend/kernels/elementwise_functions/spacing.hpp

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -51,8 +51,14 @@ struct SpacingFunctor
5151
return std::numeric_limits<resT>::quiet_NaN();
5252
}
5353

54-
const argT y = sycl::copysign(std::numeric_limits<argT>::infinity(), x);
55-
return sycl::nextafter(x, y) - x;
54+
constexpr argT inf = std::numeric_limits<argT>::infinity();
55+
if constexpr (std::is_same_v<argT, sycl::half>) {
56+
// numpy laways computes spacing towards +inf for float16 dtype
57+
return sycl::nextafter(x, inf) - x;
58+
}
59+
else {
60+
return sycl::nextafter(x, sycl::copysign(inf, x)) - x;
61+
}
5662
}
5763
};
5864
} // namespace dpnp::kernels::spacing

tests/test_mathematical.py

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1898,11 +1898,13 @@ def test_wrong_tol_type(self, xp, tol_val):
18981898

18991899

19001900
class TestSpacing:
1901+
@pytest.mark.parametrize("sign", [1, -1])
19011902
@pytest.mark.parametrize("dt", get_float_dtypes())
1902-
def test_basic(self, dt):
1903+
def test_basic(self, sign, dt):
19031904
a = numpy.array(
1904-
[1, numpy.nan, numpy.inf, 0.0, 1e10, 1e-5, 1000, 10500], dtype=dt
1905+
[1, numpy.nan, numpy.inf, 1e10, 1e-5, 1000, 10500], dtype=dt
19051906
)
1907+
a *= sign
19061908
ia = dpnp.array(a)
19071909

19081910
result = dpnp.spacing(ia)
@@ -1914,6 +1916,22 @@ def test_basic(self, dt):
19141916
expected = numpy.spacing(-a)
19151917
assert_equal(result, expected)
19161918

1919+
@pytest.mark.parametrize("dt", get_float_dtypes())
1920+
def test_zeros(self, dt):
1921+
a = numpy.array([0.0, -0.0], dtype=dt)
1922+
ia = dpnp.array(a)
1923+
1924+
result = dpnp.spacing(ia)
1925+
expected = numpy.spacing(a)
1926+
if numpy.lib.NumpyVersion(numpy.__version__) < "2.0.0":
1927+
assert_equal(result, expected)
1928+
else:
1929+
# numpy.spacing(-0.0) == numpy.spacing(0.0), i.e. NumPy returns
1930+
# positive value, while for any other negative input the result
1931+
# will be negative value (looks as a bug in NumPy)
1932+
expected[1] *= -1
1933+
assert_equal(result, expected)
1934+
19171935
@pytest.mark.parametrize("dt", get_float_dtypes(no_float16=False))
19181936
@pytest.mark.parametrize("val", [1, 1e-5, 1000])
19191937
@pytest.mark.parametrize("xp", [numpy, dpnp])

0 commit comments

Comments
 (0)