Skip to content

Commit a56b885

Browse files
MAINT: _lib: co-vendor array-api-extra and array-api-compat (scipy#22062)
* DNM array_api_extra depends on array_api_compat * Fix build * nits * bump array-api-extra * Switch to array-api-compat main * DNM switch array-api-extra to crusaderky/init * copy __init__.py * Repoint array-api-extra to main * set array-api-extra and array-api-compat to latest release --------- Co-authored-by: Lucas Colley <[email protected]>
1 parent 33262e3 commit a56b885

File tree

5 files changed

+32
-5
lines changed

5 files changed

+32
-5
lines changed

.github/workflows/array_api.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ env:
3232
-t scipy.stats
3333
-t scipy.ndimage
3434
-t scipy.integrate.tests.test_quadrature
35-
-t scipy/signal/tests/test_signaltools.py
35+
-t scipy.signal.tests.test_signaltools
3636
3737
concurrency:
3838
group: ${{ github.workflow }}-${{ github.head_ref || github.run_id }}
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
# DO NOT RENAME THIS FILE
2+
# This is a hook for array_api_extra/src/array_api_extra/_lib/_compat.py
3+
# to override functions of array_api_compat.
4+
5+
from .array_api_compat import * # noqa: F403
6+
from ._array_api import array_namespace as scipy_array_namespace
7+
8+
# overrides array_api_compat.array_namespace inside array-api-extra
9+
array_namespace = scipy_array_namespace # type: ignore[assignment]

scipy/_lib/meson.build

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,7 @@ py3.extension_module('messagestream',
113113
python_sources = [
114114
'__init__.py',
115115
'_array_api.py',
116+
'_array_api_compat_vendor.py',
116117
'_array_api_no_0d.py',
117118
'_bunch.py',
118119
'_ccallback.py',
@@ -216,11 +217,21 @@ py3.install_sources(
216217
# `array_api_extra` install to simplify import path;
217218
# should be updated whenever new files are added to `array_api_extra`
218219

220+
py3.install_sources(
221+
[
222+
'array_api_extra/src/array_api_extra/_lib/__init__.py',
223+
'array_api_extra/src/array_api_extra/_lib/_compat.py',
224+
'array_api_extra/src/array_api_extra/_lib/_compat.pyi',
225+
'array_api_extra/src/array_api_extra/_lib/_utils.py',
226+
'array_api_extra/src/array_api_extra/_lib/_typing.py',
227+
],
228+
subdir: 'scipy/_lib/array_api_extra/_lib',
229+
)
230+
219231
py3.install_sources(
220232
[
221233
'array_api_extra/src/array_api_extra/__init__.py',
222234
'array_api_extra/src/array_api_extra/_funcs.py',
223-
'array_api_extra/src/array_api_extra/_typing.py',
224235
],
225236
subdir: 'scipy/_lib/array_api_extra',
226237
)

scipy/_lib/tests/test_array_api.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
_GLOBAL_CONFIG, array_namespace, _asarray, xp_copy, xp_assert_equal, is_numpy,
77
np_compat,
88
)
9+
from scipy._lib import array_api_extra as xpx
910
from scipy._lib._array_api_no_0d import xp_assert_equal as xp_assert_equal_no_0d
1011

1112
skip_xp_backends = pytest.mark.skip_xp_backends
@@ -54,6 +55,14 @@ def test_array_likes(self):
5455
array_namespace(1, 2, 3)
5556
array_namespace(1)
5657

58+
def test_array_api_extra_hook(self):
59+
"""Test that the `array_namespace` function used by
60+
array-api-extra has been overridden by scipy
61+
"""
62+
msg = "only boolean and numerical dtypes are supported"
63+
with pytest.raises(TypeError, match=msg):
64+
xpx.atleast_nd("abc", ndim=0)
65+
5766
@skip_xp_backends('jax.numpy',
5867
reason="JAX arrays do not support item assignment")
5968
@pytest.mark.usefixtures("skip_xp_backends")
@@ -112,7 +121,6 @@ def test_strict_checks(self, xp, dtype, shape):
112121
with pytest.raises(AssertionError, match="Array-ness does not match."):
113122
xp_assert_equal(x, y, **options)
114123

115-
116124
@array_api_compatible
117125
def test_check_scalar(self, xp):
118126
if not is_numpy(xp):
@@ -147,7 +155,6 @@ def test_check_scalar(self, xp):
147155
# as an alternative to `check_0d=False`, explicitly expect scalar
148156
xp_assert_equal(xp.float64(0), xp.asarray(0.)[()])
149157

150-
151158
@array_api_compatible
152159
def test_check_scalar_no_0d(self, xp):
153160
if not is_numpy(xp):

0 commit comments

Comments
 (0)