|
3 | 3 |
|
4 | 4 | import numpy as np |
5 | 5 | import pytest |
6 | | -from numpy import min as numpy_min |
7 | 6 |
|
8 | 7 | from array_api_extra._lib import Backend |
9 | 8 | from array_api_extra._lib._testing import xp_assert_close, xp_assert_equal |
@@ -205,21 +204,30 @@ def test_lazy_xp_function_static_params(xp: ModuleType, func: Callable[..., Arra |
205 | 204 | xp_assert_equal(func(x, n=1, flag=True), xp.asarray([2.0, 4.0])) |
206 | 205 |
|
207 | 206 |
|
208 | | -lazy_xp_function(numpy_min, static_argnames="axis") |
| 207 | +try: |
| 208 | + # Test an arbitrary Cython ufunc (@cython.vectorize). |
| 209 | + # When SCIPY_ARRAY_API is not set, this is the same as |
| 210 | + # scipy.special.erf. |
| 211 | + from scipy.special._ufuncs import erf # type: ignore[import-not-found] |
209 | 212 |
|
| 213 | + lazy_xp_function(erf) # pyright: ignore[reportUnknownArgumentType] |
| 214 | +except ImportError: |
| 215 | + erf = None |
210 | 216 |
|
211 | | -def test_lazy_xp_function_ufunc(xp: ModuleType, library: Backend): |
212 | | - x = xp.asarray([[1, 4], [3, 2]]) |
213 | | - if library in (Backend.ARRAY_API_STRICT, Backend.TORCH, Backend.JAX): |
214 | | - # array-api-strict, torch and jax don't define __array_ufunc__ |
215 | | - # numpy ufuncs can't auto-convert to numpy from torch |
| 217 | + |
| 218 | +@pytest.mark.filterwarnings("ignore:__array_wrap__:DeprecationWarning") # torch |
| 219 | +def test_lazy_xp_function_cython_ufuncs(xp: ModuleType, library: Backend): |
| 220 | + pytest.importorskip("scipy") |
| 221 | + assert erf is not None |
| 222 | + x = xp.asarray([6.0, 7.0]) |
| 223 | + if library in (Backend.ARRAY_API_STRICT, Backend.JAX): |
216 | 224 | # array-api-strict arrays are auto-converted to numpy |
217 | 225 | # eager jax arrays are auto-converted to numpy in eager jax |
218 | 226 | # and fail in jax.jit (which lazy_xp_function tests here) |
219 | 227 | with pytest.raises((TypeError, AssertionError)): |
220 | | - xp_assert_equal(numpy_min(x, axis=0), xp.asarray([1, 2])) |
| 228 | + xp_assert_equal(erf(x), xp.asarray([1.0, 1.0])) |
221 | 229 | else: |
222 | 230 | # cupy, dask and sparse define __array_ufunc__ and dispatch accordingly |
223 | 231 | # note that when sparse reduces to scalar it returns a np.generic, which |
224 | 232 | # would make xp_assert_equal fail. |
225 | | - xp_assert_equal(numpy_min(x, axis=0), xp.asarray([1, 2])) |
| 233 | + xp_assert_equal(erf(x), xp.asarray([1.0, 1.0])) |
0 commit comments