Skip to content

Commit 253caeb

Browse files
committed
Add more tests to cover different use cases
1 parent 19a27a3 commit 253caeb

File tree

1 file changed

+87
-0
lines changed

1 file changed

+87
-0
lines changed

tests/test_mathematical.py

Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1897,6 +1897,93 @@ def test_wrong_tol_type(self, xp, tol_val):
18971897
assert_raises(TypeError, xp.real_if_close, a, tol=tol_val)
18981898

18991899

1900+
class TestSinc:
1901+
@pytest.mark.parametrize(
1902+
"dt", get_all_dtypes(no_none=True, no_bool=True, no_float16=False)
1903+
)
1904+
def test_basic(self, dt):
1905+
a = numpy.linspace(-1, 1, 100, dtype=dt)
1906+
ia = dpnp.array(a)
1907+
1908+
result = dpnp.sinc(ia)
1909+
expected = numpy.sinc(a)
1910+
assert_dtype_allclose(result, expected)
1911+
1912+
def test_bool(self):
1913+
a = numpy.array([True, False, True])
1914+
ia = dpnp.array(a)
1915+
1916+
result = dpnp.sinc(ia)
1917+
expected = numpy.sinc(a)
1918+
# numpy 1.26 promotes result to float64 dtype, but expected float16
1919+
assert_dtype_allclose(
1920+
result,
1921+
expected,
1922+
check_only_type_kind=(
1923+
numpy.lib.NumpyVersion(numpy.__version__) < "2.0.0"
1924+
),
1925+
)
1926+
1927+
@pytest.mark.parametrize("dt", get_all_dtypes(no_none=True, no_bool=True))
1928+
def test_zero(self, dt):
1929+
if (
1930+
dt == numpy.float16
1931+
and numpy.lib.NumpyVersion(numpy.__version__) < "2.0.0"
1932+
):
1933+
pytest.skip("numpy.sinc return NaN")
1934+
1935+
a = numpy.array([0.0], dtype=dt)
1936+
ia = dpnp.array(a)
1937+
1938+
result = dpnp.sinc(ia)
1939+
expected = numpy.sinc(a)
1940+
assert_dtype_allclose(result, expected)
1941+
1942+
@testing.with_requires("numpy>=2.0.0")
1943+
def test_zero_fp16(self):
1944+
a = numpy.array([0.0], dtype=numpy.float16)
1945+
ia = dpnp.array(a)
1946+
1947+
result = dpnp.sinc(ia)
1948+
expected = numpy.sinc(a)
1949+
assert_dtype_allclose(result, expected)
1950+
1951+
@pytest.mark.usefixtures("suppress_invalid_numpy_warnings")
1952+
def test_nan_infs(self):
1953+
a = numpy.array([numpy.inf, -numpy.inf, numpy.nan])
1954+
ia = dpnp.array(a)
1955+
1956+
result = dpnp.sinc(ia)
1957+
expected = numpy.sinc(a)
1958+
assert_equal(result, expected)
1959+
1960+
@pytest.mark.usefixtures("suppress_invalid_numpy_warnings")
1961+
def test_nan_infs_complex(self):
1962+
a = numpy.array(
1963+
[
1964+
numpy.inf,
1965+
-numpy.inf,
1966+
numpy.nan,
1967+
complex(numpy.nan),
1968+
complex(numpy.nan, numpy.nan),
1969+
complex(0, numpy.nan),
1970+
complex(numpy.inf, numpy.nan),
1971+
complex(numpy.nan, numpy.inf),
1972+
complex(-numpy.inf, numpy.nan),
1973+
complex(numpy.nan, -numpy.inf),
1974+
complex(numpy.inf, numpy.inf),
1975+
complex(numpy.inf, -numpy.inf),
1976+
complex(-numpy.inf, numpy.inf),
1977+
complex(-numpy.inf, -numpy.inf),
1978+
]
1979+
)
1980+
ia = dpnp.array(a)
1981+
1982+
result = dpnp.sinc(ia)
1983+
expected = numpy.sinc(a)
1984+
assert_equal(result, expected)
1985+
1986+
19001987
class TestTrapezoid:
19011988
def get_numpy_func(self):
19021989
if numpy.lib.NumpyVersion(numpy.__version__) < "2.0.0":

0 commit comments

Comments
 (0)