Skip to content

Commit 77c254c

Browse files
committed
WIP lazywhere
1 parent 03f0b3e commit 77c254c

File tree

5 files changed

+206
-9
lines changed

5 files changed

+206
-9
lines changed

docs/api-reference.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
:nosignatures:
77
:toctree: generated
88
9+
apply_where
910
at
1011
atleast_nd
1112
broadcast_shapes

src/array_api_extra/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from ._delegation import isclose, pad
44
from ._lib._at import at
55
from ._lib._funcs import (
6+
apply_where,
67
atleast_nd,
78
broadcast_shapes,
89
cov,
@@ -19,6 +20,7 @@
1920
# pylint: disable=duplicate-code
2021
__all__ = [
2122
"__version__",
23+
"apply_where",
2224
"at",
2325
"atleast_nd",
2426
"broadcast_shapes",

src/array_api_extra/_lib/_funcs.py

Lines changed: 175 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,17 +5,25 @@
55

66
import math
77
import warnings
8-
from collections.abc import Sequence
8+
from collections.abc import Callable, Sequence
9+
from functools import partial
910
from types import ModuleType
10-
from typing import cast
11+
from typing import cast, overload
1112

1213
from ._at import at
1314
from ._utils import _compat, _helpers
14-
from ._utils._compat import array_namespace, is_jax_array
15-
from ._utils._helpers import asarrays
16-
from ._utils._typing import Array
15+
from ._utils._compat import (
16+
array_namespace,
17+
is_array_api_obj,
18+
is_dask_namespace,
19+
is_jax_array,
20+
is_jax_namespace,
21+
)
22+
from ._utils._helpers import asarrays, get_meta
23+
from ._utils._typing import Array, DType
1724

1825
__all__ = [
26+
"apply_where",
1927
"atleast_nd",
2028
"broadcast_shapes",
2129
"cov",
@@ -29,6 +37,168 @@
2937
]
3038

3139

40+
@overload
41+
def apply_where( # type: ignore[no-any-explicit,no-any-decorated] # numpydoc ignore=GL08
42+
cond: Array,
43+
f1: Callable[..., Array],
44+
f2: Callable[..., Array],
45+
/,
46+
*args: Array,
47+
xp: ModuleType | None = None,
48+
) -> Array: ...
49+
50+
51+
@overload
52+
def apply_where( # type: ignore[no-any-explicit,no-any-decorated] # numpydoc ignore=GL08
53+
cond: Array,
54+
f1: Callable[..., Array],
55+
/,
56+
*args: Array,
57+
fill_value: Array | int | float | complex | bool,
58+
xp: ModuleType | None = None,
59+
) -> Array: ...
60+
61+
62+
def apply_where( # type: ignore[no-any-explicit,misc] # numpydoc ignore=PR01,PR02
63+
cond: Array,
64+
f1: Callable[..., Array],
65+
f2: Callable[..., Array] | Array,
66+
/,
67+
*args: Array,
68+
fill_value: Array | int | float | complex | bool | None = None,
69+
xp: ModuleType | None = None,
70+
) -> Array:
71+
"""
72+
Run one of two elementwise functions depending on a condition.
73+
74+
Equivalent to ``f1(*args) if cond else fill_value`` performed elementwise
75+
when `fill_value` is defined, otherwise to ``f1(*args) if cond else f2(*args)``.
76+
77+
Parameters
78+
----------
79+
cond : array
80+
The condition, expressed as a boolean array.
81+
f1 : callable
82+
Where `cond` is True, output will be ``f1(arg0[cond], arg1[cond], ...)``.
83+
f2 : callable, optional
84+
Where `cond` is False, output will be ``f2(arg0[cond], arg1[cond], ...)``.
85+
Mutually exclusive with `fill_value`.
86+
*args : one or more arrays
87+
Arguments to `f1` (and `f2`). Must be broadcastable with `cond`.
88+
fill_value : Array or scalar, optional
89+
If provided, value with which to fill output array where `cond` is
90+
not True. Mutually exclusive with `f2`. You must provide one or the other.
91+
xp : array_namespace, optional
92+
The standard-compatible namespace for `cond` and `args`. Default: infer.
93+
94+
Returns
95+
-------
96+
Array
97+
An array with elements from the output of `f1` where `cond` is True and either
98+
the output of `f2` or `fill_value` where `cond` is False. The returned array has
99+
data type determined by type promotion rules between the output of `f1` and
100+
either `fill_value` or the output of `f2`.
101+
102+
Notes
103+
-----
104+
``xp.where(cond, f1(*args), f2(*args))`` requires explicitly evaluating `f1` even
105+
when `cond` is False, and `f2` when cond is True. This function evaluates each
106+
function only for their matching condition, if the backend allows for it.
107+
108+
Examples
109+
--------
110+
>>> a = xp.asarray([5, 4, 3])
111+
>>> b = xp.asarray([0, 2, 2])
112+
>>> def f(a, b):
113+
... return a // b
114+
>>> apply_where(b != 0, f, a, b, fill_value=xp.nan)
115+
array([ nan, 2., 1.])
116+
"""
117+
# Parse and normalize arguments
118+
mutually_exc_msg = "Exactly one of `fill_value` or `f2` must be given."
119+
if is_array_api_obj(f2):
120+
args = (cast(Array, f2), *args)
121+
if fill_value is not None:
122+
raise TypeError(mutually_exc_msg)
123+
f2_: Callable[..., Array] | None = None # type: ignore[no-any-explicit]
124+
else:
125+
if not callable(f2):
126+
msg = "Third parameter must be either an Array or callable."
127+
raise ValueError(msg)
128+
f2_ = cast(Callable[..., Array], f2) # type: ignore[no-any-explicit]
129+
if fill_value is None:
130+
raise TypeError(mutually_exc_msg)
131+
if getattr(fill_value, "ndim", 0) != 0:
132+
msg = "`fill_value` must be a scalar."
133+
raise ValueError(msg)
134+
del f2
135+
if not args:
136+
msg = "Must give at least one input array."
137+
raise TypeError(msg)
138+
139+
xp = array_namespace(cond, *args) if xp is None else xp
140+
141+
# Determine output dtype
142+
metas = [get_meta(arg, xp=xp) for arg in args]
143+
temp1 = f1(*metas)
144+
if f2_ is None:
145+
if xp.__array_api_version__ >= "2024.12" or is_array_api_obj(fill_value):
146+
dtype = xp.result_type(temp1.dtype, fill_value)
147+
else:
148+
# TODO: remove this when all backends support Array API 2024.12
149+
dtype = (xp.empty((), dtype=temp1.dtype) * fill_value).dtype
150+
else:
151+
temp2 = f2_(*metas)
152+
dtype = xp.result_type(temp1, temp2)
153+
154+
if is_dask_namespace(xp):
155+
# Dask does not support assignment by boolean mask
156+
meta_xp = array_namespace(get_meta(cond), *metas)
157+
# pass dtype to both da.map_blocks and _apply_where
158+
return xp.map_blocks(
159+
partial(_apply_where, dtype=dtype, xp=meta_xp),
160+
cond,
161+
f1,
162+
f2_,
163+
*args,
164+
fill_value=fill_value,
165+
dtype=dtype,
166+
meta=metas[0],
167+
)
168+
169+
return _apply_where(cond, f1, f2_, *args, fill_value=fill_value, dtype=dtype, xp=xp)
170+
171+
172+
def _apply_where( # type: ignore[no-any-explicit] # numpydoc ignore=PR01,RT01
173+
cond: Array,
174+
f1: Callable[..., Array],
175+
f2: Callable[..., Array] | None,
176+
*args: Array,
177+
fill_value: Array | int | float | complex | bool | None,
178+
dtype: DType,
179+
xp: ModuleType,
180+
) -> Array:
181+
"""Helper of `apply_where`. On Dask, this runs on a single chunk."""
182+
183+
if is_jax_namespace(xp):
184+
# jax.jit does not support assignment by boolean mask
185+
return xp.where(cond, f1(*args), f2(*args) if f2 is not None else fill_value)
186+
187+
device = _compat.device(cond)
188+
cond, *args = xp.broadcast_arrays(cond, *args) # pyright: ignore[reportAssignmentType]
189+
temp1 = f1(*(arr[cond] for arr in args))
190+
191+
if f2 is None:
192+
out = xp.full(cond.shape, fill_value=fill_value, dtype=dtype, device=device)
193+
else:
194+
ncond = ~cond
195+
temp2 = f2(*(arr[ncond] for arr in args))
196+
out = xp.empty(cond.shape, dtype=dtype, device=device)
197+
out = at(out, ncond).set(temp2)
198+
199+
return at(out, cond).set(temp1)
200+
201+
32202
def atleast_nd(x: Array, /, *, ndim: int, xp: ModuleType | None = None) -> Array:
33203
"""
34204
Recursively expand the dimension of an array to at least `ndim`.

src/array_api_extra/_lib/_utils/_helpers.py

Lines changed: 26 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,9 @@
77
from typing import cast
88

99
from . import _compat
10-
from ._compat import is_array_api_obj, is_numpy_array
10+
from ._compat import array_namespace, is_array_api_obj, is_dask_array, is_numpy_array
1111
from ._typing import Array
1212

13-
__all__ = ["in1d", "mean"]
14-
1513

1614
def in1d(
1715
x1: Array,
@@ -175,3 +173,28 @@ def asarrays(
175173
xa, xb = xp.asarray(a), xp.asarray(b)
176174

177175
return (xb, xa) if swap else (xa, xb)
176+
177+
178+
def get_meta(x: Array, xp: ModuleType | None = None) -> Array:
179+
"""
180+
Return a 0-sized dummy array that mocks `x`.
181+
182+
Parameters
183+
----------
184+
x : Array
185+
The array to mock.
186+
xp : ModuleType, optional
187+
The array namespace to use. If None, it is inferred from `x`.
188+
189+
Returns
190+
-------
191+
Array
192+
Array with size 0 with the same same namespace, dimensionality,
193+
dtype and device as `x`.
194+
On Dask, return instead the meta array of `x`, which has the
195+
namespace of the wrapped backend.
196+
"""
197+
if is_dask_array(x):
198+
return x._meta # pylint: disable=protected-access
199+
xp = array_namespace(x) if xp is None else xp
200+
return xp.empty((0,) * x.ndim, dtype=x.dtype, device=_compat.device(x))

src/array_api_extra/_lib/_utils/_typing.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,8 @@
44

55
# To be changed to a Protocol later (see data-apis/array-api#589)
66
Array = Any # type: ignore[no-any-explicit]
7+
DType = Any # type: ignore[no-any-explicit]
78
Device = Any # type: ignore[no-any-explicit]
89
Index = Any # type: ignore[no-any-explicit]
910

10-
__all__ = ["Array", "Device", "Index"]
11+
__all__ = ["Array", "DType", "Device", "Index"]

0 commit comments

Comments
 (0)