33# https://github.com/scikit-learn/scikit-learn/pull/27910#issuecomment-2568023972
44from __future__ import annotations
55
6- from collections .abc import Callable , Hashable , Mapping , Sequence
6+ from collections .abc import Callable , Sequence
77from functools import wraps
88from types import ModuleType
99from typing import TYPE_CHECKING , Any , cast
2020 from typing import TypeAlias
2121
2222 import numpy as np
23- import numpy .typing as npt
2423
25- NumPyObject : TypeAlias = npt . NDArray [ DType ] | np .generic # type: ignore[no-any-explicit]
24+ NumPyObject : TypeAlias = np . ndarray [ Any , Any ] | np .generic # type: ignore[no-any-explicit]
2625
2726
2827def apply_numpy_func ( # type: ignore[no-any-explicit]
@@ -31,11 +30,6 @@ def apply_numpy_func( # type: ignore[no-any-explicit]
3130 shapes : Sequence [tuple [int , ...]] | None = None ,
3231 dtypes : Sequence [DType ] | None = None ,
3332 xp : ModuleType | None = None ,
34- input_indices : Sequence [Sequence [Hashable ]] | None = None ,
35- core_indices : Sequence [Hashable ] | None = None ,
36- output_indices : Sequence [Sequence [Hashable ]] | None = None ,
37- adjust_chunks : Sequence [dict [Hashable , Callable [[int ], int ]]] | None = None ,
38- new_axes : Sequence [dict [Hashable , int ]] | None = None ,
3933 ** kwargs : Any ,
4034) -> tuple [Array , ...]:
4135 """
@@ -66,33 +60,6 @@ def apply_numpy_func( # type: ignore[no-any-explicit]
6660 Default: infer the result type(s) from the input arrays.
6761 xp : array_namespace, optional
6862 The standard-compatible namespace for `args`. Default: infer.
69- input_indices : Sequence[Sequence[Hashable]], optional
70- Dask specific.
71- Axes labels for each input array, e.g. if there are two args with respectively
72- ndim=3 and 1, `input_indices` could be ``['ijk', 'j']`` or ``[(0, 1, 2),
73- (1,)]``.
74- Default: disallow Dask.
75- core_indices : Sequence[Hashable], optional
76- **Dask specific.**
77- Axes of the input arrays that cannot be broken into chunks.
78- Default: disallow Dask.
79- output_indices : Sequence[Sequence[Hashable]], optional
80- **Dask specific.**
81- Axes labels for each output array. If `func` returns a single (non-sequence)
82- output, this must be a sequence containing a single sequence of labels, e.g.
83- ``['ijk']``.
84- Default: disallow Dask.
85- adjust_chunks : Sequence[Mapping[Hashable, Callable[[int], int]]], optional
86- **Dask specific.**
87- Sequence of dicts, one per output, mapping index to function to be applied to
88- each chunk to determine the output size. The total must add up to the output
89- shape.
90- Default: on Dask, the size along each index cannot change.
91- new_axes : Sequence[Mapping[Hashable, int]], optional
92- **Dask specific.**
93- New indexes and their dimension lengths, one per output.
94- Default: on Dask, there can't be `output_indices` that don't appear in
95- `input_indices`.
9663 **kwargs : Any, optional
9764 Additional keyword arguments to pass verbatim to `func`.
9865 Any array objects in them won't be converted to NumPy.
@@ -124,43 +91,22 @@ def apply_numpy_func( # type: ignore[no-any-explicit]
12491 <https://sparse.pydata.org/en/stable/operations.html#package-configuration>`_.
12592
12693 Dask
127- This allows applying eager functions to the individual chunks of dask arrays.
128- The dask graph won't be computed. As a special limitation, `func` must return
129- exactly one output.
94+ This allows applying eager functions to dask arrays.
95+ The dask graph won't be computed.
13096
131- In order to enable running on Dask you need to specify at least
132- `input_indices`, `output_indices`, and `core_indices`, but you may also need
133- `adjust_chunks` and `new_axes` depending on the function .
97+ `apply_numpy_func` doesn't know if `func` reduces along any axes and shape
98+ changes are non-trivial in chunked Dask arrays. For these reasons, all inputs
99+ will be rechunked into a single chunk .
134100
135- Read `dask.array.blockwise`:
136- - ``input_indices`` map to the even ``*args`` of `dask.array.blockwise`
137- - ``output_indices[0]`` maps to the ``out_ind`` parameter
138- - ``adjust_chunks[0]`` maps to the ``adjust_chunks`` parameter
139- - ``new_axes[0]`` maps to the ``new_axes`` parameter
101+ .. warning::
140102
141- ``core_indices`` is a safety measure to prevent incorrect results on
142- Dask along chunked axes. Consider this::
103+ The whole operation needs to fit in memory all at once on a single worker.
143104
144- >>> apply_numpy_func(lambda x: x + x.sum(axis=0), x,
145- ... input_indices=['ij'], output_indices=['ij'])
146-
147- The above example would produce incorrect results if x is a dask array with more
148- than one chunk along axis 0, as each chunk will calculate its own local
149- subtotal. To prevent this, we need to declare the first axis of ``args[0]`` as a
150- *core axis*::
151-
152- >>> apply_numpy_func(lambda x: x + x.sum(axis=0), x,
153- ... input_indices=['ij'], output_indices=['ij'],
154- ... core_indices='i')
155-
156- This will cause `apply_numpy_func` to raise if the first axis of `x` is broken
157- along multiple chunks, thus forcing the final user to rechunk ahead of time:
158-
159- >>> x = x.chunk({0: -1})
160-
161- This needs to always be a conscious decision on behalf of the final user, as the
162- new chunks will be larger than the old and may cause memory issues, unless chunk
163- size is reduced along a different, non-core axis.
105+ The outputs will also be returned as a single chunk and you should consider
106+ rechunking them into smaller chunks afterwards.
107+ If you want to distribute the calculation across multiple workers, you
108+ should use `dask.array.map_blocks`, `dask.array.blockwise`,
109+ `dask.array.map_overlap`, or a native Dask wrapper instead of this function.
164110 """
165111 if xp is None :
166112 xp = array_namespace (* args )
@@ -177,68 +123,30 @@ def apply_numpy_func( # type: ignore[no-any-explicit]
177123 raise ValueError (msg )
178124
179125 if is_dask_namespace (xp ):
180- # General validation
181- if len (shapes ) > 1 :
182- msg = "dask.array.map_blocks() does not support multiple outputs"
183- raise NotImplementedError (msg )
184- if input_indices is None or output_indices is None or core_indices is None :
185- msg = (
186- "Dask is disallowed unless one declares input_indices, "
187- "output_indices, and core_indices"
188- )
189- raise ValueError (msg )
190- if len (input_indices ) != len (args ):
191- msg = f"got { len (input_indices )} input_indices and { len (args )} args"
192- raise ValueError (msg )
193- if len (output_indices ) != len (shapes ):
194- msg = f"got { len (output_indices )} input_indices and { len (shapes )} shapes"
195- raise NotImplementedError (msg )
196- if isinstance (adjust_chunks , Mapping ):
197- msg = "adjust_chunks must be a sequence of mappings"
198- raise ValueError (msg )
199- if adjust_chunks is not None and len (adjust_chunks ) != len (shapes ):
200- msg = f"got { len (adjust_chunks )} adjust_chunks and { len (shapes )} shapes"
201- raise ValueError (msg )
202- if isinstance (new_axes , Mapping ):
203- msg = "new_axes must be a sequence of mappings"
204- raise ValueError (msg )
205- if new_axes is not None and len (new_axes ) != len (shapes ):
206- msg = f"got { len (new_axes )} new_axes and { len (shapes )} shapes"
207- raise ValueError (msg )
126+ import dask # type: ignore[import-not-found] # pylint: disable=import-outside-toplevel,import-error # pyright: ignore[reportMissingImports]
127+
128+ metas = [arg ._meta for arg in args if hasattr (arg , "_meta" )] # pylint: disable=protected-access
129+ meta_xp = array_namespace (* metas )
130+ meta = metas [0 ]
208131
209- # core_indices validation
210- for inp_idx , arg in zip (input_indices , args , strict = True ):
211- for i , chunks in zip (inp_idx , arg .chunks , strict = True ):
212- if i in core_indices and len (chunks ) > 1 :
213- msg = f"Core index { i } is broken into multiple chunks"
214- raise ValueError (msg )
215-
216- meta_xp = array_namespace (* (getattr (arg , "meta" , None ) for arg in args ))
217- wrapped = _npfunc_single_output_wrapper (func , meta_xp )
218- dask_args = []
219- for arg , inp_idx in zip (args , input_indices , strict = True ):
220- dask_args += [arg , inp_idx ]
221-
222- out = xp .blockwise (
223- wrapped ,
224- output_indices [0 ],
225- * dask_args ,
226- dtype = dtypes [0 ],
227- adjust_chunks = adjust_chunks [0 ] if adjust_chunks is not None else None ,
228- new_axes = new_axes [0 ] if new_axes is not None else None ,
229- ** kwargs ,
132+ wrapped = dask .delayed (_npfunc_wrapper (func , meta_xp ), pure = True )
133+ # This finalizes each arg, which is the same as arg.rechunk(-1)
134+ # Please read docstring above for why we're not using
135+ # dask.array.map_blocks or dask.array.blockwise!
136+ delayed_out = wrapped (* args , ** kwargs )
137+
138+ return tuple (
139+ xp .from_delayed (delayed_out [i ], shape = shape , dtype = dtype , meta = meta )
140+ for i , (shape , dtype ) in enumerate (zip (shapes , dtypes , strict = True ))
230141 )
231- if out .shape != shapes [0 ]:
232- msg = f"expected shape { shapes [0 ]} , but got { out .shape } from indices"
233- raise ValueError (msg )
234- return (out ,)
235142
236- wrapped = _npfunc_tuple_output_wrapper (func , xp )
143+ wrapped = _npfunc_wrapper (func , xp )
237144 if is_jax_namespace (xp ):
238145 # If we're inside jax.jit, we can't eagerly convert
239146 # the JAX tracer objects to numpy.
240147 # Instead, we delay calling wrapped, which will receive
241148 # as arguments and will return JAX eager arrays.
149+
242150 import jax # type: ignore[import-not-found] # pylint: disable=import-outside-toplevel,import-error # pyright: ignore[reportMissingImports]
243151
244152 return cast (
@@ -271,17 +179,17 @@ def apply_numpy_func( # type: ignore[no-any-explicit]
271179 return out # type: ignore[no-any-return]
272180
273181
274- def _npfunc_tuple_output_wrapper ( # type: ignore[no-any-explicit] # numpydoc ignore=PR01,RT01
182+ def _npfunc_wrapper ( # type: ignore[no-any-explicit] # numpydoc ignore=PR01,RT01
275183 func : Callable [..., NumPyObject | Sequence [NumPyObject ]],
276184 xp : ModuleType ,
277185) -> Callable [..., tuple [Array , ...]]:
278186 """
279187 Helper of `apply_numpy_func`.
280188
281189 Given a function that accepts one or more numpy arrays as positional arguments and
282- returns a single numpy array or a sequence of numpy arrays,
283- return a function that accepts the same number of Array API arrays and always
284- returns a tuple of Array API array.
190+ returns a single numpy array or a sequence of numpy arrays, return a function that
191+ accepts the same number of Array API arrays and always returns a tuple of Array API
192+ array.
285193
286194 Any keyword arguments are passed through verbatim to the wrapped function.
287195
@@ -290,6 +198,7 @@ def _npfunc_tuple_output_wrapper( # type: ignore[no-any-explicit] # numpydoc i
290198 densification for sparse arrays, device->host transfer for cupy and torch arrays).
291199 """
292200
201+ # On Dask, @wraps causes the graph key to contain the wrapped function's name
293202 @wraps (func )
294203 def wrapper ( # type: ignore[no-any-decorated,no-any-explicit]
295204 * args : Array , ** kwargs : Any
@@ -311,41 +220,3 @@ def wrapper( # type: ignore[no-any-decorated,no-any-explicit]
311220 return tuple (xp .asarray (o ) for o in out )
312221
313222 return wrapper
314-
315-
316- def _npfunc_single_output_wrapper ( # type: ignore[no-any-explicit] # numpydoc ignore=PR01,RT01
317- func : Callable [..., NumPyObject | Sequence [NumPyObject ]],
318- xp : ModuleType ,
319- ) -> Callable [..., Array ]:
320- """
321- Dask-specific helper of `apply_numpy_func`.
322-
323- Variant of `_npfunc_tuple_output_wrapper`, to be used with Dask which, at the time
324- of writing, does not support multiple outputs in `dask.array.blockwise`.
325-
326- func may return a single numpy object or a sequence with exactly one numpy object.
327- The wrapper returns a single Array object, with no tuple wrapping.
328- """
329-
330- # @wraps causes the generated dask key to contain the name of the wrapped function
331- @wraps (func )
332- def wrapper ( # type: ignore[no-any-decorated,no-any-explicit] # numpydoc ignore=GL08
333- * args : Array , ** kwargs : Any
334- ) -> Array :
335- import numpy as np # pylint: disable=import-outside-toplevel
336-
337- args = tuple (np .asarray (arg ) for arg in args )
338- out = func (* args , ** kwargs )
339-
340- if not isinstance (out , np .ndarray | np .generic ):
341- if not isinstance (out , Sequence ) or len (out ) != 1 : # pyright: ignore[reportUnnecessaryIsInstance]
342- msg = (
343- "apply_numpy_func: func must return a single numpy object or a "
344- f"sequence with exactly one numpy object; got { out } "
345- )
346- raise ValueError (msg )
347- out = out [0 ]
348-
349- return xp .asarray (out )
350-
351- return wrapper
0 commit comments