Skip to content

Commit cd0e8aa

Browse files
committed
WIP lazywhere
1 parent 27b0bf2 commit cd0e8aa

File tree

1 file changed

+126
-3
lines changed

1 file changed

+126
-3
lines changed

src/array_api_extra/_lib/_funcs.py

Lines changed: 126 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,17 +5,18 @@
55

66
import math
77
import warnings
8-
from collections.abc import Sequence
8+
from collections.abc import Callable, Sequence
99
from types import ModuleType
10-
from typing import cast
10+
from typing import cast, overload
1111

1212
from ._at import at
1313
from ._utils import _compat, _helpers
14-
from ._utils._compat import array_namespace, is_jax_array
14+
from ._utils._compat import array_namespace, is_array_api_obj, is_jax_array
1515
from ._utils._helpers import asarrays
1616
from ._utils._typing import Array
1717

1818
__all__ = [
19+
"apply_where",
1920
"atleast_nd",
2021
"cov",
2122
"create_diagonal",
@@ -28,6 +29,128 @@
2829
]
2930

3031

32+
@overload
33+
def apply_where(
34+
cond: Array,
35+
f1: Callable[..., Array],
36+
f2: Callable[..., Array],
37+
/,
38+
*args: Array,
39+
xp: ModuleType | None = None,
40+
): ...
41+
42+
43+
@overload
44+
def apply_where(
45+
cond: Array,
46+
f1: Callable[..., Array],
47+
/,
48+
*args: Array,
49+
fill_value: Array | int | float | complex | bool,
50+
xp: ModuleType | None = None,
51+
): ...
52+
53+
54+
def apply_where(
55+
cond: Array,
56+
f1: Callable[..., Array],
57+
f2_or_args0: Callable[..., Array] | Array,
58+
/,
59+
*args: Array,
60+
fill_value: Array | int | float | complex | bool | None = None,
61+
xp: ModuleType | None = None,
62+
):
63+
"""Return elements chosen from two possibilities depending on a condition
64+
65+
Equivalent to ``f1(*args) if cond else f2(*args)`` performed elementwise.
66+
67+
Parameters
68+
----------
69+
cond : array
70+
The condition (expressed as a boolean array).
71+
f1 : callable
72+
Where `cond` is True, output will be ``f1(arr1[cond], arr2[cond], ...)``.
73+
f2 : callable, optional
74+
Where `cond` is False, output will be ``f1(arr1[cond], arr2[cond], ...)``.
75+
Mutually exclusive with `fill_value`.
76+
*args : one or more arrays
77+
Arguments to `f1` (and `f2`). Must be broadcastable with `cond`.
78+
fill_value : Array or scalar, optional
79+
If provided, value with which to fill output array where `cond` is
80+
not True. Mutually exclusive with `f2`. You must provide either one.
81+
xp : array_namespace, optional
82+
The standard-compatible namespace for `cond` and `args`. Default: infer.
83+
84+
Returns
85+
-------
86+
out : array
87+
An array with elements from the output of `f1` where `cond` is True and either
88+
the output of `f2` or `fill_value` where `cond` is False. The returned array has
89+
data type determined by Type Promotion Rules between the output of `f` and
90+
either `fill_value` or the output of `f2`.
91+
92+
Notes
93+
-----
94+
``xp.where(cond, f1(*args), f2(*args))`` requires explicitly evaluating f1 even when
95+
`cond` is False, and `f2` when cond is True. This function evaluates each function
96+
only for their matching condition when the backend allows for it.
97+
98+
Examples
99+
--------
100+
>>> a, b = xp.asarray([1, 2, 3, 4]), xp.asarray([5, 6, 7, 8])
101+
>>> def f(a, b):
102+
... return a * b
103+
>>> apply_where(a > 2, f, a, b, fill_value=xp.nan)
104+
array([ nan, nan, 21., 32.])
105+
106+
"""
107+
# Parse and normalize arguments
108+
mutually_exc_msg = "Exactly one of `fill_value` or `f2` must be given."
109+
if is_array_api_obj(f2_or_args0):
110+
args = (cast(Array, f2_or_args0), *args)
111+
if fill_value is not None:
112+
raise TypeError(mutually_exc_msg)
113+
f2: Callable[..., Array] | None = None
114+
else:
115+
if not callable(f2):
116+
msg = "Third parameter must be either an Array or callable."
117+
raise ValueError(msg)
118+
f2 = f2_or_args0
119+
if fill_value is None:
120+
raise TypeError(mutually_exc_msg)
121+
122+
xp = array_namespace(cond, *args) if xp is None else xp
123+
124+
if fill_value is not None and getattr(fill_value, "ndim", 0) != 0:
125+
msg = "`fill_value` must be a scalar."
126+
raise ValueError(msg)
127+
128+
args = xp.broadcast_arrays(cond, *args)
129+
bool_dtype = xp.asarray([True]).dtype # numpy 1.xx doesn't have `bool`
130+
cond, args = xp.astype(args[0], bool_dtype, copy=False), args[1:]
131+
132+
temp1 = f1(*(arr[cond] for arr in args))
133+
134+
if f2 is None:
135+
if is_array_api_obj(fill_value) or xp.__array_api_version__ >= "2024.12":
136+
dtype = xp.result_type(temp1.dtype, fill_value)
137+
else:
138+
# TODO: remove this branch when all backends support
139+
# Array API 2024.12
140+
dtype = (xp.zeros((), dtype=temp1.dtype) * fill_value).dtype
141+
out = xp.full(
142+
cond.shape, dtype=dtype, fill_value=xp.asarray(fill_value, dtype=dtype)
143+
)
144+
else:
145+
ncond = ~cond
146+
temp2 = xp.asarray(f2(*(arr[ncond] for arr in args)))
147+
dtype = xp.result_type(temp1, temp2)
148+
out = xp.empty(cond.shape, dtype=dtype)
149+
out = at(out, ncond).set(temp2)
150+
151+
return at(out, cond).set(temp1)
152+
153+
31154
def atleast_nd(x: Array, /, *, ndim: int, xp: ModuleType | None = None) -> Array:
32155
"""
33156
Recursively expand the dimension of an array to at least `ndim`.

0 commit comments

Comments
 (0)