33from  __future__ import  annotations 
44
55import  math 
6+ import  operator 
67from  collections .abc  import  Callable , Sequence 
78from  functools  import  partial , wraps 
89from  types  import  ModuleType 
9- from  typing  import  TYPE_CHECKING , Any , ParamSpec ,  TypeAlias , cast , overload 
10+ from  typing  import  TYPE_CHECKING , Any , TypeAlias , cast , overload 
1011
1112from  ._funcs  import  broadcast_shapes 
1213from  ._utils  import  _compat 
2728    # Sphinx hack 
2829    NumPyObject  =  Any 
2930
30- P  =  ParamSpec ("P" )
31- 
3231
3332@overload  
34- def  lazy_apply (  # type: ignore[decorated -any, valid-type ] 
35-     func : Callable [P , Array  |  ArrayLike ],
33+ def  lazy_apply (  # type: ignore[explicit -any,decorated-any ] 
34+     func : Callable [... , Array  |  ArrayLike ],
3635    * args : Array  |  complex  |  None ,
3736    shape : tuple [int  |  None , ...] |  None  =  None ,
3837    dtype : DType  |  None  =  None ,
3938    as_numpy : bool  =  False ,
4039    xp : ModuleType  |  None  =  None ,
41-     ** kwargs : P . kwargs ,   # pyright: ignore[reportGeneralTypeIssues] 
40+     ** kwargs : Any , 
4241) ->  Array : ...  # numpydoc ignore=GL08 
4342
4443
4544@overload  
46- def  lazy_apply (  # type: ignore[decorated -any, valid-type ] 
47-     func : Callable [P , Sequence [Array  |  ArrayLike ]],
45+ def  lazy_apply (  # type: ignore[explicit -any,decorated-any ] 
46+     func : Callable [... , Sequence [Array  |  ArrayLike ]],
4847    * args : Array  |  complex  |  None ,
4948    shape : Sequence [tuple [int  |  None , ...]],
5049    dtype : Sequence [DType ] |  None  =  None ,
5150    as_numpy : bool  =  False ,
5251    xp : ModuleType  |  None  =  None ,
53-     ** kwargs : P . kwargs ,   # pyright: ignore[reportGeneralTypeIssues] 
52+     ** kwargs : Any , 
5453) ->  tuple [Array , ...]: ...  # numpydoc ignore=GL08 
5554
5655
57- def  lazy_apply (  # type: ignore[valid-type ]  # numpydoc ignore=GL07,SA04 
58-     func : Callable [P , Array  |  ArrayLike  |  Sequence [Array  |  ArrayLike ]],
56+ def  lazy_apply (  # type: ignore[explicit-any ]  # numpydoc ignore=GL07,SA04 
57+     func : Callable [... , Array  |  ArrayLike  |  Sequence [Array  |  ArrayLike ]],
5958    * args : Array  |  complex  |  None ,
6059    shape : tuple [int  |  None , ...] |  Sequence [tuple [int  |  None , ...]] |  None  =  None ,
6160    dtype : DType  |  Sequence [DType ] |  None  =  None ,
6261    as_numpy : bool  =  False ,
6362    xp : ModuleType  |  None  =  None ,
64-     ** kwargs : P . kwargs ,   # pyright: ignore[reportGeneralTypeIssues] 
63+     ** kwargs : Any , 
6564) ->  Array  |  tuple [Array , ...]:
6665    """ 
6766    Lazily apply an eager function. 
@@ -162,10 +161,11 @@ def lazy_apply(  # type: ignore[valid-type]  # numpydoc ignore=GL07,SA04
162161        The outputs will also be returned as a single chunk and you should consider 
163162        rechunking them into smaller chunks afterwards. 
164163
165-         If you want to distribute the calculation across multiple workers, you 
166-         should use :func:`dask.array.map_blocks`, :func:`dask.array.map_overlap`, 
167-         :func:`dask.array.blockwise`, or a native Dask wrapper instead of 
168-         `lazy_apply`. 
164+         If you want to distribute the calculation across multiple workers and your 
165+         function is elementwise, you should use :func:`lazy_apply_elementwise` instead. 
166+         If the function is not elementwise, you should consider writing an ad-hoc 
167+         variant for Dask using primitives like :func:`dask.array.blockwise`, 
168+         :func:`dask.array.map_overlap`, or a native Dask algorithm. 
169169
170170    Dask wrapping around other backends 
171171        If ``as_numpy=False``, `func` will receive in input eager arrays of the meta 
@@ -186,9 +186,9 @@ def lazy_apply(  # type: ignore[valid-type]  # numpydoc ignore=GL07,SA04
186186
187187    See Also 
188188    -------- 
189+     lazy_apply_elementwise 
189190    jax.transfer_guard 
190191    jax.pure_callback 
191-     dask.array.map_blocks 
192192    dask.array.map_overlap 
193193    dask.array.blockwise 
194194    """ 
@@ -240,7 +240,7 @@ def lazy_apply(  # type: ignore[valid-type]  # numpydoc ignore=GL07,SA04
240240    if  is_dask_namespace (xp ):
241241        import  dask 
242242
243-         metas : list [Array ] =  [arg ._meta  for  arg  in  array_args ]  # pylint: disable=protected-access     # pyright: ignore[reportAttributeAccessIssue] 
243+         metas : list [Array ] =  [arg ._meta  for  arg  in  array_args ]  # type: ignore[attr-defined] #  pylint: disable=protected-access  # pyright: ignore[reportAttributeAccessIssue] 
244244        meta_xp  =  array_namespace (* metas )
245245
246246        wrapped  =  dask .delayed (  # type: ignore[attr-defined]  # pyright: ignore[reportPrivateImportUsage] 
@@ -355,3 +355,145 @@ def wrapper(  # type: ignore[decorated-any,explicit-any]
355355        return  (xp .asarray (out , device = device ),)
356356
357357    return  wrapper 
358+ 
359+ 
360+ @overload  
361+ def  lazy_apply_elementwise (  # type: ignore[explicit-any,decorated-any] 
362+     func : Callable [..., Array  |  ArrayLike ],
363+     * args : Array  |  complex  |  None ,
364+     dtype : DType  |  None  =  None ,
365+     as_numpy : bool  =  False ,
366+     xp : ModuleType  |  None  =  None ,
367+     ** kwargs : Any ,
368+ ) ->  Array : ...  # numpydoc ignore=GL08 
369+ 
370+ 
371+ @overload  
372+ def  lazy_apply_elementwise (  # type: ignore[explicit-any,decorated-any] 
373+     func : Callable [..., Sequence [Array  |  ArrayLike ]],
374+     * args : Array  |  complex  |  None ,
375+     dtype : Sequence [DType  |  None ],
376+     as_numpy : bool  =  False ,
377+     xp : ModuleType  |  None  =  None ,
378+     ** kwargs : Any ,
379+ ) ->  tuple [Array , ...]: ...  # numpydoc ignore=GL08 
380+ 
381+ 
382+ def  lazy_apply_elementwise (  # type: ignore[explicit-any] 
383+     func : Callable [..., Array  |  ArrayLike  |  Sequence [Array  |  ArrayLike ]],
384+     * args : Array  |  complex  |  None ,
385+     dtype : DType  |  Sequence [DType  |  None ] |  None  =  None ,
386+     as_numpy : bool  =  False ,
387+     xp : ModuleType  |  None  =  None ,
388+     ** kwargs : Any ,
389+ ) ->  Array  |  tuple [Array , ...]:
390+     """ 
391+     Lazily apply an eager elementwise function. 
392+ 
393+     This is a variant of :func:`lazy_apply` which expects `func` to be elementwise, e.g. 
394+     each output point must depend exclusively from the corresponding input point in each 
395+     inputarray. This can result in faster execution on some backends. 
396+ 
397+     Parameters 
398+     ---------- 
399+     func : callable 
400+         As in `lazy_apply`, but in addition it must be elementwise. 
401+     *args : Array | int | float | complex | bool | None 
402+         As in `lazy_apply`. 
403+     dtype : DType | Sequence[DType | None], optional 
404+         Output dtype or sequence of output dtypes, one for each output of `func`. 
405+         dtype(s) must belong to the same array namespace as the input arrays. 
406+         This also informs how many outputs the function has. 
407+         Default: assume a single output and infer the result type(s) from 
408+         the input arrays. 
409+     as_numpy : bool, optional 
410+         As in `lazy_apply`. 
411+     xp : array_namespace, optional 
412+         The standard-compatible namespace for `args`. Default: infer. 
413+     **kwargs : Any, optional 
414+         As in `lazy_apply`. 
415+ 
416+     Returns 
417+     ------- 
418+     Array | tuple[Array, ...] 
419+         The result(s) of `func` applied to the input arrays, wrapped in the same 
420+         array namespace as the inputs. 
421+         If dtype is omitted or a single dtype, return a single array. 
422+         Otherwise, return a tuple of arrays. 
423+ 
424+     See Also 
425+     -------- 
426+     lazy_apply : General version of this function. 
427+     dask.array.map_blocks : Dask version of this function. 
428+ 
429+     Notes 
430+     ----- 
431+     Unlike in :func:`lazy_apply`, you can't define output shapes that aren't 
432+     broadcasted from the input arrays. 
433+ 
434+     Dask 
435+         Unlike :func:`dask.array.map_blocks`, this function allows for multiple outputs. 
436+ 
437+     Dask wrapping around other backends 
438+         If ``as_numpy=False``, `func` will receive in input eager arrays of the meta 
439+         namespace, as defined by the ``._meta`` attribute of the input Dask arrays. The 
440+         outputs of `func` will be wrapped by the meta namespace, and then wrapped again 
441+         by Dask. 
442+ 
443+     All other backends 
444+         This function is identical to :func:`lazy_apply`. 
445+     """ 
446+     args_not_none  =  [arg  for  arg  in  args  if  arg  is  not   None ]
447+     array_args  =  [arg  for  arg  in  args_not_none  if  not  is_python_scalar (arg )]
448+     if  not  array_args :
449+         msg  =  "Must have at least one argument array" 
450+         raise  ValueError (msg )
451+     if  xp  is  None :
452+         xp  =  array_namespace (* array_args )
453+ 
454+     # Normalize and validate dtype 
455+     dtypes : list [DType ]
456+ 
457+     if  isinstance (dtype , Sequence ):
458+         multi_output  =  True 
459+         if  None  in  dtype :
460+             rtype  =  xp .result_type (* args_not_none )
461+             dtypes  =  [d  or  rtype  for  d  in  dtype ]
462+         else :
463+             dtypes  =  list (dtype )  # pyright: ignore[reportUnknownArgumentType] 
464+     else :
465+         multi_output  =  False 
466+         dtypes  =  [dtype ]
467+     del  dtype 
468+ 
469+     if  not  is_dask_namespace (xp ):
470+         shape  =  broadcast_shapes (* (arg .shape  for  arg  in  array_args ))
471+         return  lazy_apply (  # pyright: ignore[reportCallIssue] 
472+             func ,  # type: ignore[arg-type]  # pyright: ignore[reportArgumentType] 
473+             * args ,
474+             shape = [shape ] *  len (dtypes ) if  multi_output  else  shape ,  # type: ignore[arg-type]  # pyright: ignore[reportArgumentType] 
475+             dtype = dtypes  if  multi_output  else  dtypes [0 ],
476+             as_numpy = as_numpy ,
477+             xp = xp ,
478+             ** kwargs ,
479+         )
480+ 
481+     # Use da.map_blocks. 
482+     # We need to handle multiple outputs, which map_blocks can't. 
483+ 
484+     metas : list [Array ] =  [arg ._meta  for  arg  in  array_args ]  # type: ignore[attr-defined]  # pylint: disable=protected-access  # pyright: ignore[reportAttributeAccessIssue] 
485+     meta_xp  =  array_namespace (* metas )
486+ 
487+     wrapped  =  _lazy_apply_wrapper (func , as_numpy , multi_output , meta_xp )
488+     wrapped  =  partial (wrapped , ** kwargs )
489+ 
490+     # Hack map_blocks to handle multiple outputs. This intermediate output has bugos 
491+     # dtype and meta, but dask.array will never know as long as we always provide 
492+     # explicit dtype and meta. 
493+     temp  =  xp .map_blocks (wrapped , * args , dtype = dtypes [0 ], meta = metas [0 ])
494+     out  =  tuple (
495+         temp .map_blocks (operator .itemgetter (i ), dtype = dtype , meta = metas [0 ])
496+         for  i , dtype  in  enumerate (dtypes )
497+     )
498+ 
499+     return  out  if  multi_output  else  out [0 ]
0 commit comments