|
8 | 8 | from types import ModuleType |
9 | 9 | from typing import TYPE_CHECKING, Any, cast, overload |
10 | 10 |
|
11 | | -from ._utils._compat import ( |
12 | | - array_namespace, |
13 | | - is_dask_namespace, |
14 | | - is_jax_namespace |
15 | | -) |
| 11 | +from ._utils._compat import array_namespace, is_dask_namespace, is_jax_namespace |
16 | 12 | from ._utils._typing import Array, DType |
17 | 13 |
|
18 | 14 | if TYPE_CHECKING: |
| 15 | + # TODO move outside TYPE_CHECKING |
| 16 | + # depends on scikit-learn abandoning Python 3.9 |
19 | 17 | # https://github.com/scikit-learn/scikit-learn/pull/27910#issuecomment-2568023972 |
20 | 18 | from typing import ParamSpec, TypeAlias |
21 | 19 |
|
@@ -72,8 +70,8 @@ def apply_numpy_func( # type: ignore[valid-type] # numpydoc ignore=GL07,SA04 |
72 | 70 | positional arguments and return either a single NumPy array or generic, or a |
73 | 71 | tuple or list thereof. |
74 | 72 |
|
75 | | - It must be a pure function, i.e. without side effects such as disk output, |
76 | | - as depending on the backend it may be executed more than once. |
| 73 | + It must be a pure function, i.e. without side effects, as depending on the |
| 74 | + backend it may be executed more than once. |
77 | 75 | *args : Array |
78 | 76 | One or more Array API compliant arrays. You need to be able to apply |
79 | 77 | :func:`numpy.asarray` to them to convert them to numpy; read notes below about |
@@ -225,18 +223,6 @@ def apply_numpy_func( # type: ignore[valid-type] # numpydoc ignore=GL07,SA04 |
225 | 223 | wrapped = _npfunc_wrapper(func, multi_output, xp) |
226 | 224 | out = wrapped(*args, **kwargs) |
227 | 225 |
|
228 | | - # Output validation |
229 | | - if len(out) != len(shapes): |
230 | | - msg = f"func was declared to return {len(shapes)} outputs, got {len(out)}" |
231 | | - raise ValueError(msg) |
232 | | - for out_i, shape_i, dtype_i in zip(out, shapes, dtypes, strict=True): |
233 | | - if out_i.shape != shape_i: |
234 | | - msg = f"expected shape {shape_i}, got {out_i.shape}" |
235 | | - raise ValueError(msg) |
236 | | - if not xp.isdtype(out_i.dtype, dtype_i): |
237 | | - msg = f"expected dtype {dtype_i}, got {out_i.dtype}" |
238 | | - raise ValueError(msg) |
239 | | - |
240 | 226 | return out if multi_output else out[0] |
241 | 227 |
|
242 | 228 |
|
|
0 commit comments