Skip to content

Commit 1dbe2d0

Browse files
committed
Self-review
1 parent d70e03c commit 1dbe2d0

File tree

2 files changed

+60
-17
lines changed

2 files changed

+60
-17
lines changed

src/array_api_extra/_lib/_lazy.py

Lines changed: 17 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ def lazy_apply( # type: ignore[valid-type]
6464

6565

6666
def lazy_apply( # type: ignore[valid-type] # numpydoc ignore=GL07,SA04
67-
func: Callable[P, Array | Sequence[ArrayLike]],
67+
func: Callable[P, ArrayLike | Sequence[ArrayLike]],
6868
*args: Array,
6969
shape: tuple[int | None, ...] | Sequence[tuple[int | None, ...]] | None = None,
7070
dtype: DType | Sequence[DType] | None = None,
@@ -90,13 +90,13 @@ def lazy_apply( # type: ignore[valid-type] # numpydoc ignore=GL07,SA04
9090
It must return either a single array-like or a sequence of array-likes.
9191
9292
`func` must be a pure function, i.e. without side effects, as depending on the
93-
backend it may be executed more than once.
93+
backend it may be executed more than once or never.
9494
*args : Array
9595
One or more Array API compliant arrays.
9696
9797
If `as_numpy=True`, you need to be able to apply :func:`numpy.asarray` to them
9898
to convert them to numpy; read notes below about specific backends.
99-
shape : tuple[int | None, ...] | Sequence[tuple[int, ...]], optional
99+
shape : tuple[int | None, ...] | Sequence[tuple[int | None, ...]], optional
100100
Output shape or sequence of output shapes, one for each output of `func`.
101101
Default: assume single output and broadcast shapes of the input arrays.
102102
dtype : DType | Sequence[DType], optional
@@ -119,34 +119,34 @@ def lazy_apply( # type: ignore[valid-type] # numpydoc ignore=GL07,SA04
119119
Array | tuple[Array, ...]
120120
The result(s) of `func` applied to the input arrays, wrapped in the same
121121
array namespace as the inputs.
122-
If shape is omitted or a `tuple[int | None, ...]`, this is a single array.
123-
Otherwise, it's a tuple of arrays.
122+
If shape is omitted or a single `tuple[int | None, ...]`, return a single array.
123+
Otherwise, return a tuple of arrays.
124124
125125
Notes
126126
-----
127127
JAX
128128
This allows applying eager functions to jitted JAX arrays, which are lazy.
129129
The function won't be applied until the JAX array is materialized.
130-
When running inside `jax.jit`, `shape` must be fully known, i.e. it cannot
130+
When running inside ``jax.jit``, `shape` must be fully known, i.e. it cannot
131131
contain any `None` elements.
132132
133133
.. warning::
134134
135-
`func` must never raise if it's run inside `jax.jit`, as its behavior is
135+
`func` must never raise inside ``jax.jit``, as the resulting behavior is
136136
undefined.
137137
138138
Using this with `as_numpy=False` is particularly useful to apply non-jittable
139139
JAX functions to arrays on GPU devices.
140-
If `as_numpy=True`, the :doc:`jax:transfer_guard` may prevent arrays on a GPU
140+
If ``as_numpy=True``, the :doc:`jax:transfer_guard` may prevent arrays on a GPU
141141
device from being transferred back to CPU. This is treated as an implicit
142142
transfer.
143143
144144
PyTorch, CuPy
145-
If `as_numpy=True`, these backends raise by default if you attempt to convert
145+
If ``as_numpy=True``, these backends raise by default if you attempt to convert
146146
arrays on a GPU device to NumPy.
147147
148148
Sparse
149-
If `as_numpy=True`, by default sparse prevents implicit densification through
149+
If ``as_numpy=True``, by default sparse prevents implicit densification through
150150
:func:`numpy.asarray`. `This safety mechanism can be disabled
151151
<https://sparse.pydata.org/en/stable/operations.html#package-configuration>`_.
152152
@@ -171,21 +171,21 @@ def lazy_apply( # type: ignore[valid-type] # numpydoc ignore=GL07,SA04
171171
`lazy_apply`.
172172
173173
Dask wrapping around other backends
174-
If `as_numpy=False`, `func` will receive in input eager arrays of the meta
175-
namespace, as defined by the `._meta` attribute of the input Dask arrays.
174+
If ``as_numpy=False``, `func` will receive in input eager arrays of the meta
175+
namespace, as defined by the ``._meta`` attribute of the input Dask arrays.
176176
The outputs of `func` will be wrapped by the meta namespace, and then wrapped
177177
again by Dask.
178178
179179
Raises
180180
------
181181
jax.errors.TracerArrayConversionError
182-
When `xp=jax.numpy`, `shape` is unknown (it contains None on one or more axes)
183-
and this function was called inside `jax.jit`.
182+
When ``xp=jax.numpy``, `shape` is unknown (it contains None on one or more axes)
183+
and this function was called inside ``jax.jit``.
184184
RuntimeError
185-
When `xp=sparse` and auto-densification is disabled.
185+
When ``xp=sparse`` and auto-densification is disabled.
186186
Exception (backend-specific)
187187
When the backend disallows implicit device to host transfers and the input
188-
arrays are on a device, e.g. on GPU.
188+
arrays are on a non-CPU device, e.g. on GPU.
189189
190190
See Also
191191
--------
@@ -237,6 +237,7 @@ def lazy_apply( # type: ignore[valid-type] # numpydoc ignore=GL07,SA04
237237
raise ValueError(msg)
238238
del shape
239239
del dtype
240+
# End of shape and dtype parsing
240241

241242
# Backend-specific branches
242243
if is_dask_namespace(xp):

tests/test_lazy.py

Lines changed: 43 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -170,6 +170,27 @@ def f(x: Array) -> Array:
170170
xp_assert_equal(y, x + 1)
171171

172172

173+
def test_lazy_apply_dask_non_numpy_meta(da: ModuleType):
174+
"""Test dask wrapping around a meta-namespace other than numpy."""
175+
# At the moment of writing, of all Array API namespaces cupy is
176+
# the only one that Dask supports.
177+
# For this reason, we can only test as_numpy=False since
178+
# np.asarray(cp.Array) is blocked by the transfer guard.
179+
180+
cp = pytest.importorskip("cupy")
181+
cp = array_namespace(cp.empty(0))
182+
x_cp = cp.asarray([1, 2, 3])
183+
x_da = da.asarray([1, 2, 3]).map_blocks(cp.asarray)
184+
assert array_namespace(x_da._meta) is cp
185+
186+
def f(x: Array) -> Array:
187+
return x + 1
188+
189+
y = lazy_apply(f, x_da)
190+
assert array_namespace(y._meta) is cp
191+
xp_assert_equal(y.compute(), x_cp + 1)
192+
193+
173194
@pytest.mark.xfail_xp_backend(Backend.JAX, reason="unknown shape")
174195
def test_lazy_apply_none_shape_in_args(xp: ModuleType, library: Backend):
175196
x = xp.asarray([1, 1, 2, 2, 2])
@@ -241,6 +262,27 @@ def f(x: Array) -> Array:
241262
assert _compat.device(y) == device
242263

243264

265+
def test_lazy_apply_arraylike(xp: ModuleType):
266+
"""Wrapped func returns an array-like"""
267+
x = xp.asarray([1, 2, 3])
268+
269+
# Single output
270+
def f(x: Array) -> int:
271+
return x.shape[0] # type: ignore[no-any-return]
272+
273+
expect = xp.asarray(3)
274+
actual = lazy_apply(f, x, shape=(), dtype=expect.dtype)
275+
xp_assert_equal(actual, expect)
276+
277+
# Multi output
278+
def g(x: Array) -> tuple[int, ...]:
279+
return x.shape[0], x.shape
280+
281+
actual = lazy_apply(g, x, shape=((), (1,)), dtype=(expect.dtype, expect.dtype))
282+
xp_assert_equal(actual[0], xp.asarray(3))
283+
xp_assert_equal(actual[1], xp.asarray([3]))
284+
285+
244286
class NT(NamedTuple):
245287
a: Array
246288

@@ -291,7 +333,7 @@ def eager(
291333

292334

293335
@as_numpy
294-
def test_lazy_apply_kwargs(xp: ModuleType, library: Backend, as_numpy: bool) -> None:
336+
def test_lazy_apply_kwargs(xp: ModuleType, library: Backend, as_numpy: bool):
295337
"""When as_numpy=True, search and replace arrays in the (nested) keywords arguments
296338
with numpy arrays, and leave the rest untouched."""
297339
expect_cls = (

0 commit comments

Comments
 (0)