Skip to content

Commit a6bb413

Browse files
committed
WIP tests
1 parent ccd4c19 commit a6bb413

File tree

1 file changed

+55
-6
lines changed

1 file changed

+55
-6
lines changed

tests/test_lazy.py

Lines changed: 55 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,19 @@
33

44
import numpy as np
55
import pytest
6+
from array_api_compat import array_namespace
67

8+
import array_api_extra as xpx # Let some tests bypass lazy_xp_function
79
from array_api_extra import lazy_apply
810
from array_api_extra._lib import Backend
911
from array_api_extra._lib._testing import xp_assert_equal
1012
from array_api_extra._lib._utils._typing import Array
1113
from array_api_extra.testing import lazy_xp_function
1214

15+
lazy_xp_function(
16+
lazy_apply, static_argnames=("func", "shape", "dtype", "as_numpy", "xp")
17+
)
18+
1319
as_numpy = pytest.mark.parametrize(
1420
"as_numpy",
1521
[
@@ -26,18 +32,59 @@
2632

2733

2834
@as_numpy
29-
def test_lazy_apply_simple(xp: ModuleType, as_numpy: bool):
30-
pytest.skip("TODO")
35+
@pytest.mark.parametrize("shape", [(2,), (3, 2)])
36+
@pytest.mark.parametrize("dtype", ["int32", "float64"])
37+
def test_lazy_apply_simple(
38+
xp: ModuleType, library: Backend, shape: tuple[int, ...], dtype: str, as_numpy: bool
39+
):
40+
def f(x: Array) -> Array:
41+
xp2 = array_namespace(x)
42+
if as_numpy or library in (Backend.NUMPY_READONLY, Backend.DASK):
43+
assert isinstance(x, np.ndarray)
44+
else:
45+
assert xp2 is xp
46+
47+
y = xp2.broadcast_to(xp2.astype(x + 1, getattr(xp2, dtype)), shape)
48+
return xp2.asarray(y, copy=True) # Torch: ensure writeable numpy array
49+
50+
x = xp.asarray([1, 2], dtype=xp.int16)
51+
expect = xp.broadcast_to(xp.astype(x + 1, getattr(xp, dtype)), shape)
52+
actual = lazy_apply(f, x, shape=shape, dtype=getattr(xp, dtype), as_numpy=as_numpy)
53+
xp_assert_equal(actual, expect)
3154

3255

3356
@as_numpy
3457
def test_lazy_apply_broadcast(xp: ModuleType, as_numpy: bool):
35-
pytest.skip("TODO")
58+
def f(x: Array, y: Array) -> Array:
59+
return x + y
60+
61+
x = xp.asarray([1, 2], dtype=xp.int16)
62+
y = xp.asarray([[4], [5], [6]], dtype=xp.int32)
63+
z = lazy_apply(f, x, y, as_numpy=as_numpy)
64+
xp_assert_equal(z, x + y)
3665

3766

3867
@as_numpy
3968
def test_lazy_apply_multi_output(xp: ModuleType, as_numpy: bool):
40-
pytest.skip("TODO")
69+
def f(x: Array) -> tuple[Array, Array]:
70+
xp2 = array_namespace(x)
71+
y = x + xp2.asarray(2, dtype=xp2.int8) # Sparse: bad dtype propagation
72+
z = xp2.broadcast_to(xp2.astype(x + 1, xp2.int16), (3, 2))
73+
z = xp2.asarray(z, copy=True) # Torch: ensure writeable numpy array
74+
return y, z
75+
76+
x = xp.asarray([1, 2], dtype=xp.int8)
77+
expect = (
78+
xp.asarray([3, 4], dtype=xp.int8),
79+
xp.asarray([[2, 3], [2, 3], [2, 3]], dtype=xp.int16),
80+
)
81+
actual = lazy_apply(
82+
f, x, shape=((2,), (3, 2)), dtype=(xp.int8, xp.int16), as_numpy=as_numpy
83+
)
84+
assert isinstance(actual, tuple)
85+
assert len(actual) == 2
86+
xp_assert_equal(actual[0], expect[0])
87+
xp_assert_equal(actual[1], expect[1])
4188

4289

4390
def test_lazy_apply_core_indices(da: ModuleType):
@@ -96,7 +143,8 @@ def eager(
96143
assert isinstance(scalar, int)
97144
return x + 1 # type: ignore[operator]
98145

99-
return lazy_apply( # pyright: ignore[reportCallIssue]
146+
# Use explicit namespace to bypass monkey-patching by lazy_xp_function
147+
return xpx.lazy_apply( # pyright: ignore[reportCallIssue]
100148
eager,
101149
x,
102150
# These kwargs can and should be passed through jax.pure_callback
@@ -136,7 +184,8 @@ def eager(_: Array) -> Array:
136184
msg = "Hello World"
137185
raise CustomError(msg)
138186

139-
return lazy_apply(eager, x, shape=x.shape, dtype=x.dtype)
187+
# Use explicit namespace to bypass monkey-patching by lazy_xp_function
188+
return xpx.lazy_apply(eager, x, shape=x.shape, dtype=x.dtype)
140189

141190

142191
# jax.pure_callback does not support raising

0 commit comments

Comments
 (0)