Skip to content

Commit 3e669a2

Browse files
committed
Rework
1 parent 77c254c commit 3e669a2

File tree

2 files changed

+35
-53
lines changed

2 files changed

+35
-53
lines changed

src/array_api_extra/_lib/_funcs.py

Lines changed: 32 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -15,11 +15,12 @@
1515
from ._utils._compat import (
1616
array_namespace,
1717
is_array_api_obj,
18+
is_dask_array,
1819
is_dask_namespace,
1920
is_jax_array,
2021
is_jax_namespace,
2122
)
22-
from ._utils._helpers import asarrays, get_meta
23+
from ._utils._helpers import asarrays
2324
from ._utils._typing import Array, DType
2425

2526
__all__ = [
@@ -138,35 +139,34 @@ def apply_where( # type: ignore[no-any-explicit,misc] # numpydoc ignore=PR01,PR
138139

139140
xp = array_namespace(cond, *args) if xp is None else xp
140141

142+
if not is_dask_namespace(xp):
143+
return _apply_where(
144+
cond, f1, f2_, *args, fill_value=fill_value, dtype=None, xp=xp
145+
)
146+
147+
# Dask-specific code from here onwards
148+
metas = [arg._meta for arg in args] # pylint: disable=protected-access
149+
meta_xp = array_namespace(cond._meta, *metas) # pylint: disable=protected-access
141150
# Determine output dtype
142-
metas = [get_meta(arg, xp=xp) for arg in args]
143-
temp1 = f1(*metas)
144-
if f2_ is None:
145-
if xp.__array_api_version__ >= "2024.12" or is_array_api_obj(fill_value):
146-
dtype = xp.result_type(temp1.dtype, fill_value)
147-
else:
148-
# TODO: remove this when all backends support Array API 2024.12
149-
dtype = (xp.empty((), dtype=temp1.dtype) * fill_value).dtype
151+
if f2_ is not None:
152+
dtype = meta_xp.result_type(f1(*metas), f2_(*metas))
153+
elif is_dask_array(fill_value):
154+
dtype = meta_xp.result_type(f1(*metas), cast(Array, fill_value)._meta) # pylint: disable=protected-access
150155
else:
151-
temp2 = f2_(*metas)
152-
dtype = xp.result_type(temp1, temp2)
156+
# TODO remove asarrays once all backends support Array API 2024.12
157+
dtype = meta_xp.result_type(*asarrays(f1(*metas), fill_value, xp=meta_xp))
153158

154-
if is_dask_namespace(xp):
155-
# Dask does not support assignment by boolean mask
156-
meta_xp = array_namespace(get_meta(cond), *metas)
159+
return xp.map_blocks(
157160
# pass dtype to both da.map_blocks and _apply_where
158-
return xp.map_blocks(
159-
partial(_apply_where, dtype=dtype, xp=meta_xp),
160-
cond,
161-
f1,
162-
f2_,
163-
*args,
164-
fill_value=fill_value,
165-
dtype=dtype,
166-
meta=metas[0],
167-
)
168-
169-
return _apply_where(cond, f1, f2_, *args, fill_value=fill_value, dtype=dtype, xp=xp)
161+
partial(_apply_where, dtype=dtype, xp=meta_xp),
162+
cond,
163+
f1,
164+
f2_,
165+
*args,
166+
fill_value=fill_value,
167+
dtype=dtype,
168+
meta=metas[0],
169+
)
170170

171171

172172
def _apply_where( # type: ignore[no-any-explicit] # numpydoc ignore=PR01,RT01
@@ -175,7 +175,7 @@ def _apply_where( # type: ignore[no-any-explicit] # numpydoc ignore=PR01,RT01
175175
f2: Callable[..., Array] | None,
176176
*args: Array,
177177
fill_value: Array | int | float | complex | bool | None,
178-
dtype: DType,
178+
dtype: DType | None,
179179
xp: ModuleType,
180180
) -> Array:
181181
"""Helper of `apply_where`. On Dask, this runs on a single chunk."""
@@ -189,10 +189,15 @@ def _apply_where( # type: ignore[no-any-explicit] # numpydoc ignore=PR01,RT01
189189
temp1 = f1(*(arr[cond] for arr in args))
190190

191191
if f2 is None:
192+
if dtype is None:
193+
# TODO remove asarrays once all backends support Array API 2024.12
194+
dtype = xp.result_type(*asarrays(temp1, fill_value, xp=xp))
192195
out = xp.full(cond.shape, fill_value=fill_value, dtype=dtype, device=device)
193196
else:
194197
ncond = ~cond
195198
temp2 = f2(*(arr[ncond] for arr in args))
199+
if dtype is None:
200+
dtype = xp.result_type(temp1, temp2)
196201
out = xp.empty(cond.shape, dtype=dtype, device=device)
197202
out = at(out, ncond).set(temp2)
198203

src/array_api_extra/_lib/_utils/_helpers.py

Lines changed: 3 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,11 @@
77
from typing import cast
88

99
from . import _compat
10-
from ._compat import array_namespace, is_array_api_obj, is_dask_array, is_numpy_array
10+
from ._compat import is_array_api_obj, is_numpy_array
1111
from ._typing import Array
1212

13+
__all__ = ["asarrays", "in1d", "is_python_scalar", "mean"]
14+
1315

1416
def in1d(
1517
x1: Array,
@@ -173,28 +175,3 @@ def asarrays(
173175
xa, xb = xp.asarray(a), xp.asarray(b)
174176

175177
return (xb, xa) if swap else (xa, xb)
176-
177-
178-
def get_meta(x: Array, xp: ModuleType | None = None) -> Array:
179-
"""
180-
Return a 0-sized dummy array that mocks `x`.
181-
182-
Parameters
183-
----------
184-
x : Array
185-
The array to mock.
186-
xp : ModuleType, optional
187-
The array namespace to use. If None, it is inferred from `x`.
188-
189-
Returns
190-
-------
191-
Array
192-
Array with size 0 with the same same namespace, dimensionality,
193-
dtype and device as `x`.
194-
On Dask, return instead the meta array of `x`, which has the
195-
namespace of the wrapped backend.
196-
"""
197-
if is_dask_array(x):
198-
return x._meta # pylint: disable=protected-access
199-
xp = array_namespace(x) if xp is None else xp
200-
return xp.empty((0,) * x.ndim, dtype=x.dtype, device=_compat.device(x))

0 commit comments

Comments
 (0)