@@ -64,7 +64,7 @@ def lazy_apply( # type: ignore[valid-type]
6464
6565
6666def lazy_apply ( # type: ignore[valid-type] # numpydoc ignore=GL07,SA04
67- func : Callable [P , Array | Sequence [ArrayLike ]],
67+ func : Callable [P , ArrayLike | Sequence [ArrayLike ]],
6868 * args : Array ,
6969 shape : tuple [int | None , ...] | Sequence [tuple [int | None , ...]] | None = None ,
7070 dtype : DType | Sequence [DType ] | None = None ,
@@ -90,13 +90,13 @@ def lazy_apply( # type: ignore[valid-type] # numpydoc ignore=GL07,SA04
9090 It must return either a single array-like or a sequence of array-likes.
9191
9292 `func` must be a pure function, i.e. without side effects, as depending on the
93- backend it may be executed more than once.
93+ backend it may be executed more than once or never .
9494 *args : Array
9595 One or more Array API compliant arrays.
9696
9797 If `as_numpy=True`, you need to be able to apply :func:`numpy.asarray` to them
9898 to convert them to numpy; read notes below about specific backends.
99- shape : tuple[int | None, ...] | Sequence[tuple[int, ...]], optional
99+ shape : tuple[int | None, ...] | Sequence[tuple[int | None , ...]], optional
100100 Output shape or sequence of output shapes, one for each output of `func`.
101101 Default: assume single output and broadcast shapes of the input arrays.
102102 dtype : DType | Sequence[DType], optional
@@ -119,34 +119,34 @@ def lazy_apply( # type: ignore[valid-type] # numpydoc ignore=GL07,SA04
119119 Array | tuple[Array, ...]
120120 The result(s) of `func` applied to the input arrays, wrapped in the same
121121 array namespace as the inputs.
122- If shape is omitted or a `tuple[int | None, ...]`, this is a single array.
123- Otherwise, it's a tuple of arrays.
122+ If shape is omitted or a single `tuple[int | None, ...]`, return a single array.
123+ Otherwise, return a tuple of arrays.
124124
125125 Notes
126126 -----
127127 JAX
128128 This allows applying eager functions to jitted JAX arrays, which are lazy.
129129 The function won't be applied until the JAX array is materialized.
130- When running inside `jax.jit`, `shape` must be fully known, i.e. it cannot
130+ When running inside `` jax.jit` `, `shape` must be fully known, i.e. it cannot
131131 contain any `None` elements.
132132
133133 .. warning::
134134
135- `func` must never raise if it's run inside `jax.jit`, as its behavior is
135+ `func` must never raise inside `` jax.jit`` , as the resulting behavior is
136136 undefined.
137137
138138 Using this with `as_numpy=False` is particularly useful to apply non-jittable
139139 JAX functions to arrays on GPU devices.
140- If `as_numpy=True`, the :doc:`jax:transfer_guard` may prevent arrays on a GPU
140+ If `` as_numpy=True` `, the :doc:`jax:transfer_guard` may prevent arrays on a GPU
141141 device from being transferred back to CPU. This is treated as an implicit
142142 transfer.
143143
144144 PyTorch, CuPy
145- If `as_numpy=True`, these backends raise by default if you attempt to convert
145+ If `` as_numpy=True` `, these backends raise by default if you attempt to convert
146146 arrays on a GPU device to NumPy.
147147
148148 Sparse
149- If `as_numpy=True`, by default sparse prevents implicit densification through
149+ If `` as_numpy=True` `, by default sparse prevents implicit densification through
150150 :func:`numpy.asarray`. `This safety mechanism can be disabled
151151 <https://sparse.pydata.org/en/stable/operations.html#package-configuration>`_.
152152
@@ -171,21 +171,21 @@ def lazy_apply( # type: ignore[valid-type] # numpydoc ignore=GL07,SA04
171171 `lazy_apply`.
172172
173173 Dask wrapping around other backends
174- If `as_numpy=False`, `func` will receive in input eager arrays of the meta
175- namespace, as defined by the `._meta` attribute of the input Dask arrays.
174+ If `` as_numpy=False` `, `func` will receive in input eager arrays of the meta
175+ namespace, as defined by the `` ._meta` ` attribute of the input Dask arrays.
176176 The outputs of `func` will be wrapped by the meta namespace, and then wrapped
177177 again by Dask.
178178
179179 Raises
180180 ------
181181 jax.errors.TracerArrayConversionError
182- When `xp=jax.numpy`, `shape` is unknown (it contains None on one or more axes)
183- and this function was called inside `jax.jit`.
182+ When `` xp=jax.numpy` `, `shape` is unknown (it contains None on one or more axes)
183+ and this function was called inside `` jax.jit` `.
184184 RuntimeError
185- When `xp=sparse` and auto-densification is disabled.
185+ When `` xp=sparse` ` and auto-densification is disabled.
186186 Exception (backend-specific)
187187 When the backend disallows implicit device to host transfers and the input
188- arrays are on a device, e.g. on GPU.
188+ arrays are on a non-CPU device, e.g. on GPU.
189189
190190 See Also
191191 --------
@@ -237,6 +237,7 @@ def lazy_apply( # type: ignore[valid-type] # numpydoc ignore=GL07,SA04
237237 raise ValueError (msg )
238238 del shape
239239 del dtype
240+ # End of shape and dtype parsing
240241
241242 # Backend-specific branches
242243 if is_dask_namespace (xp ):
0 commit comments