33from __future__ import annotations
44
55import math
6+ import operator
67from collections .abc import Callable , Sequence
78from functools import partial , wraps
89from types import ModuleType
9- from typing import TYPE_CHECKING , Any , ParamSpec , TypeAlias , cast , overload
10+ from typing import TYPE_CHECKING , Any , TypeAlias , cast , overload
1011
1112from ._funcs import broadcast_shapes
1213from ._utils import _compat
2728 # Sphinx hack
2829 NumPyObject = Any
2930
30- P = ParamSpec ("P" )
31-
3231
3332@overload
34- def lazy_apply ( # type: ignore[decorated -any, valid-type ]
35- func : Callable [P , Array | ArrayLike ],
33+ def lazy_apply ( # type: ignore[explicit -any,decorated-any ]
34+ func : Callable [... , Array | ArrayLike ],
3635 * args : Array | complex | None ,
3736 shape : tuple [int | None , ...] | None = None ,
3837 dtype : DType | None = None ,
3938 as_numpy : bool = False ,
4039 xp : ModuleType | None = None ,
41- ** kwargs : P . kwargs , # pyright: ignore[reportGeneralTypeIssues]
40+ ** kwargs : Any ,
4241) -> Array : ... # numpydoc ignore=GL08
4342
4443
4544@overload
46- def lazy_apply ( # type: ignore[decorated -any, valid-type ]
47- func : Callable [P , Sequence [Array | ArrayLike ]],
45+ def lazy_apply ( # type: ignore[explicit -any,decorated-any ]
46+ func : Callable [... , Sequence [Array | ArrayLike ]],
4847 * args : Array | complex | None ,
4948 shape : Sequence [tuple [int | None , ...]],
5049 dtype : Sequence [DType ] | None = None ,
5150 as_numpy : bool = False ,
5251 xp : ModuleType | None = None ,
53- ** kwargs : P . kwargs , # pyright: ignore[reportGeneralTypeIssues]
52+ ** kwargs : Any ,
5453) -> tuple [Array , ...]: ... # numpydoc ignore=GL08
5554
5655
57- def lazy_apply ( # type: ignore[valid-type ] # numpydoc ignore=GL07,SA04
58- func : Callable [P , Array | ArrayLike | Sequence [Array | ArrayLike ]],
56+ def lazy_apply ( # type: ignore[explicit-any ] # numpydoc ignore=GL07,SA04
57+ func : Callable [... , Array | ArrayLike | Sequence [Array | ArrayLike ]],
5958 * args : Array | complex | None ,
6059 shape : tuple [int | None , ...] | Sequence [tuple [int | None , ...]] | None = None ,
6160 dtype : DType | Sequence [DType ] | None = None ,
6261 as_numpy : bool = False ,
6362 xp : ModuleType | None = None ,
64- ** kwargs : P . kwargs , # pyright: ignore[reportGeneralTypeIssues]
63+ ** kwargs : Any ,
6564) -> Array | tuple [Array , ...]:
6665 """
6766 Lazily apply an eager function.
@@ -162,10 +161,11 @@ def lazy_apply( # type: ignore[valid-type] # numpydoc ignore=GL07,SA04
162161 The outputs will also be returned as a single chunk and you should consider
163162 rechunking them into smaller chunks afterwards.
164163
165- If you want to distribute the calculation across multiple workers, you
166- should use :func:`dask.array.map_blocks`, :func:`dask.array.map_overlap`,
167- :func:`dask.array.blockwise`, or a native Dask wrapper instead of
168- `lazy_apply`.
164+ If you want to distribute the calculation across multiple workers and your
165+ function is elementwise, you should use :func:`lazy_apply_elementwise` instead.
166+ If the function is not elementwise, you should consider writing an ad-hoc
167+ variant for Dask using primitives like :func:`dask.array.blockwise`,
168+ :func:`dask.array.map_overlap`, or a native Dask algorithm.
169169
170170 Dask wrapping around other backends
171171 If ``as_numpy=False``, `func` will receive in input eager arrays of the meta
@@ -186,9 +186,9 @@ def lazy_apply( # type: ignore[valid-type] # numpydoc ignore=GL07,SA04
186186
187187 See Also
188188 --------
189+ lazy_apply_elementwise
189190 jax.transfer_guard
190191 jax.pure_callback
191- dask.array.map_blocks
192192 dask.array.map_overlap
193193 dask.array.blockwise
194194 """
@@ -240,7 +240,7 @@ def lazy_apply( # type: ignore[valid-type] # numpydoc ignore=GL07,SA04
240240 if is_dask_namespace (xp ):
241241 import dask
242242
243- metas : list [Array ] = [arg ._meta for arg in array_args ] # pylint: disable=protected-access # pyright: ignore[reportAttributeAccessIssue]
243+ metas : list [Array ] = [arg ._meta for arg in array_args ] # type: ignore[attr-defined] # pylint: disable=protected-access # pyright: ignore[reportAttributeAccessIssue]
244244 meta_xp = array_namespace (* metas )
245245
246246 wrapped = dask .delayed ( # type: ignore[attr-defined] # pyright: ignore[reportPrivateImportUsage]
@@ -355,3 +355,145 @@ def wrapper( # type: ignore[decorated-any,explicit-any]
355355 return (xp .asarray (out , device = device ),)
356356
357357 return wrapper
358+
359+
360+ @overload
361+ def lazy_apply_elementwise ( # type: ignore[explicit-any,decorated-any]
362+ func : Callable [..., Array | ArrayLike ],
363+ * args : Array | complex | None ,
364+ dtype : DType | None = None ,
365+ as_numpy : bool = False ,
366+ xp : ModuleType | None = None ,
367+ ** kwargs : Any ,
368+ ) -> Array : ... # numpydoc ignore=GL08
369+
370+
371+ @overload
372+ def lazy_apply_elementwise ( # type: ignore[explicit-any,decorated-any]
373+ func : Callable [..., Sequence [Array | ArrayLike ]],
374+ * args : Array | complex | None ,
375+ dtype : Sequence [DType | None ],
376+ as_numpy : bool = False ,
377+ xp : ModuleType | None = None ,
378+ ** kwargs : Any ,
379+ ) -> tuple [Array , ...]: ... # numpydoc ignore=GL08
380+
381+
382+ def lazy_apply_elementwise ( # type: ignore[explicit-any]
383+ func : Callable [..., Array | ArrayLike | Sequence [Array | ArrayLike ]],
384+ * args : Array | complex | None ,
385+ dtype : DType | Sequence [DType | None ] | None = None ,
386+ as_numpy : bool = False ,
387+ xp : ModuleType | None = None ,
388+ ** kwargs : Any ,
389+ ) -> Array | tuple [Array , ...]:
390+ """
391+ Lazily apply an eager elementwise function.
392+
393+ This is a variant of :func:`lazy_apply` which expects `func` to be elementwise, e.g.
394+ each output point must depend exclusively from the corresponding input point in each
395+ inputarray. This can result in faster execution on some backends.
396+
397+ Parameters
398+ ----------
399+ func : callable
400+ As in `lazy_apply`, but in addition it must be elementwise.
401+ *args : Array | int | float | complex | bool | None
402+ As in `lazy_apply`.
403+ dtype : DType | Sequence[DType | None], optional
404+ Output dtype or sequence of output dtypes, one for each output of `func`.
405+ dtype(s) must belong to the same array namespace as the input arrays.
406+ This also informs how many outputs the function has.
407+ Default: assume a single output and infer the result type(s) from
408+ the input arrays.
409+ as_numpy : bool, optional
410+ As in `lazy_apply`.
411+ xp : array_namespace, optional
412+ The standard-compatible namespace for `args`. Default: infer.
413+ **kwargs : Any, optional
414+ As in `lazy_apply`.
415+
416+ Returns
417+ -------
418+ Array | tuple[Array, ...]
419+ The result(s) of `func` applied to the input arrays, wrapped in the same
420+ array namespace as the inputs.
421+ If dtype is omitted or a single dtype, return a single array.
422+ Otherwise, return a tuple of arrays.
423+
424+ See Also
425+ --------
426+ lazy_apply : General version of this function.
427+ dask.array.map_blocks : Dask version of this function.
428+
429+ Notes
430+ -----
431+ Unlike in :func:`lazy_apply`, you can't define output shapes that aren't
432+ broadcasted from the input arrays.
433+
434+ Dask
435+ Unlike :func:`dask.array.map_blocks`, this function allows for multiple outputs.
436+
437+ Dask wrapping around other backends
438+ If ``as_numpy=False``, `func` will receive in input eager arrays of the meta
439+ namespace, as defined by the ``._meta`` attribute of the input Dask arrays. The
440+ outputs of `func` will be wrapped by the meta namespace, and then wrapped again
441+ by Dask.
442+
443+ All other backends
444+ This function is identical to :func:`lazy_apply`.
445+ """
446+ args_not_none = [arg for arg in args if arg is not None ]
447+ array_args = [arg for arg in args_not_none if not is_python_scalar (arg )]
448+ if not array_args :
449+ msg = "Must have at least one argument array"
450+ raise ValueError (msg )
451+ if xp is None :
452+ xp = array_namespace (* array_args )
453+
454+ # Normalize and validate dtype
455+ dtypes : list [DType ]
456+
457+ if isinstance (dtype , Sequence ):
458+ multi_output = True
459+ if None in dtype :
460+ rtype = xp .result_type (* args_not_none )
461+ dtypes = [d or rtype for d in dtype ]
462+ else :
463+ dtypes = list (dtype ) # pyright: ignore[reportUnknownArgumentType]
464+ else :
465+ multi_output = False
466+ dtypes = [dtype ]
467+ del dtype
468+
469+ if not is_dask_namespace (xp ):
470+ shape = broadcast_shapes (* (arg .shape for arg in array_args ))
471+ return lazy_apply ( # pyright: ignore[reportCallIssue]
472+ func , # type: ignore[arg-type] # pyright: ignore[reportArgumentType]
473+ * args ,
474+ shape = [shape ] * len (dtypes ) if multi_output else shape , # type: ignore[arg-type] # pyright: ignore[reportArgumentType]
475+ dtype = dtypes if multi_output else dtypes [0 ],
476+ as_numpy = as_numpy ,
477+ xp = xp ,
478+ ** kwargs ,
479+ )
480+
481+ # Use da.map_blocks.
482+ # We need to handle multiple outputs, which map_blocks can't.
483+
484+ metas : list [Array ] = [arg ._meta for arg in array_args ] # type: ignore[attr-defined] # pylint: disable=protected-access # pyright: ignore[reportAttributeAccessIssue]
485+ meta_xp = array_namespace (* metas )
486+
487+ wrapped = _lazy_apply_wrapper (func , as_numpy , multi_output , meta_xp )
488+ wrapped = partial (wrapped , ** kwargs )
489+
490+ # Hack map_blocks to handle multiple outputs. This intermediate output has bugos
491+ # dtype and meta, but dask.array will never know as long as we always provide
492+ # explicit dtype and meta.
493+ temp = xp .map_blocks (wrapped , * args , dtype = dtypes [0 ], meta = metas [0 ])
494+ out = tuple (
495+ temp .map_blocks (operator .itemgetter (i ), dtype = dtype , meta = metas [0 ])
496+ for i , dtype in enumerate (dtypes )
497+ )
498+
499+ return out if multi_output else out [0 ]
0 commit comments