Skip to content

Commit f739f65

Browse files
committed
fft/linalg/special passing tests
1 parent 14deab1 commit f739f65

File tree

6 files changed

+32
-13
lines changed

6 files changed

+32
-13
lines changed

scipy/conftest.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -179,8 +179,11 @@ def num_parallel_threads():
179179
pass
180180

181181
try:
182-
import dask.array # type: ignore[import-not-found]
183-
xp_available_backends.update({'dask.array': dask.array})
182+
# Note: dask.array main namespace is not array API compatible
183+
# (to address this, we will fix tests that use the broken dask behavior to
184+
# use the array-api-compat wrapped version instead)
185+
import dask.array as da
186+
xp_available_backends.update({'dask.array': da})
184187
except ImportError:
185188
pass
186189

scipy/fft/tests/test_basic.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -333,8 +333,12 @@ def test_dtypes_real(self, dtype, xp):
333333

334334
@pytest.mark.parametrize("dtype", ["complex64", "complex128"])
335335
def test_dtypes_complex(self, dtype, xp):
336+
# Trick to get the array-api-compat namespace for dask
337+
# (otherwise the "naked" dask.array asarray does not respect
338+
# the input dtype)
339+
xp_test = array_namespace(xp.asarray(1))
336340
rng = np.random.default_rng(1234)
337-
x = xp.asarray(rng.random(30), dtype=getattr(xp, dtype))
341+
x = xp.asarray(rng.random(30), dtype=getattr(xp_test, dtype))
338342

339343
res_fft = fft.ifft(fft.fft(x))
340344
# Check both numerical results and exact dtype matches

scipy/fft/tests/test_real_transforms.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import scipy.fft as fft
88
from scipy import fftpack
99
from scipy.conftest import array_api_compatible
10-
from scipy._lib._array_api import xp_copy, xp_assert_close
10+
from scipy._lib._array_api import xp_copy, xp_assert_close, array_namespace
1111

1212
pytestmark = [array_api_compatible, pytest.mark.usefixtures("skip_xp_backends")]
1313
skip_xp_backends = pytest.mark.skip_xp_backends
@@ -199,6 +199,10 @@ def test_orthogonalize_noop(func, type, norm, xp):
199199
def test_orthogonalize_dct1(norm, xp):
200200
x = xp.asarray(np.random.rand(100))
201201

202+
# use array-api-compat namespace for dask
203+
# since dask asarray never makes a copy
204+
# which makes xp_copy silently a no-op
205+
xp = array_namespace(x)
202206
x2 = xp_copy(x, xp=xp)
203207
x2[0] *= SQRT_2
204208
x2[-1] *= SQRT_2
@@ -232,6 +236,10 @@ def test_orthogonalize_dcst2(func, norm, xp):
232236
@pytest.mark.parametrize("func", [dct, dst])
233237
def test_orthogonalize_dcst3(func, norm, xp):
234238
x = xp.asarray(np.random.rand(100))
239+
# use array-api-compat namespace for dask
240+
# since dask asarray never makes a copy
241+
# which makes xp_copy silently a no-op
242+
xp = array_namespace(x)
235243
x2 = xp_copy(x, xp=xp)
236244
x2[0 if func == dct else -1] *= SQRT_2
237245

scipy/special/_logsumexp.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -147,7 +147,9 @@ def _wrap_radians(x, xp=None):
147147
out = -((-x + math.pi) % (2 * math.pi) - math.pi)
148148
# preserve relative precision
149149
no_wrap = xp.abs(x) < xp.pi
150-
out[no_wrap] = x[no_wrap]
150+
# TODO: i think this is correct but double check
151+
# out[no_wrap] = x[no_wrap]
152+
out = xp.where(no_wrap, x, out)
151153
return out
152154

153155

scipy/special/tests/test_logsumexp.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,6 @@
1818

1919

2020
@array_api_compatible
21-
@pytest.mark.usefixtures("skip_xp_backends")
22-
@pytest.mark.skip_xp_backends('jax.numpy',
23-
reason="JAX arrays do not support item assignment")
2421
def test_wrap_radians(xp):
2522
x = xp.asarray([-math.pi-1, -math.pi, -1, -1e-300,
2623
0, 1e-300, 1, math.pi, math.pi+1])
@@ -35,6 +32,10 @@ def test_wrap_radians(xp):
3532
@pytest.mark.skip_xp_backends('jax.numpy',
3633
reason="JAX arrays do not support item assignment")
3734
class TestLogSumExp:
35+
# numpy warning filters don't work for dask
36+
# (also we should not expect the numpy warning filter to work for any Array API
37+
# library)
38+
@pytest.mark.filterwarnings("ignore:divide by zero encountered in log")
3839
def test_logsumexp(self, xp):
3940
# Test with zero-size array
4041
a = xp.asarray([])
@@ -69,11 +70,8 @@ def test_logsumexp(self, xp):
6970
nan = xp.asarray([xp.nan])
7071
xp_assert_equal(logsumexp(inf), inf[0])
7172
xp_assert_equal(logsumexp(-inf), -inf[0])
72-
# catch warnings here for dasks state there's no way to suppress
73-
# warnings just for dask
74-
# https://github.com/dask/dask/issues/3245
75-
with np.errstate(divide='ignore', invalid='ignore'):
76-
xp_assert_equal(logsumexp(nan), nan[0])
73+
74+
xp_assert_equal(logsumexp(nan), nan[0])
7775
xp_assert_equal(logsumexp(xp.asarray([-xp.inf, -xp.inf])), -inf[0])
7876

7977
# Handling an array with different magnitudes on the axes
@@ -119,6 +117,7 @@ def test_logsumexp_sign(self, xp):
119117
xp_assert_close(r, xp.asarray(1.))
120118
xp_assert_equal(s, xp.asarray(-1.))
121119

120+
@pytest.mark.filterwarnings("ignore::RuntimeWarning")
122121
def test_logsumexp_sign_zero(self, xp):
123122
a = xp.asarray([1, 1])
124123
b = xp.asarray([1, -1])
@@ -223,6 +222,7 @@ def test_gh18295(self, xp):
223222
ref = xp.logaddexp(a[0], a[1])
224223
xp_assert_close(res, ref)
225224

225+
@pytest.mark.filterwarnings("ignore::FutureWarning:dask")
226226
@pytest.mark.parametrize('dtype', ['complex64', 'complex128'])
227227
def test_gh21610(self, xp, dtype):
228228
# gh-21610 noted that `logsumexp` could return imaginary components

scipy/special/tests/test_support_alternative_backends.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,8 @@ def test_rel_entr_generic(dtype):
5252
# @pytest.mark.usefixtures("skip_xp_backends")
5353
# `reversed` is for developer convenience: test new function first = less waiting
5454
@pytest.mark.parametrize('f_name_n_args', reversed(array_special_func_map.items()))
55+
# numpy warning filter doesn't work for dask
56+
@pytest.mark.filterwarnings("ignore::RuntimeWarning")
5557
@pytest.mark.parametrize('dtype', ['float32', 'float64'])
5658
@pytest.mark.parametrize('shapes', [[(0,)]*4, [tuple()]*4, [(10,)]*4,
5759
[(10,), (11, 1), (12, 1, 1), (13, 1, 1, 1)]])

0 commit comments

Comments
 (0)