Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 0 additions & 4 deletions scipy/_lib/tests/test_array_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,10 +66,6 @@ def test_array_api_extra_hook(self):
with pytest.raises(TypeError, match=msg):
xpx.atleast_nd("abc", ndim=0)

@skip_xp_backends(
"dask.array",
reason="raw dask.array namespace ignores copy=True in asarray"
)
def test_copy(self, xp):
for _xp in [xp, None]:
x = xp.asarray([1, 2, 3])
Expand Down
7 changes: 4 additions & 3 deletions scipy/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -399,9 +399,10 @@ def skip_or_xfail_xp_backends(request: pytest.FixtureRequest,
if 'cpu' not in d.device_kind:
skip_or_xfail(reason=reason)
elif xp.__name__ == 'dask.array' and 'dask.array' not in exceptions:
if xp_device(xp.empty(0)) != 'cpu':
skip_or_xfail(reason=reason)

# dask has no device. 'cpu' is a hack introduced by array-api-compat.
# Force to revisit this when in the future
# dask adds proper device support
assert xp_device(xp.empty(0)) == 'cpu'

# Following the approach of NumPy's conftest.py...
# Use a known and persistent tmpdir for hypothesis' caches, which
Expand Down
22 changes: 11 additions & 11 deletions scipy/ndimage/tests/test_filters.py
Original file line number Diff line number Diff line change
Expand Up @@ -534,7 +534,7 @@ def test_correlate22(self, dtype_array, dtype_output, xp):
assert_array_almost_equal(output, expected)

@skip_xp_backends("jax.numpy", reason="output array is read-only.")
@skip_xp_backends("dask.array", reason="output array is read-only.")
@skip_xp_backends("dask.array", reason="output kw doesn't make sense for dask")

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A fully Array API compliant function would be able to suppor the output= kwarg with dask.
The problem is that you're calling np.asarray(x) will always return a buffer that is not shared with the input parameter.

Please change reason to "converts dask output array to numpy"

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I changed the comment.

@pytest.mark.parametrize('dtype_array', types)
@pytest.mark.parametrize('dtype_output', types)
def test_correlate23(self, dtype_array, dtype_output, xp):
Expand All @@ -554,7 +554,7 @@ def test_correlate23(self, dtype_array, dtype_output, xp):
assert_array_almost_equal(output, expected)

@skip_xp_backends("jax.numpy", reason="output array is read-only.")
@skip_xp_backends("dask.array", reason="output array is read-only.")
@skip_xp_backends("dask.array", reason="output kw doesn't make sense for dask")
@pytest.mark.parametrize('dtype_array', types)
@pytest.mark.parametrize('dtype_output', types)
def test_correlate24(self, dtype_array, dtype_output, xp):
Expand All @@ -575,7 +575,7 @@ def test_correlate24(self, dtype_array, dtype_output, xp):
assert_array_almost_equal(output, tcov)

@skip_xp_backends("jax.numpy", reason="output array is read-only.")
@skip_xp_backends("dask.array", reason="output array is read-only.")
@skip_xp_backends("dask.array", reason="output kw doesn't make sense for dask")
@pytest.mark.parametrize('dtype_array', types)
@pytest.mark.parametrize('dtype_output', types)
def test_correlate25(self, dtype_array, dtype_output, xp):
Expand Down Expand Up @@ -881,7 +881,7 @@ def test_gauss06(self, xp):
assert_array_almost_equal(output1, output2)

@skip_xp_backends("jax.numpy", reason="output array is read-only.")
@xfail_xp_backends("dask.array", reason="output keyword not handled properly")
@skip_xp_backends("dask.array", reason="output kw doesn't make sense for dask")
def test_gauss_memory_overlap(self, xp):
input = xp.arange(100 * 100, dtype=xp.float32)
input = xp.reshape(input, (100, 100))
Expand Down Expand Up @@ -1228,7 +1228,7 @@ def test_prewitt01(self, dtype, xp):
assert_array_almost_equal(t, output)

@skip_xp_backends("jax.numpy", reason="output array is read-only.")
@xfail_xp_backends("dask.array", reason="output array is read-only.")
@skip_xp_backends("dask.array", reason="output kw doesn't make sense for dask")
@pytest.mark.parametrize('dtype', types + complex_types)
def test_prewitt02(self, dtype, xp):
if is_torch(xp) and dtype in ("uint16", "uint32", "uint64"):
Expand Down Expand Up @@ -1291,7 +1291,7 @@ def test_sobel01(self, dtype, xp):
assert_array_almost_equal(t, output)

@skip_xp_backends("jax.numpy", reason="output array is read-only.",)
@skip_xp_backends("dask.array", reason="output array is read-only.")
@skip_xp_backends("dask.array", reason="output kw doesn't make sense for dask")
@pytest.mark.parametrize('dtype', types + complex_types)
def test_sobel02(self, dtype, xp):
if is_torch(xp) and dtype in ("uint16", "uint32", "uint64"):
Expand Down Expand Up @@ -1352,7 +1352,7 @@ def test_laplace01(self, dtype, xp):
assert_array_almost_equal(tmp1 + tmp2, output)

@skip_xp_backends("jax.numpy", reason="output array is read-only",)
@skip_xp_backends("dask.array", reason="output array is read-only.")
@skip_xp_backends("dask.array", reason="output kw doesn't make sense for dask")
@pytest.mark.parametrize('dtype',
["int32", "float32", "float64",
"complex64", "complex128"])
Expand Down Expand Up @@ -1383,7 +1383,7 @@ def test_gaussian_laplace01(self, dtype, xp):
assert_array_almost_equal(tmp1 + tmp2, output)

@skip_xp_backends("jax.numpy", reason="output array is read-only")
@skip_xp_backends("dask.array", reason="output array is read-only.")
@skip_xp_backends("dask.array", reason="output kw doesn't make sense for dask")
@pytest.mark.parametrize('dtype',
["int32", "float32", "float64",
"complex64", "complex128"])
Expand All @@ -1400,7 +1400,7 @@ def test_gaussian_laplace02(self, dtype, xp):
assert_array_almost_equal(tmp1 + tmp2, output)

@skip_xp_backends("jax.numpy", reason="output array is read-only.")
@skip_xp_backends("dask.array", reason="output array is read-only.")
@skip_xp_backends("dask.array", reason="output kw doesn't make sense for dask")
@pytest.mark.parametrize('dtype', types + complex_types)
def test_generic_laplace01(self, dtype, xp):
if is_torch(xp) and dtype in ("uint16", "uint32", "uint64"):
Expand All @@ -1426,7 +1426,7 @@ def derivative2(input, axis, output, mode, cval, a, b):
assert_array_almost_equal(tmp, output)

@skip_xp_backends("jax.numpy", reason="output array is read-only")
@skip_xp_backends("dask.array", reason="output array is read-only.")
@skip_xp_backends("dask.array", reason="output kw doesn't make sense for dask")
@pytest.mark.parametrize('dtype',
["int32", "float32", "float64",
"complex64", "complex128"])
Expand All @@ -1447,7 +1447,7 @@ def test_gaussian_gradient_magnitude01(self, dtype, xp):
xp_assert_close(output, expected, rtol=1e-6, atol=1e-6)

@skip_xp_backends("jax.numpy", reason="output array is read-only")
@skip_xp_backends("dask.array", reason="output array is read-only.")
@skip_xp_backends("dask.array", reason="output kw doesn't make sense for dask")
@pytest.mark.parametrize('dtype',
["int32", "float32", "float64",
"complex64", "complex128"])
Expand Down
12 changes: 4 additions & 8 deletions scipy/ndimage/tests/test_morphology.py
Original file line number Diff line number Diff line change
Expand Up @@ -2510,7 +2510,7 @@ def test_morphological_laplace02(self, xp):
assert_array_almost_equal(output, expected)

@skip_xp_backends("jax.numpy", reason="output array is read-only.")
@skip_xp_backends("dask.array", reason="output array is read-only.")
@skip_xp_backends("dask.array", reason="output kw doesn't make sense for dask")
def test_white_tophat01(self, xp):
array = xp.asarray([[3, 2, 5, 1, 4],
[7, 6, 9, 3, 5],
Expand All @@ -2520,8 +2520,6 @@ def test_white_tophat01(self, xp):
tmp = ndimage.grey_opening(array, footprint=footprint,
structure=structure)
expected = array - tmp
# array created by xp.zeros is non-writeable for dask
# and jax
output = xp.zeros(array.shape, dtype=array.dtype)
ndimage.white_tophat(array, footprint=footprint,
structure=structure, output=output)
Expand Down Expand Up @@ -2566,7 +2564,7 @@ def test_white_tophat03(self, xp):
xp_assert_equal(output, expected)

@skip_xp_backends("jax.numpy", reason="output array is read-only.")
@skip_xp_backends("dask.array", reason="output array is read-only.")
@skip_xp_backends("dask.array", reason="output kw doesn't make sense for dask")
def test_white_tophat04(self, xp):
array = np.eye(5, dtype=bool)
structure = np.ones((3, 3), dtype=bool)
Expand All @@ -2575,13 +2573,11 @@ def test_white_tophat04(self, xp):
structure = xp.asarray(structure)

# Check that type mismatch is properly handled
# This output array is read-only for dask and jax
# TODO: investigate why for dask?
output = xp.empty_like(array, dtype=xp.float64)
ndimage.white_tophat(array, structure=structure, output=output)

@skip_xp_backends("jax.numpy", reason="output array is read-only.")
@skip_xp_backends("dask.array", reason="output array is read-only.")
@skip_xp_backends("dask.array", reason="output kw doesn't make sense for dask")
def test_black_tophat01(self, xp):
array = xp.asarray([[3, 2, 5, 1, 4],
[7, 6, 9, 3, 5],
Expand Down Expand Up @@ -2636,7 +2632,7 @@ def test_black_tophat03(self, xp):
xp_assert_equal(output, expected)

@skip_xp_backends("jax.numpy", reason="output array is read-only.")
@skip_xp_backends("dask.array", reason="output array is read-only.")
@skip_xp_backends("dask.array", reason="output kw doesn't make sense for dask")
def test_black_tophat04(self, xp):
array = xp.asarray(np.eye(5, dtype=bool))
structure = xp.asarray(np.ones((3, 3), dtype=bool))
Expand Down
8 changes: 4 additions & 4 deletions scipy/signal/_filter_design.py
Original file line number Diff line number Diff line change
Expand Up @@ -1779,6 +1779,9 @@ def normalize(b, a):
"""
num, den = b, a

# cast to numpy by hand to avoid libraries like dask
# trying to dispatch this function via NEP 18
den = np.asarray(den)
den = np.atleast_1d(den)
num = np.atleast_2d(_align_nums(num))

Expand All @@ -1791,10 +1794,7 @@ def normalize(b, a):
raise ValueError("Denominator must have at least on nonzero element.")

# Trim leading zeros in denominator, leave at least one.

# cast to numpy by hand to avoid libraries like dask
# trying to dispatch this function via NEP 18
den = np.trim_zeros(np.asarray(den), 'f')
den = np.trim_zeros(den, 'f')

# Normalize transfer function
num, den = num / den[0], den / den[0]
Expand Down
4 changes: 2 additions & 2 deletions scipy/signal/_signaltools.py
Original file line number Diff line number Diff line change
Expand Up @@ -4089,8 +4089,8 @@ def detrend(data: np.ndarray, axis: int = -1,
N = dshape[axis]
# Manually cast to numpy to prevent
# NEP18 dispatching for libraries like dask
bp = np.asarray(np.concatenate(np.atleast_1d(0, bp, N)))
bp = np.sort(np.unique(bp))
bp = np.asarray(bp)
bp = np.sort(np.unique(np.concatenate(np.atleast_1d(0, bp, N))))
if np.any(bp > N):
raise ValueError("Breakpoints must be less than length "
"of data along given axis.")
Expand Down
4 changes: 1 addition & 3 deletions scipy/signal/tests/test_signaltools.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,7 +291,6 @@ def test_dtype_deprecation(self, xp):
convolve(a, b)



@skip_xp_backends(cpu_only=True, exceptions=['cupy'])
class TestConvolve2d:

Expand Down Expand Up @@ -512,7 +511,6 @@ def test_large_array(self, xp):
assert fails[0].size == 0



@skip_xp_backends(cpu_only=True, exceptions=['cupy'])
class TestFFTConvolve:

Expand Down Expand Up @@ -975,7 +973,7 @@ def gen_oa_shapes_eq(sizes):


@skip_xp_backends("jax.numpy", reason="fails all around")
@skip_xp_backends("dask.array", reason="wrong answer")
@xfail_xp_backends("dask.array", reason="wrong answer")
class TestOAConvolve:
@pytest.mark.slow()
@pytest.mark.parametrize('shape_a_0, shape_b_0',
Expand Down