Skip to content

Commit ec480f7

Browse files
committed
tests
1 parent 92e703b commit ec480f7

File tree

2 files changed

+114
-17
lines changed

2 files changed

+114
-17
lines changed

src/array_api_extra/_lib/_lazy.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -195,6 +195,9 @@ def lazy_apply( # type: ignore[valid-type] # numpydoc ignore=GL07,SA04
195195
dask.array.map_overlap
196196
dask.array.blockwise
197197
"""
198+
if not args:
199+
msg = "Must have at least one argument array"
200+
raise ValueError(msg)
198201
if xp is None:
199202
xp = array_namespace(*args)
200203

@@ -204,9 +207,13 @@ def lazy_apply( # type: ignore[valid-type] # numpydoc ignore=GL07,SA04
204207
multi_output = False
205208

206209
if shape is None:
207-
shapes = [xp.broadcast_shapes(*(arg.shape for arg in args))]
208-
elif isinstance(shape, tuple) and all(isinstance(s, int | None) for s in shape):
209-
shapes = [shape] # pyright: ignore[reportAssignmentType]
210+
import numpy as np # DNM
211+
212+
shapes = [np.broadcast_shapes(*(arg.shape for arg in args))]
213+
elif all(isinstance(s, int | None) for s in shape):
214+
# Do not test for shape to be a tuple
215+
# https://github.com/data-apis/array-api/issues/891#issuecomment-2637430522
216+
shapes = [cast(tuple[int | None, ...], shape)]
210217
else:
211218
shapes = list(shape) # type: ignore[arg-type] # pyright: ignore[reportAssignmentType]
212219
multi_output = True

tests/test_lazy.py

Lines changed: 104 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,14 @@
33

44
import numpy as np
55
import pytest
6-
from array_api_compat import array_namespace
76

87
import array_api_extra as xpx # Let some tests bypass lazy_xp_function
98
from array_api_extra import lazy_apply
109
from array_api_extra._lib import Backend
1110
from array_api_extra._lib._testing import xp_assert_equal
12-
from array_api_extra._lib._utils._typing import Array
11+
from array_api_extra._lib._utils import _compat
12+
from array_api_extra._lib._utils._compat import array_namespace
13+
from array_api_extra._lib._utils._typing import Array, Device
1314
from array_api_extra.testing import lazy_xp_function
1415

1516
lazy_xp_function(
@@ -55,6 +56,8 @@ def f(x: Array) -> Array:
5556

5657
@as_numpy
5758
def test_lazy_apply_broadcast(xp: ModuleType, as_numpy: bool):
59+
"""Test that default shape and dtype are broadcasted from the inputs."""
60+
5861
def f(x: Array, y: Array) -> Array:
5962
return x + y
6063

@@ -88,31 +91,117 @@ def f(x: Array) -> tuple[Array, Array]:
8891

8992

9093
def test_lazy_apply_core_indices(da: ModuleType):
91-
"""Test that a func that performs reductions along axes does so
94+
"""
95+
Test that a function that performs reductions along axes does so
9296
globally and not locally to each Dask chunk.
9397
"""
94-
pytest.skip("TODO")
98+
99+
def f(x: Array) -> Array:
100+
return x.sum(axis=0) + x
101+
102+
x_np = np.arange(15).reshape(5, 3)
103+
expect = da.asarray(f(x_np))
104+
x_da = da.asarray(x_np).rechunk(3)
105+
106+
# A naive map_blocks fails because it applies f to each chunk separately,
107+
# but f needs to reduce along axis 0 which is broken into multiple chunks.
108+
# axis 0 is a "core axis" or "core index" (from xarray.apply_ufunc's
109+
# "core dimension").
110+
with pytest.raises(AssertionError):
111+
xp_assert_equal(da.map_blocks(f, x_da), expect)
112+
113+
xp_assert_equal(lazy_apply(f, x_da), expect)
95114

96115

97116
def test_lazy_apply_dont_run_on_meta(da: ModuleType):
98117
"""Test that Dask won't try running func on the meta array,
99118
as it may have minimum size requirements.
100119
"""
101-
pytest.skip("TODO")
102120

121+
def f(x: Array) -> Array:
122+
assert x.size
123+
return x + 1
103124

104-
def test_lazy_apply_none_shape(da: ModuleType):
105-
pytest.skip("TODO")
125+
x = da.arange(10)
126+
assert not x._meta.size
127+
y = lazy_apply(f, x)
128+
xp_assert_equal(y, x + 1)
106129

107130

108-
@as_numpy
109-
def test_lazy_apply_device(xp: ModuleType, as_numpy: bool):
110-
pytest.skip("TODO")
131+
@pytest.mark.xfail_xp_backend(Backend.JAX, reason="unknown shape")
132+
def test_lazy_apply_none_shape_in_args(xp: ModuleType, library: Backend):
133+
x = xp.asarray([1, 1, 2, 2, 2])
111134

135+
xp2 = np if library is Backend.DASK else xp
112136

113-
@as_numpy
114-
def test_lazy_apply_no_args(xp: ModuleType, as_numpy: bool):
115-
pytest.skip("TODO")
137+
# Single output
138+
values = lazy_apply(xp2.unique_values, x, shape=(None,))
139+
xp_assert_equal(values, xp.asarray([1, 2]))
140+
141+
# Multi output
142+
int_type = xp.asarray(0).dtype
143+
values, counts = lazy_apply(
144+
xp2.unique_counts,
145+
x,
146+
shape=((None,), (None,)),
147+
dtype=(x.dtype, int_type),
148+
)
149+
xp_assert_equal(values, xp.asarray([1, 2]))
150+
xp_assert_equal(counts, xp.asarray([2, 3]))
151+
152+
153+
def check_lazy_apply_none_shape_broadcast(x: Array) -> Array:
154+
def f(x: Array) -> Array:
155+
return x
156+
157+
x = x[x > 1]
158+
return lazy_apply(f, x)
159+
160+
161+
lazy_xp_function(check_lazy_apply_none_shape_broadcast)
162+
163+
164+
@pytest.mark.xfail_xp_backend(Backend.SPARSE, reason="bool mask")
165+
@pytest.mark.xfail_xp_backend(Backend.JAX, reason="unknown shape")
166+
def test_lazy_apply_none_shape_broadcast(xp: ModuleType):
167+
"""Broadcast from input array with unknown shape"""
168+
x = xp.asarray([1, 2, 2])
169+
actual = check_lazy_apply_none_shape_broadcast(x)
170+
xp_assert_equal(actual, xp.asarray([2, 2]))
171+
172+
173+
@pytest.mark.parametrize(
174+
"as_numpy",
175+
[
176+
False,
177+
pytest.param(
178+
True,
179+
marks=[
180+
pytest.mark.skip_xp_backend(
181+
Backend.ARRAY_API_STRICT, reason="device->host copy"
182+
),
183+
pytest.mark.skip_xp_backend(Backend.CUPY, reason="device->host copy"),
184+
pytest.mark.skip_xp_backend(Backend.SPARSE, reason="densification"),
185+
],
186+
),
187+
],
188+
)
189+
def test_lazy_apply_device(xp: ModuleType, as_numpy: bool, device: Device):
190+
def f(x: Array) -> Array:
191+
xp2 = array_namespace(x)
192+
# Deliberately forgetting to add device here to test that the
193+
# output is transferred to the right device. This is necessary when
194+
# as_numpy=True anyway.
195+
return xp2.zeros(x.shape, dtype=x.dtype)
196+
197+
x = xp.asarray([1, 2], device=device)
198+
y = lazy_apply(f, x, as_numpy=as_numpy)
199+
assert _compat.device(y) == device
200+
201+
202+
def test_lazy_apply_no_args(xp: ModuleType):
203+
with pytest.raises(ValueError, match="at least one argument"):
204+
lazy_apply(lambda: xp.zeros(1), shape=(1,), dtype=xp.zeros(1).dtype, xp=xp)
116205

117206

118207
class NT(NamedTuple):
@@ -128,7 +217,8 @@ def eager(
128217
scalar: int,
129218
) -> Array:
130219
assert isinstance(x, expect_cls)
131-
assert int(x) == 0 # JAX will crash if x isn't material
220+
# JAX will crash if x isn't material
221+
assert int(x) == 0 # type: ignore[call-overload]
132222
# Did we re-wrap the namedtuple correctly, or did it get
133223
# accidentally changed to a basic tuple?
134224
assert isinstance(z["foo"], NT)

0 commit comments

Comments
 (0)