|
5 | 5 |
|
6 | 6 | import math |
7 | 7 | import warnings |
8 | | -from collections.abc import Sequence |
| 8 | +from collections.abc import Callable, Sequence |
9 | 9 | from types import ModuleType |
10 | | -from typing import cast |
| 10 | +from typing import cast, overload |
11 | 11 |
|
12 | 12 | from ._at import at |
13 | 13 | from ._utils import _compat, _helpers |
14 | | -from ._utils._compat import array_namespace, is_jax_array |
| 14 | +from ._utils._compat import array_namespace, is_array_api_obj, is_jax_array |
15 | 15 | from ._utils._helpers import asarrays |
16 | 16 | from ._utils._typing import Array |
17 | 17 |
|
18 | 18 | __all__ = [ |
| 19 | + "apply_where", |
19 | 20 | "atleast_nd", |
20 | 21 | "cov", |
21 | 22 | "create_diagonal", |
|
28 | 29 | ] |
29 | 30 |
|
30 | 31 |
|
| 32 | +@overload |
| 33 | +def apply_where( |
| 34 | + cond: Array, |
| 35 | + f1: Callable[..., Array], |
| 36 | + f2: Callable[..., Array], |
| 37 | + /, |
| 38 | + *args: Array, |
| 39 | + xp: ModuleType | None = None, |
| 40 | +): ... |
| 41 | + |
| 42 | + |
| 43 | +@overload |
| 44 | +def apply_where( |
| 45 | + cond: Array, |
| 46 | + f1: Callable[..., Array], |
| 47 | + /, |
| 48 | + *args: Array, |
| 49 | + fill_value: Array | int | float | complex | bool, |
| 50 | + xp: ModuleType | None = None, |
| 51 | +): ... |
| 52 | + |
| 53 | + |
| 54 | +def apply_where( |
| 55 | + cond: Array, |
| 56 | + f1: Callable[..., Array], |
| 57 | + f2_or_args0: Callable[..., Array] | Array, |
| 58 | + /, |
| 59 | + *args: Array, |
| 60 | + fill_value: Array | int | float | complex | bool | None = None, |
| 61 | + xp: ModuleType | None = None, |
| 62 | +): |
| 63 | + """Return elements chosen from two possibilities depending on a condition |
| 64 | +
|
| 65 | + Equivalent to ``f1(*args) if cond else f2(*args)`` performed elementwise. |
| 66 | +
|
| 67 | + Parameters |
| 68 | + ---------- |
| 69 | + cond : array |
| 70 | + The condition (expressed as a boolean array). |
| 71 | + f1 : callable |
| 72 | + Where `cond` is True, output will be ``f1(arr1[cond], arr2[cond], ...)``. |
| 73 | + f2 : callable, optional |
| 74 | + Where `cond` is False, output will be ``f1(arr1[cond], arr2[cond], ...)``. |
| 75 | + Mutually exclusive with `fill_value`. |
| 76 | + *args : one or more arrays |
| 77 | + Arguments to `f1` (and `f2`). Must be broadcastable with `cond`. |
| 78 | + fill_value : Array or scalar, optional |
| 79 | + If provided, value with which to fill output array where `cond` is |
| 80 | + not True. Mutually exclusive with `f2`. You must provide either one. |
| 81 | + xp : array_namespace, optional |
| 82 | + The standard-compatible namespace for `cond` and `args`. Default: infer. |
| 83 | +
|
| 84 | + Returns |
| 85 | + ------- |
| 86 | + out : array |
| 87 | + An array with elements from the output of `f1` where `cond` is True and either |
| 88 | + the output of `f2` or `fill_value` where `cond` is False. The returned array has |
| 89 | + data type determined by Type Promotion Rules between the output of `f` and |
| 90 | + either `fill_value` or the output of `f2`. |
| 91 | +
|
| 92 | + Notes |
| 93 | + ----- |
| 94 | + ``xp.where(cond, f1(*args), f2(*args))`` requires explicitly evaluating f1 even when |
| 95 | + `cond` is False, and `f2` when cond is True. This function evaluates each function |
| 96 | + only for their matching condition when the backend allows for it. |
| 97 | +
|
| 98 | + Examples |
| 99 | + -------- |
| 100 | + >>> a, b = xp.asarray([1, 2, 3, 4]), xp.asarray([5, 6, 7, 8]) |
| 101 | + >>> def f(a, b): |
| 102 | + ... return a * b |
| 103 | + >>> apply_where(a > 2, f, a, b, fill_value=xp.nan) |
| 104 | + array([ nan, nan, 21., 32.]) |
| 105 | +
|
| 106 | + """ |
| 107 | + # Parse and normalize arguments |
| 108 | + mutually_exc_msg = "Exactly one of `fill_value` or `f2` must be given." |
| 109 | + if is_array_api_obj(f2_or_args0): |
| 110 | + args = (cast(Array, f2_or_args0), *args) |
| 111 | + if fill_value is not None: |
| 112 | + raise TypeError(mutually_exc_msg) |
| 113 | + f2: Callable[..., Array] | None = None |
| 114 | + else: |
| 115 | + if not callable(f2): |
| 116 | + msg = "Third parameter must be either an Array or callable." |
| 117 | + raise ValueError(msg) |
| 118 | + f2 = f2_or_args0 |
| 119 | + if fill_value is None: |
| 120 | + raise TypeError(mutually_exc_msg) |
| 121 | + |
| 122 | + xp = array_namespace(cond, *args) if xp is None else xp |
| 123 | + |
| 124 | + if fill_value is not None and getattr(fill_value, "ndim", 0) != 0: |
| 125 | + msg = "`fill_value` must be a scalar." |
| 126 | + raise ValueError(msg) |
| 127 | + |
| 128 | + args = xp.broadcast_arrays(cond, *args) |
| 129 | + bool_dtype = xp.asarray([True]).dtype # numpy 1.xx doesn't have `bool` |
| 130 | + cond, args = xp.astype(args[0], bool_dtype, copy=False), args[1:] |
| 131 | + |
| 132 | + temp1 = f1(*(arr[cond] for arr in args)) |
| 133 | + |
| 134 | + if f2 is None: |
| 135 | + if is_array_api_obj(fill_value) or xp.__array_api_version__ >= "2024.12": |
| 136 | + dtype = xp.result_type(temp1.dtype, fill_value) |
| 137 | + else: |
| 138 | + # TODO: remove this branch when all backends support |
| 139 | + # Array API 2024.12 |
| 140 | + dtype = (xp.zeros((), dtype=temp1.dtype) * fill_value).dtype |
| 141 | + out = xp.full( |
| 142 | + cond.shape, dtype=dtype, fill_value=xp.asarray(fill_value, dtype=dtype) |
| 143 | + ) |
| 144 | + else: |
| 145 | + ncond = ~cond |
| 146 | + temp2 = xp.asarray(f2(*(arr[ncond] for arr in args))) |
| 147 | + dtype = xp.result_type(temp1, temp2) |
| 148 | + out = xp.empty(cond.shape, dtype=dtype) |
| 149 | + out = at(out, ncond).set(temp2) |
| 150 | + |
| 151 | + return at(out, cond).set(temp1) |
| 152 | + |
| 153 | + |
31 | 154 | def atleast_nd(x: Array, /, *, ndim: int, xp: ModuleType | None = None) -> Array: |
32 | 155 | """ |
33 | 156 | Recursively expand the dimension of an array to at least `ndim`. |
|
0 commit comments