Skip to content

Commit 7baaa4a

Browse files
committed
TST: sinc: add tests
1 parent 28bff59 commit 7baaa4a

File tree

2 files changed

+19
-3
lines changed

2 files changed

+19
-3
lines changed

src/array_api_extra/_funcs.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -367,9 +367,9 @@ def sinc(x: Array, /, *, xp: ModuleType) -> Array:
367367
368368
Parameters
369369
----------
370-
x : array
370+
x : array of floats
371371
Array (possibly multi-dimensional) of values for which to calculate
372-
``sinc(x)``.
372+
``sinc(x)``. Should have a floating point dtype.
373373
374374
Returns
375375
-------
@@ -423,5 +423,8 @@ def sinc(x: Array, /, *, xp: ModuleType) -> Array:
423423
-3.89817183e-17], dtype=array_api_strict.float64)
424424
425425
"""
426+
if not xp.isdtype(x.dtype, "real floating"):
427+
err_msg = "`x` must have a real floating data type."
428+
raise ValueError(err_msg)
426429
y = xp.pi * xp.where(x == 0, xp.asarray(1.0e-20), x)
427430
return xp.sin(y) / y

tests/test_funcs.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
import pytest
1010
from numpy.testing import assert_allclose, assert_array_equal, assert_equal
1111

12-
from array_api_extra import atleast_nd, cov, expand_dims, kron
12+
from array_api_extra import atleast_nd, cov, expand_dims, kron, sinc
1313

1414
if TYPE_CHECKING:
1515
Array = Any # To be changed to a Protocol later (see array-api#589)
@@ -224,3 +224,16 @@ def test_positive_negative_repeated(self):
224224
a = xp.empty((2, 3, 4, 5))
225225
with pytest.raises(ValueError, match="Duplicate dimensions"):
226226
expand_dims(a, axis=(3, -3), xp=xp)
227+
228+
229+
class TestSinc:
230+
def test_simple(self):
231+
assert_array_equal(sinc(xp.asarray(0.0), xp=xp), xp.asarray(1.0))
232+
w = sinc(xp.linspace(-1, 1, 100), xp=xp)
233+
# check symmetry
234+
assert_allclose(w, xp.flip(w, axis=0))
235+
236+
@pytest.mark.parametrize("x", [0, 1 + 3j])
237+
def test_dtype(self, x):
238+
with pytest.raises(ValueError, match="real floating data type"):
239+
sinc(xp.asarray(x), xp=xp)

0 commit comments

Comments
 (0)