1515from ._utils ._compat import (
1616 array_namespace ,
1717 is_array_api_obj ,
18+ is_dask_array ,
1819 is_dask_namespace ,
1920 is_jax_array ,
2021 is_jax_namespace ,
2122)
22- from ._utils ._helpers import asarrays , get_meta
23+ from ._utils ._helpers import asarrays
2324from ._utils ._typing import Array , DType
2425
2526__all__ = [
@@ -138,35 +139,34 @@ def apply_where( # type: ignore[no-any-explicit,misc] # numpydoc ignore=PR01,PR
138139
139140 xp = array_namespace (cond , * args ) if xp is None else xp
140141
142+ if not is_dask_namespace (xp ):
143+ return _apply_where (
144+ cond , f1 , f2_ , * args , fill_value = fill_value , dtype = None , xp = xp
145+ )
146+
147+ # Dask-specific code from here onwards
148+ metas = [arg ._meta for arg in args ] # pylint: disable=protected-access
149+ meta_xp = array_namespace (cond ._meta , * metas ) # pylint: disable=protected-access
141150 # Determine output dtype
142- metas = [get_meta (arg , xp = xp ) for arg in args ]
143- temp1 = f1 (* metas )
144- if f2_ is None :
145- if xp .__array_api_version__ >= "2024.12" or is_array_api_obj (fill_value ):
146- dtype = xp .result_type (temp1 .dtype , fill_value )
147- else :
148- # TODO: remove this when all backends support Array API 2024.12
149- dtype = (xp .empty ((), dtype = temp1 .dtype ) * fill_value ).dtype
151+ if f2_ is not None :
152+ dtype = meta_xp .result_type (f1 (* metas ), f2_ (* metas ))
153+ elif is_dask_array (fill_value ):
154+ dtype = meta_xp .result_type (f1 (* metas ), cast (Array , fill_value )._meta ) # pylint: disable=protected-access
150155 else :
151- temp2 = f2_ ( * metas )
152- dtype = xp .result_type (temp1 , temp2 )
156+ # TODO remove asarrays once all backends support Array API 2024.12
157+ dtype = meta_xp .result_type (* asarrays ( f1 ( * metas ), fill_value , xp = meta_xp ) )
153158
154- if is_dask_namespace (xp ):
155- # Dask does not support assignment by boolean mask
156- meta_xp = array_namespace (get_meta (cond ), * metas )
159+ return xp .map_blocks (
157160 # pass dtype to both da.map_blocks and _apply_where
158- return xp .map_blocks (
159- partial (_apply_where , dtype = dtype , xp = meta_xp ),
160- cond ,
161- f1 ,
162- f2_ ,
163- * args ,
164- fill_value = fill_value ,
165- dtype = dtype ,
166- meta = metas [0 ],
167- )
168-
169- return _apply_where (cond , f1 , f2_ , * args , fill_value = fill_value , dtype = dtype , xp = xp )
161+ partial (_apply_where , dtype = dtype , xp = meta_xp ),
162+ cond ,
163+ f1 ,
164+ f2_ ,
165+ * args ,
166+ fill_value = fill_value ,
167+ dtype = dtype ,
168+ meta = metas [0 ],
169+ )
170170
171171
172172def _apply_where ( # type: ignore[no-any-explicit] # numpydoc ignore=PR01,RT01
@@ -175,7 +175,7 @@ def _apply_where( # type: ignore[no-any-explicit] # numpydoc ignore=PR01,RT01
175175 f2 : Callable [..., Array ] | None ,
176176 * args : Array ,
177177 fill_value : Array | int | float | complex | bool | None ,
178- dtype : DType ,
178+ dtype : DType | None ,
179179 xp : ModuleType ,
180180) -> Array :
181181 """Helper of `apply_where`. On Dask, this runs on a single chunk."""
@@ -189,10 +189,15 @@ def _apply_where( # type: ignore[no-any-explicit] # numpydoc ignore=PR01,RT01
189189 temp1 = f1 (* (arr [cond ] for arr in args ))
190190
191191 if f2 is None :
192+ if dtype is None :
193+ # TODO remove asarrays once all backends support Array API 2024.12
194+ dtype = xp .result_type (* asarrays (temp1 , fill_value , xp = xp ))
192195 out = xp .full (cond .shape , fill_value = fill_value , dtype = dtype , device = device )
193196 else :
194197 ncond = ~ cond
195198 temp2 = f2 (* (arr [ncond ] for arr in args ))
199+ if dtype is None :
200+ dtype = xp .result_type (temp1 , temp2 )
196201 out = xp .empty (cond .shape , dtype = dtype , device = device )
197202 out = at (out , ncond ).set (temp2 )
198203
0 commit comments