@@ -90,7 +90,8 @@ def apply_where( # type: ignore[no-any-explicit] # numpydoc ignore=PR01,PR02
9090 Argument(s) to `f1` (and `f2`). Must be broadcastable with `cond`.
9191 fill_value : Array or scalar, optional
9292 If provided, value with which to fill output array where `cond` is False.
93- It does not need to be scalar.
93+ It does not need to be scalar; it needs however to be broadcastable with
94+ `cond` and `args`.
9495 Mutually exclusive with `f2`. You must provide one or the other.
9596 xp : array_namespace, optional
9697 The standard-compatible namespace for `cond` and `args`. Default: infer.
@@ -147,7 +148,8 @@ def apply_where( # type: ignore[no-any-explicit] # numpydoc ignore=PR01,PR02
147148 cond , * args = xp .broadcast_arrays (cond , * args )
148149
149150 if is_dask_namespace (xp ):
150- meta_xp = meta_namespace (cond , fill_value , * args , xp = xp )
151+ meta_xp = meta_namespace (cond , * args , fill_value , xp = xp )
152+ # map_blocks doesn't descend into tuples of Arrays
151153 return xp .map_blocks (_apply_where , cond , f1 , f2 , fill_value , * args , xp = meta_xp )
152154 return _apply_where (cond , f1 , f2 , fill_value , * args , xp = xp )
153155
@@ -166,21 +168,20 @@ def _apply_where( # type: ignore[no-any-explicit] # numpydoc ignore=PR01,RT01
166168 # jax.jit does not support assignment by boolean mask
167169 return xp .where (cond , f1 (* args ), f2 (* args ) if f2 is not None else fill_value )
168170
169- device = _compat .device (cond )
170171 temp1 = f1 (* (arr [cond ] for arr in args ))
171172
172173 if f2 is None :
173174 # TODO remove asarrays once all backends support Array API 2024.12
174175 dtype = xp .result_type (* asarrays (temp1 , fill_value , xp = xp ))
175176 if getattr (fill_value , "ndim" , 0 ):
176- fill_value = xp .astype (fill_value , dtype )
177- return at ( fill_value , cond ). set ( temp1 , copy = True )
178- out = xp .full (cond . shape , fill_value = fill_value , dtype = dtype , device = device )
177+ out = xp .astype (fill_value , dtype , copy = True )
178+ else :
179+ out = xp .full_like (cond , dtype = dtype , fill_value = fill_value )
179180 else :
180181 ncond = ~ cond
181182 temp2 = f2 (* (arr [ncond ] for arr in args ))
182183 dtype = xp .result_type (temp1 , temp2 )
183- out = xp .empty (cond . shape , dtype = dtype , device = device )
184+ out = xp .empty_like (cond , dtype = dtype )
184185 out = at (out , ncond ).set (temp2 )
185186
186187 return at (out , cond ).set (temp1 )
0 commit comments