|
3 | 3 |
|
4 | 4 | import numpy as np |
5 | 5 | import pytest |
| 6 | +from array_api_compat import array_namespace |
6 | 7 |
|
| 8 | +import array_api_extra as xpx # Let some tests bypass lazy_xp_function |
7 | 9 | from array_api_extra import lazy_apply |
8 | 10 | from array_api_extra._lib import Backend |
9 | 11 | from array_api_extra._lib._testing import xp_assert_equal |
10 | 12 | from array_api_extra._lib._utils._typing import Array |
11 | 13 | from array_api_extra.testing import lazy_xp_function |
12 | 14 |
|
| 15 | +lazy_xp_function( |
| 16 | + lazy_apply, static_argnames=("func", "shape", "dtype", "as_numpy", "xp") |
| 17 | +) |
| 18 | + |
13 | 19 | as_numpy = pytest.mark.parametrize( |
14 | 20 | "as_numpy", |
15 | 21 | [ |
|
26 | 32 |
|
27 | 33 |
|
28 | 34 | @as_numpy |
29 | | -def test_lazy_apply_simple(xp: ModuleType, as_numpy: bool): |
30 | | - pytest.skip("TODO") |
| 35 | +@pytest.mark.parametrize("shape", [(2,), (3, 2)]) |
| 36 | +@pytest.mark.parametrize("dtype", ["int32", "float64"]) |
| 37 | +def test_lazy_apply_simple( |
| 38 | + xp: ModuleType, library: Backend, shape: tuple[int, ...], dtype: str, as_numpy: bool |
| 39 | +): |
| 40 | + def f(x: Array) -> Array: |
| 41 | + xp2 = array_namespace(x) |
| 42 | + if as_numpy or library in (Backend.NUMPY_READONLY, Backend.DASK): |
| 43 | + assert isinstance(x, np.ndarray) |
| 44 | + else: |
| 45 | + assert xp2 is xp |
| 46 | + |
| 47 | + y = xp2.broadcast_to(xp2.astype(x + 1, getattr(xp2, dtype)), shape) |
| 48 | + return xp2.asarray(y, copy=True) # Torch: ensure writeable numpy array |
| 49 | + |
| 50 | + x = xp.asarray([1, 2], dtype=xp.int16) |
| 51 | + expect = xp.broadcast_to(xp.astype(x + 1, getattr(xp, dtype)), shape) |
| 52 | + actual = lazy_apply(f, x, shape=shape, dtype=getattr(xp, dtype), as_numpy=as_numpy) |
| 53 | + xp_assert_equal(actual, expect) |
31 | 54 |
|
32 | 55 |
|
33 | 56 | @as_numpy |
34 | 57 | def test_lazy_apply_broadcast(xp: ModuleType, as_numpy: bool): |
35 | | - pytest.skip("TODO") |
| 58 | + def f(x: Array, y: Array) -> Array: |
| 59 | + return x + y |
| 60 | + |
| 61 | + x = xp.asarray([1, 2], dtype=xp.int16) |
| 62 | + y = xp.asarray([[4], [5], [6]], dtype=xp.int32) |
| 63 | + z = lazy_apply(f, x, y, as_numpy=as_numpy) |
| 64 | + xp_assert_equal(z, x + y) |
36 | 65 |
|
37 | 66 |
|
38 | 67 | @as_numpy |
39 | 68 | def test_lazy_apply_multi_output(xp: ModuleType, as_numpy: bool): |
40 | | - pytest.skip("TODO") |
| 69 | + def f(x: Array) -> tuple[Array, Array]: |
| 70 | + xp2 = array_namespace(x) |
| 71 | + y = x + xp2.asarray(2, dtype=xp2.int8) # Sparse: bad dtype propagation |
| 72 | + z = xp2.broadcast_to(xp2.astype(x + 1, xp2.int16), (3, 2)) |
| 73 | + z = xp2.asarray(z, copy=True) # Torch: ensure writeable numpy array |
| 74 | + return y, z |
| 75 | + |
| 76 | + x = xp.asarray([1, 2], dtype=xp.int8) |
| 77 | + expect = ( |
| 78 | + xp.asarray([3, 4], dtype=xp.int8), |
| 79 | + xp.asarray([[2, 3], [2, 3], [2, 3]], dtype=xp.int16), |
| 80 | + ) |
| 81 | + actual = lazy_apply( |
| 82 | + f, x, shape=((2,), (3, 2)), dtype=(xp.int8, xp.int16), as_numpy=as_numpy |
| 83 | + ) |
| 84 | + assert isinstance(actual, tuple) |
| 85 | + assert len(actual) == 2 |
| 86 | + xp_assert_equal(actual[0], expect[0]) |
| 87 | + xp_assert_equal(actual[1], expect[1]) |
41 | 88 |
|
42 | 89 |
|
43 | 90 | def test_lazy_apply_core_indices(da: ModuleType): |
@@ -96,7 +143,8 @@ def eager( |
96 | 143 | assert isinstance(scalar, int) |
97 | 144 | return x + 1 # type: ignore[operator] |
98 | 145 |
|
99 | | - return lazy_apply( # pyright: ignore[reportCallIssue] |
| 146 | + # Use explicit namespace to bypass monkey-patching by lazy_xp_function |
| 147 | + return xpx.lazy_apply( # pyright: ignore[reportCallIssue] |
100 | 148 | eager, |
101 | 149 | x, |
102 | 150 | # These kwargs can and should be passed through jax.pure_callback |
@@ -136,7 +184,8 @@ def eager(_: Array) -> Array: |
136 | 184 | msg = "Hello World" |
137 | 185 | raise CustomError(msg) |
138 | 186 |
|
139 | | - return lazy_apply(eager, x, shape=x.shape, dtype=x.dtype) |
| 187 | + # Use explicit namespace to bypass monkey-patching by lazy_xp_function |
| 188 | + return xpx.lazy_apply(eager, x, shape=x.shape, dtype=x.dtype) |
140 | 189 |
|
141 | 190 |
|
142 | 191 | # jax.pure_callback does not support raising |
|
0 commit comments