1919 from typing import ParamSpec , TypeAlias
2020
2121 import numpy as np
22+ from numpy .typing import ArrayLike
2223
2324 NumPyObject : TypeAlias = np .ndarray [Any , Any ] | np .generic # type: ignore[no-any-explicit]
2425 P = ParamSpec ("P" )
@@ -32,58 +33,74 @@ class P: # pylint: disable=missing-class-docstring
3233
3334
3435@overload
35- def apply_numpy_func ( # type: ignore[valid-type]
36- func : Callable [P , NumPyObject ],
36+ def apply_lazy ( # type: ignore[valid-type]
37+ func : Callable [P , ArrayLike ],
3738 * args : Array ,
3839 shape : tuple [int | None , ...] | None = None ,
3940 dtype : DType | None = None ,
41+ as_numpy : bool = False ,
4042 xp : ModuleType | None = None ,
4143 ** kwargs : P .kwargs , # pyright: ignore[reportGeneralTypeIssues]
4244) -> Array : ... # numpydoc ignore=GL08
4345
4446
4547@overload
46- def apply_numpy_func ( # type: ignore[valid-type]
47- func : Callable [P , Sequence [NumPyObject ]],
48+ def apply_lazy ( # type: ignore[valid-type]
49+ func : Callable [P , Sequence [ArrayLike ]],
4850 * args : Array ,
4951 shape : Sequence [tuple [int | None , ...]],
5052 dtype : Sequence [DType ] | None = None ,
53+ as_numpy : bool = False ,
5154 xp : ModuleType | None = None ,
5255 ** kwargs : P .kwargs , # pyright: ignore[reportGeneralTypeIssues]
5356) -> tuple [Array , ...]: ... # numpydoc ignore=GL08
5457
5558
56- def apply_numpy_func ( # type: ignore[valid-type] # numpydoc ignore=GL07,SA04
57- func : Callable [P , NumPyObject | Sequence [NumPyObject ]],
59+ def apply_lazy ( # type: ignore[valid-type] # numpydoc ignore=GL07,SA04
60+ func : Callable [P , Array | Sequence [ArrayLike ]],
5861 * args : Array ,
5962 shape : tuple [int | None , ...] | Sequence [tuple [int | None , ...]] | None = None ,
6063 dtype : DType | Sequence [DType ] | None = None ,
64+ as_numpy : bool = False ,
6165 xp : ModuleType | None = None ,
6266 ** kwargs : P .kwargs , # pyright: ignore[reportGeneralTypeIssues]
6367) -> Array | tuple [Array , ...]:
6468 """
65- Apply a function that operates on NumPy arrays to Array API compliant arrays.
69+ Lazily apply an eager function.
70+
71+ If the backend of the input arrays is lazy, e.g. Dask or jitted JAX, the execution
72+ of the function is delayed until the graph is materialized; if it's eager, the
73+ function is executed immediately.
6674
6775 Parameters
6876 ----------
6977 func : callable
70- The function to apply. It must accept one or more NumPy arrays or generics as
71- positional arguments and return either a single NumPy array or generic, or a
72- tuple or list thereof.
78+ The function to apply.
79+
80+ It must accept one or more array API compliant arrays as positional arguments.
81+ If `as_numpy=True`, inputs are converted to NumPy before they are passed to
82+ `func`.
83+ It must return either a single array-like or a sequence of array-likes.
7384
74- It must be a pure function, i.e. without side effects, as depending on the
85+ `func` must be a pure function, i.e. without side effects, as depending on the
7586 backend it may be executed more than once.
7687 *args : Array
77- One or more Array API compliant arrays. You need to be able to apply
78- :func:`numpy.asarray` to them to convert them to numpy; read notes below about
79- specific backends.
88+ One or more Array API compliant arrays.
89+
90+ If `as_numpy=True`, you need to be able to apply :func:`numpy.asarray` to them
91+ to convert them to numpy; read notes below about specific backends.
8092 shape : tuple[int | None, ...] | Sequence[tuple[int, ...]], optional
8193 Output shape or sequence of output shapes, one for each output of `func`.
8294 Default: assume single output and broadcast shapes of the input arrays.
8395 dtype : DType | Sequence[DType], optional
8496 Output dtype or sequence of output dtypes, one for each output of `func`.
8597 dtype(s) must belong to the same array namespace as the input arrays.
8698 Default: infer the result type(s) from the input arrays.
99+ as_numpy : bool, optional
100+ If True, convert the input arrays to NumPy before passing them to `func`.
101+ This is particularly useful to make numpy-only functions, e.g. written in Cython
102+ or Numba, work transparently API arrays.
103+ Default: False.
87104 xp : array_namespace, optional
88105 The standard-compatible namespace for `args`. Default: infer.
89106 **kwargs : Any, optional
@@ -95,7 +112,7 @@ def apply_numpy_func( # type: ignore[valid-type] # numpydoc ignore=GL07,SA04
95112 Array | tuple[Array, ...]
96113 The result(s) of `func` applied to the input arrays, wrapped in the same
97114 array namespace as the inputs.
98- If shape is omitted or a `tuple[int, ...]`, this is a single array.
115+ If shape is omitted or a `tuple[int | None , ...]`, this is a single array.
99116 Otherwise, it's a tuple of arrays.
100117
101118 Notes
@@ -106,23 +123,26 @@ def apply_numpy_func( # type: ignore[valid-type] # numpydoc ignore=GL07,SA04
106123 When running inside `jax.jit`, `shape` must be fully known, i.e. it cannot
107124 contain any `None` elements.
108125
109- The :doc:`jax:transfer_guard` may prevent arrays on a GPU device from being
110- transferred back to CPU. This is treated as an implicit transfer.
126+ Using this with `as_numpy=False` is particularly useful to apply non-jittable
127+ JAX functions to arrays on GPU devices.
128+ If `as_numpy=True`, the :doc:`jax:transfer_guard` may prevent arrays on a GPU
129+ device from being transferred back to CPU. This is treated as an implicit
130+ transfer.
111131
112132 PyTorch, CuPy
113- These backends raise by default if you attempt to convert arrays on a GPU device
114- to NumPy.
133+ If `as_numpy=True`, these backends raise by default if you attempt to convert
134+ arrays on a GPU device to NumPy.
115135
116136 Sparse
117- By default, sparse prevents implicit densification through
137+ If `as_numpy=True`, by default sparse prevents implicit densification through
118138 :func:`numpy.asarray`. `This safety mechanism can be disabled
119139 <https://sparse.pydata.org/en/stable/operations.html#package-configuration>`_.
120140
121141 Dask
122142 This allows applying eager functions to dask arrays.
123143 The dask graph won't be computed.
124144
125- `apply_numpy_func ` doesn't know if `func` reduces along any axes; also, shape
145+ `apply_lazy ` doesn't know if `func` reduces along any axes; also, shape
126146 changes are non-trivial in chunked Dask arrays. For these reasons, all inputs
127147 will be rechunked into a single chunk.
128148
@@ -136,7 +156,13 @@ def apply_numpy_func( # type: ignore[valid-type] # numpydoc ignore=GL07,SA04
136156 If you want to distribute the calculation across multiple workers, you
137157 should use :func:`dask.array.map_blocks`, :func:`dask.array.map_overlap`,
138158 :func:`dask.array.blockwise`, or a native Dask wrapper instead of
139- `apply_numpy_func`.
159+ `apply_lazy`.
160+
161+ Dask wrapping around other backends
162+ If `as_numpy=False`, `func` will receive in input eager arrays of the meta
163+ namespace, as defined by the `._meta` attribute of the input Dask arrays.
164+ The outputs of `func` will be wrapped by the meta namespace, and then wrapped
165+ again by Dask.
140166
141167 Raises
142168 ------
@@ -202,7 +228,10 @@ def apply_numpy_func( # type: ignore[valid-type] # numpydoc ignore=GL07,SA04
202228 metas = [arg ._meta for arg in args if hasattr (arg , "_meta" )] # pylint: disable=protected-access
203229 meta_xp = array_namespace (* metas )
204230
205- wrapped = dask .delayed (_npfunc_wrapper (func , multi_output , meta_xp ), pure = True )
231+ wrapped = dask .delayed (
232+ _apply_lazy_wrapper (func , as_numpy , multi_output , meta_xp ),
233+ pure = True ,
234+ )
206235 # This finalizes each arg, which is the same as arg.rechunk(-1).
207236 # Please read docstring above for why we're not using
208237 # dask.array.map_blocks or dask.array.blockwise!
@@ -227,7 +256,7 @@ def apply_numpy_func( # type: ignore[valid-type] # numpydoc ignore=GL07,SA04
227256
228257 import jax # type: ignore[import-not-found] # pylint: disable=import-outside-toplevel,import-error # pyright: ignore[reportMissingImports]
229258
230- wrapped = _npfunc_wrapper (func , multi_output , xp )
259+ wrapped = _apply_lazy_wrapper (func , as_numpy , multi_output , xp )
231260
232261 if any (s is None for shape in shapes for s in shape ):
233262 # Unknown output shape. Won't work with jax.jit, but it
@@ -251,19 +280,20 @@ def apply_numpy_func( # type: ignore[valid-type] # numpydoc ignore=GL07,SA04
251280
252281 else :
253282 # Eager backends
254- wrapped = _npfunc_wrapper (func , multi_output , xp )
283+ wrapped = _apply_lazy_wrapper (func , as_numpy , multi_output , xp )
255284 out = wrapped (* args , ** kwargs )
256285
257286 return out if multi_output else out [0 ]
258287
259288
260- def _npfunc_wrapper ( # type: ignore[no-any-explicit] # numpydoc ignore=PR01,RT01
261- func : Callable [..., NumPyObject | Sequence [NumPyObject ]],
289+ def _apply_lazy_wrapper ( # type: ignore[no-any-explicit] # numpydoc ignore=PR01,RT01
290+ func : Callable [..., ArrayLike | Sequence [ArrayLike ]],
291+ as_numpy : bool ,
262292 multi_output : bool ,
263293 xp : ModuleType ,
264294) -> Callable [..., tuple [Array , ...]]:
265295 """
266- Helper of `apply_numpy_func `.
296+ Helper of `apply_lazy `.
267297
268298 Given a function that accepts one or more numpy arrays as positional arguments and
269299 returns a single numpy array or a sequence of numpy arrays, return a function that
@@ -284,19 +314,13 @@ def wrapper( # type: ignore[no-any-decorated,no-any-explicit]
284314 ) -> tuple [Array , ...]: # numpydoc ignore=GL08
285315 import numpy as np # pylint: disable=import-outside-toplevel
286316
287- args = tuple (np .asarray (arg ) for arg in args )
317+ if as_numpy :
318+ args = tuple (np .asarray (arg ) for arg in args )
288319 out = func (* args , ** kwargs )
289320
290- # Stay relaxed on output validation, e.g. in case func returns a
291- # Python scalar instead of a np.generic
292321 if multi_output :
293- if not isinstance (out , Sequence ) or isinstance (out , np .ndarray ):
294- msg = "Expected multiple outputs, got a single one"
295- raise ValueError (msg )
296- outs = out
297- else :
298- outs = [cast ("NumPyObject" , out )]
299-
300- return tuple (xp .asarray (o ) for o in outs )
322+ assert isinstance (out , Sequence )
323+ return tuple (xp .asarray (o ) for o in out )
324+ return (xp .asarray (out ),)
301325
302326 return wrapper
0 commit comments