66from collections .abc import Callable , Sequence
77from functools import wraps
88from types import ModuleType
9- from typing import TYPE_CHECKING , Any , cast
9+ from typing import TYPE_CHECKING , Any , cast , overload
1010
1111from ._lib ._compat import (
1212 array_namespace ,
2222 import numpy as np
2323
2424 NumPyObject : TypeAlias = np .ndarray [Any , Any ] | np .generic # type: ignore[no-any-explicit]
25+ KwArg : TypeAlias = Any # type: ignore[no-any-explicit]
26+
27+
28+ @overload
29+ def apply_numpy_func (
30+ func : Callable [..., NumPyObject ],
31+ * args : Array ,
32+ shape : tuple [int , ...] | None = None ,
33+ dtype : DType | None = None ,
34+ xp : ModuleType | None = None ,
35+ ** kwargs : KwArg ,
36+ ) -> Array : ... # numpydoc ignore=GL08
37+
38+
39+ @overload
40+ def apply_numpy_func ( # type: ignore[no-any-decorated]
41+ func : Callable [..., Sequence [NumPyObject ]],
42+ * args : Array ,
43+ shape : Sequence [tuple [int , ...]],
44+ dtype : Sequence [DType ] | None = None ,
45+ xp : ModuleType | None = None ,
46+ ** kwargs : Any ,
47+ ) -> tuple [Array , ...]: ... # numpydoc ignore=GL08
2548
2649
2750def apply_numpy_func ( # type: ignore[no-any-explicit]
2851 func : Callable [..., NumPyObject | Sequence [NumPyObject ]],
2952 * args : Array ,
30- shapes : Sequence [tuple [int , ...]] | None = None ,
31- dtypes : Sequence [DType ] | None = None ,
53+ shape : tuple [ int , ...] | Sequence [tuple [int , ...]] | None = None ,
54+ dtype : DType | Sequence [DType ] | None = None ,
3255 xp : ModuleType | None = None ,
3356 ** kwargs : Any ,
34- ) -> tuple [Array , ...]:
57+ ) -> Array | tuple [Array , ...]:
3558 """
3659 Apply a function that operates on NumPy arrays to Array API compliant arrays.
3760
@@ -48,15 +71,11 @@ def apply_numpy_func( # type: ignore[no-any-explicit]
4871 One or more Array API compliant arrays. You need to be able to apply
4972 ``np.asarray()`` to them to convert them to numpy; read notes below about
5073 specific backends.
51- shapes : Sequence[tuple[int, ...]], optional
52- Sequence of output shapes, one for each output of `func`.
53- If `func` returns a single (non-sequence) output, this must be a sequence
54- with a single element.
55- Default: assume a single output and broadcast shapes of the input arrays.
56- dtypes : Sequence[DType], optional
57- Sequence of output dtypes, one for each output of `func`.
58- If `func` returns a single (non-sequence) output, this must be a sequence
59- with a single element.
74+ shape : tuple[int, ...] | Sequence[tuple[int, ...]], optional
75+ Output shape or sequence of output shapes, one for each output of `func`.
76+ Default: assume single output and broadcast shapes of the input arrays.
77+ dtype : DType | Sequence[DType], optional
78+ Output dtype or sequence of output dtypes, one for each output of `func`.
6079 Default: infer the result type(s) from the input arrays.
6180 xp : array_namespace, optional
6281 The standard-compatible namespace for `args`. Default: infer.
@@ -66,9 +85,11 @@ def apply_numpy_func( # type: ignore[no-any-explicit]
6685
6786 Returns
6887 -------
69- tuple[Array, ...]
70- The result(s) of `func` applied to the input arrays.
71- This is always a tuple, even if `func` returns a single output.
88+ Array | tuple[Array, ...]
89+ The result(s) of `func` applied to the input arrays, wrapped in the same
90+ array namespace as the inputs.
91+ If shape is omitted or a `tuple[int, ...]`, this is a single array.
92+ Otherwise, it's a tuple of arrays.
7293
7394 Notes
7495 -----
@@ -110,46 +131,67 @@ def apply_numpy_func( # type: ignore[no-any-explicit]
110131 """
111132 if xp is None :
112133 xp = array_namespace (* args )
113- if shapes is None :
134+
135+ # Normalize and validate shape and dtype
136+ multi_output = False
137+ if shape is None :
114138 shapes = [xp .broadcast_shapes (* (arg .shape for arg in args ))]
115- if dtypes is None :
139+ elif isinstance (shape , tuple ) and all (isinstance (s , int ) for s in shape ):
140+ shapes = [shape ]
141+ else :
142+ shapes = shape
143+ multi_output = True
144+
145+ if dtype is None :
116146 dtypes = [xp .result_type (* args )] * len (shapes )
147+ elif multi_output :
148+ if not isinstance (dtype , Sequence ):
149+ msg = "Got sequence of shapes but only one dtype"
150+ raise TypeError (msg )
151+ dtypes = dtype
152+ else :
153+ if isinstance (dtype , Sequence ):
154+ msg = "Got single shape but multiple dtypes"
155+ raise TypeError (msg )
156+ dtypes = [dtype ]
117157
118158 if len (shapes ) != len (dtypes ):
119- msg = f"got { len (shapes )} shapes and { len (dtypes )} dtypes"
159+ msg = f"Got { len (shapes )} shapes and { len (dtypes )} dtypes"
120160 raise ValueError (msg )
121161 if len (shapes ) == 0 :
122- msg = "Must have at least one output array "
162+ msg = "func must return one or more output arrays "
123163 raise ValueError (msg )
164+ del shape
165+ del dtype
124166
167+ # Backend-specific branches
125168 if is_dask_namespace (xp ):
126169 import dask # type: ignore[import-not-found] # pylint: disable=import-outside-toplevel,import-error # pyright: ignore[reportMissingImports]
127170
128171 metas = [arg ._meta for arg in args if hasattr (arg , "_meta" )] # pylint: disable=protected-access
129172 meta_xp = array_namespace (* metas )
130- meta = metas [0 ]
131173
132- wrapped = dask .delayed (_npfunc_wrapper (func , meta_xp ), pure = True )
174+ wrapped = dask .delayed (_npfunc_wrapper (func , multi_output , meta_xp ), pure = True )
133175 # This finalizes each arg, which is the same as arg.rechunk(-1)
134176 # Please read docstring above for why we're not using
135177 # dask.array.map_blocks or dask.array.blockwise!
136178 delayed_out = wrapped (* args , ** kwargs )
137179
138- return tuple (
139- xp .from_delayed (delayed_out [i ], shape = shape , dtype = dtype , meta = meta )
180+ out = tuple (
181+ xp .from_delayed (delayed_out [i ], shape = shape , dtype = dtype , meta = metas [ 0 ] )
140182 for i , (shape , dtype ) in enumerate (zip (shapes , dtypes , strict = True ))
141183 )
142184
143- wrapped = _npfunc_wrapper (func , xp )
144- if is_jax_namespace (xp ):
185+ elif is_jax_namespace (xp ):
145186 # If we're inside jax.jit, we can't eagerly convert
146187 # the JAX tracer objects to numpy.
147188 # Instead, we delay calling wrapped, which will receive
148189 # as arguments and will return JAX eager arrays.
149190
150191 import jax # type: ignore[import-not-found] # pylint: disable=import-outside-toplevel,import-error # pyright: ignore[reportMissingImports]
151192
152- return cast (
193+ wrapped = _npfunc_wrapper (func , multi_output , xp )
194+ out = cast (
153195 tuple [Array , ...],
154196 jax .pure_callback (
155197 wrapped ,
@@ -162,25 +204,29 @@ def apply_numpy_func( # type: ignore[no-any-explicit]
162204 ),
163205 )
164206
165- # Eager backends
166- out = wrapped (* args , ** kwargs )
207+ else :
208+ # Eager backends
209+ wrapped = _npfunc_wrapper (func , multi_output , xp )
210+ out = wrapped (* args , ** kwargs )
167211
168- # Output validation
169- if len (out ) != len (shapes ):
170- msg = f"func was declared to return { len (shapes )} outputs, got { len (out )} "
171- raise ValueError (msg )
172- for out_i , shape_i , dtype_i in zip (out , shapes , dtypes , strict = True ):
173- if out_i .shape != shape_i :
174- msg = f"expected shape { shape_i } , got { out_i .shape } "
175- raise ValueError (msg )
176- if not xp .isdtype (out_i .dtype , dtype_i ):
177- msg = f"expected dtype { dtype_i } , got { out_i .dtype } "
212+ # Output validation
213+ if len (out ) != len (shapes ):
214+ msg = f"func was declared to return { len (shapes )} outputs, got { len (out )} "
178215 raise ValueError (msg )
179- return out # type: ignore[no-any-return]
216+ for out_i , shape_i , dtype_i in zip (out , shapes , dtypes , strict = True ):
217+ if out_i .shape != shape_i :
218+ msg = f"expected shape { shape_i } , got { out_i .shape } "
219+ raise ValueError (msg )
220+ if not xp .isdtype (out_i .dtype , dtype_i ):
221+ msg = f"expected dtype { dtype_i } , got { out_i .dtype } "
222+ raise ValueError (msg )
223+
224+ return out if multi_output else out [0 ]
180225
181226
182227def _npfunc_wrapper ( # type: ignore[no-any-explicit] # numpydoc ignore=PR01,RT01
183228 func : Callable [..., NumPyObject | Sequence [NumPyObject ]],
229+ multi_output : bool ,
184230 xp : ModuleType ,
185231) -> Callable [..., tuple [Array , ...]]:
186232 """
@@ -208,14 +254,12 @@ def wrapper( # type: ignore[no-any-decorated,no-any-explicit]
208254 args = tuple (np .asarray (arg ) for arg in args )
209255 out = func (* args , ** kwargs )
210256
211- if isinstance (out , np .ndarray | np .generic ):
257+ if multi_output :
258+ if not isinstance (out , Sequence ) or isinstance (out , np .ndarray ):
259+ msg = "Expected multiple outputs, got a single one"
260+ raise ValueError (msg )
261+ else :
212262 out = (out ,)
213- elif not isinstance (out , Sequence ): # pyright: ignore[reportUnnecessaryIsInstance]
214- msg = (
215- "apply_numpy_func: func must return a numpy object or a "
216- f"sequence of numpy objects; got { out } "
217- )
218- raise TypeError (msg )
219263
220264 return tuple (xp .asarray (o ) for o in out )
221265
0 commit comments