1717
1818if TYPE_CHECKING :
1919 # https://github.com/scikit-learn/scikit-learn/pull/27910#issuecomment-2568023972
20- from typing import TypeAlias
20+ from typing import ParamSpec , TypeAlias
2121
2222 import numpy as np
2323
2424 NumPyObject : TypeAlias = np .ndarray [Any , Any ] | np .generic # type: ignore[no-any-explicit]
25- KwArg : TypeAlias = Any # type: ignore[no-any-explicit]
25+ P = ParamSpec ( "P" )
2626
2727
2828@overload
29- def apply_numpy_func (
30- func : Callable [... , NumPyObject ],
29+ def apply_numpy_func ( # type: ignore[valid-type]
30+ func : Callable [P , NumPyObject ],
3131 * args : Array ,
3232 shape : tuple [int , ...] | None = None ,
3333 dtype : DType | None = None ,
3434 xp : ModuleType | None = None ,
35- ** kwargs : KwArg ,
35+ ** kwargs : P . kwargs , # pyright: ignore[reportGeneralTypeIssues]
3636) -> Array : ... # numpydoc ignore=GL08
3737
3838
3939@overload
40- def apply_numpy_func ( # type: ignore[no-any-decorated ]
41- func : Callable [... , Sequence [NumPyObject ]],
40+ def apply_numpy_func ( # type: ignore[valid-type ]
41+ func : Callable [P , Sequence [NumPyObject ]],
4242 * args : Array ,
4343 shape : Sequence [tuple [int , ...]],
4444 dtype : Sequence [DType ] | None = None ,
4545 xp : ModuleType | None = None ,
46- ** kwargs : Any ,
46+ ** kwargs : P . kwargs , # pyright: ignore[reportGeneralTypeIssues]
4747) -> tuple [Array , ...]: ... # numpydoc ignore=GL08
4848
4949
50- def apply_numpy_func ( # type: ignore[no-any-explicit ]
51- func : Callable [... , NumPyObject | Sequence [NumPyObject ]],
50+ def apply_numpy_func ( # type: ignore[valid-type ]
51+ func : Callable [P , NumPyObject | Sequence [NumPyObject ]],
5252 * args : Array ,
5353 shape : tuple [int , ...] | Sequence [tuple [int , ...]] | None = None ,
5454 dtype : DType | Sequence [DType ] | None = None ,
5555 xp : ModuleType | None = None ,
56- ** kwargs : Any ,
56+ ** kwargs : P . kwargs , # pyright: ignore[reportGeneralTypeIssues]
5757) -> Array | tuple [Array , ...]:
5858 """
5959 Apply a function that operates on NumPy arrays to Array API compliant arrays.
@@ -139,7 +139,7 @@ def apply_numpy_func( # type: ignore[no-any-explicit]
139139 elif isinstance (shape , tuple ) and all (isinstance (s , int ) for s in shape ):
140140 shapes = [shape ]
141141 else :
142- shapes = shape
142+ shapes = list ( shape )
143143 multi_output = True
144144
145145 if dtype is None :
@@ -148,7 +148,7 @@ def apply_numpy_func( # type: ignore[no-any-explicit]
148148 if not isinstance (dtype , Sequence ):
149149 msg = "Got sequence of shapes but only one dtype"
150150 raise TypeError (msg )
151- dtypes = dtype
151+ dtypes = list ( dtype ) # pyright: ignore[reportUnknownArgumentType]
152152 else :
153153 if isinstance (dtype , Sequence ):
154154 msg = "Got single shape but multiple dtypes"
@@ -254,13 +254,16 @@ def wrapper( # type: ignore[no-any-decorated,no-any-explicit]
254254 args = tuple (np .asarray (arg ) for arg in args )
255255 out = func (* args , ** kwargs )
256256
257+ # Stay relaxed on output validation, e.g. in case func returns a
258+ # Python scalar instead of a np.generic
257259 if multi_output :
258260 if not isinstance (out , Sequence ) or isinstance (out , np .ndarray ):
259261 msg = "Expected multiple outputs, got a single one"
260262 raise ValueError (msg )
263+ outs = out
261264 else :
262- out = ( out ,)
265+ outs = [ cast ( "NumPyObject" , out )]
263266
264- return tuple (xp .asarray (o ) for o in out )
267+ return tuple (xp .asarray (o ) for o in outs )
265268
266269 return wrapper
0 commit comments