Skip to content

Commit 0e8706e

Browse files
committed
v2
1 parent 6884a34 commit 0e8706e

File tree

2 files changed

+153
-69
lines changed

2 files changed

+153
-69
lines changed

array_api_compat/common/_helpers.py

Lines changed: 152 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,11 @@
77
"""
88
from __future__ import annotations
99

10+
import operator
1011
from typing import TYPE_CHECKING
1112

1213
if TYPE_CHECKING:
13-
from typing import Optional, Union, Any
14+
from typing import Callable, Optional, Union, Any
1415
from ._typing import Array, Device
1516

1617
import sys
@@ -811,9 +812,22 @@ def is_writeable_array(x):
811812
return False
812813
return True
813814

815+
def _parse_copy_param(x, copy: bool | None) -> bool:
816+
"""Preprocess and validate a copy parameter, in line with the same
817+
parameter in np.asarray(), np.astype(), etc.
818+
"""
819+
if copy is None:
820+
return not is_writeable_array(x)
821+
if copy is False:
822+
if not is_writeable_array(x):
823+
raise ValueError("Cannot avoid modifying parameter in place")
824+
elif copy is not True:
825+
raise ValueError(f"Invalid value for copy: {copy!r}")
826+
return copy
827+
814828
_undef = object()
815829

816-
def at(x, idx=_undef, /):
830+
class at:
817831
"""
818832
Update operations for read-only arrays.
819833
@@ -823,12 +837,22 @@ def at(x, idx=_undef, /):
823837
Keyword arguments (e.g. ``indices_are_sorted``) are passed to JAX and are
824838
quietly ignored for backends that don't support them.
825839
840+
Additionally, this introduces support for the `copy` keyword for all backends:
841+
842+
None
843+
x *may* be modified in place if it is possible and beneficial
844+
for performance. You should not use x after calling this function.
845+
True
846+
Ensure that the inputs are not modified. This is the default.
847+
False
848+
Raise ValueError if a copy cannot be avoided.
849+
826850
Examples
827851
--------
828852
Given either of these equivalent expressions::
829853
830-
x = at(x)[1].add(2)
831-
x = at(x, 1).add(2)
854+
x = at(x)[1].add(2, copy=None)
855+
x = at(x, 1).add(2, copy=None)
832856
833857
If x is a JAX array, they are the same as::
834858
@@ -845,16 +869,17 @@ def at(x, idx=_undef, /):
845869
846870
Warning
847871
-------
848-
You should always immediately overwrite the parameter array::
872+
When you use copy=None, you should always immediately overwrite
873+
the parameter array::
849874
850-
x = at(x, 0).set(2)
875+
x = at(x, 0).set(2, copy=None)
851876
852877
The anti-pattern below must be avoided, as it will result in different behaviour
853878
on read-only versus writeable arrays:
854879
855880
x = xp.asarray([0, 0, 0])
856-
y = at(x, 0).set(2)
857-
z = at(x, 1).set(3)
881+
y = at(x, 0).set(2, copy=None)
882+
z = at(x, 1).set(3, copy=None)
858883
859884
In the above example, y == [2, 0, 0] and z == [0, 3, 0] when x is read-only,
860885
whereas y == z == [2, 3, 0] when x is writeable!
@@ -863,18 +888,6 @@ def at(x, idx=_undef, /):
863888
--------
864889
https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.ndarray.at.html
865890
"""
866-
if is_jax_array(x):
867-
return x.at
868-
if is_numpy_array(x) and not x.flags.writeable:
869-
x = x.copy()
870-
return _InPlaceAt(x, idx)
871-
872-
class _InPlaceAt:
873-
"""Helper of at().
874-
875-
Trivially implement jax.numpy.ndarray.at for other backends.
876-
x is updated in place.
877-
"""
878891
__slots__ = ("x", "idx")
879892

880893
def __init__(self, x, idx=_undef):
@@ -890,7 +903,16 @@ def __getitem__(self, idx):
890903
self.idx = idx
891904
return self
892905

893-
def _check_args(self, mode="promise_in_bounds", **kwargs):
906+
def _common(self, at_op, y=_undef, mode: str = "promise_in_bounds", **kwargs):
907+
"""Validate kwargs and perform common prepocessing.
908+
909+
Returns
910+
-------
911+
If the operation can be resolved by at[],
912+
(return value, None)
913+
Otherwise,
914+
(None, preprocessed x)
915+
"""
894916
if self.idx is _undef:
895917
raise TypeError(
896918
"Index has not been set.\n"
@@ -900,74 +922,136 @@ def _check_args(self, mode="promise_in_bounds", **kwargs):
900922
" at(x)[idx].set(value)\n"
901923
"(same for all other methods)."
902924
)
903-
if mode != "promise_in_bounds":
925+
if mode != "promise_in_bounds" and not is_jax_array(self.x):
904926
xp = array_namespace(self.x)
905927
raise NotImplementedError(
906-
f"mode='{mode}' is not supported for backend {xp.__name__}"
928+
f"mode='{mode!r}' is not supported for backend {xp.__name__}"
929+
)
930+
931+
copy = _parse_copy_param(self.x, copy)
932+
933+
if copy and is_jax_array(self.x):
934+
# Use JAX's at[]
935+
at_ = self.x.at[self.idx]
936+
args = (y, ) if y is not _undef else ()
937+
return getattr(at_, at_op)(*args, mode=mode, **kwargs), None
938+
939+
# Emulate at[] behaviour for non-JAX arrays
940+
x = self.x.copy() if copy else self.x
941+
return None, x
942+
943+
def get(self, copy: bool | None = True, **kwargs):
944+
"""Return x[idx]. In addition to plain __getitem__, this allows ensuring
945+
that the output is (not) a copy and kwargs are passed to the backend."""
946+
# Special case when xp=numpy and idx is a fancy index
947+
# If copy is not False, avoid an unnecessary double copy.
948+
# if copy is forced to False, raise.
949+
if (
950+
is_numpy_array(self.x)
951+
and (
952+
isinstance(self.idx, (list, tuple))
953+
or (is_numpy_array(self.idx) and self.idx.dtype.kind in "biu")
907954
)
955+
):
956+
if copy is True:
957+
copy = None
958+
elif copy is False:
959+
raise ValueError(
960+
"Indexing a numpy array with a fancy index always "
961+
"results in a copy"
962+
)
963+
964+
res, x = self._common("get", copy=copy, **kwargs)
965+
if res is not None:
966+
return res
967+
return x[self.idx]
908968

909969
def set(self, y, /, **kwargs):
910-
self._check_args(**kwargs)
911-
self.x[self.idx] = y
912-
return self.x
970+
"""x[idx] = y"""
971+
res, x = self._common("set", y, **kwargs)
972+
if res is not None:
973+
return res
974+
x[self.idx] = y
975+
return x
976+
977+
def apply(self, ufunc, /, **kwargs):
978+
"""ufunc.at(x, idx)"""
979+
res, x = self._common("apply", ufunc, **kwargs)
980+
if res is not None:
981+
return res
982+
ufunc.at(x, self.idx)
983+
return x
984+
985+
def _iop(self, at_op: str, elwise_op: Callable[[Array, Array], Array], y: Array, **kwargs):
986+
"""x[idx] += y or equivalent in-place operation on a subset of x
987+
988+
which is the same as saying
989+
x[idx] = x[idx] + y
990+
Note that this is not the same as
991+
operator.iadd(x[idx], y)
992+
Consider for example when x is a numpy array and idx is a fancy index, which
993+
triggers a deep copy on __getitem__.
994+
"""
995+
res, x = self._common(at_op, y, **kwargs)
996+
if res is not None:
997+
return res
998+
x[self.idx] = elwise_op(x[self.idx], y)
999+
return x
9131000

9141001
def add(self, y, /, **kwargs):
915-
self._check_args(**kwargs)
916-
self.x[self.idx] += y
917-
return self.x
918-
1002+
"""x[idx] += y"""
1003+
return self._iop("add", operator.add, y, **kwargs)
1004+
9191005
def subtract(self, y, /, **kwargs):
920-
self._check_args(**kwargs)
921-
self.x[self.idx] -= y
922-
return self.x
1006+
"""x[idx] -= y"""
1007+
return self._iop("subtract", operator.sub, y, **kwargs)
9231008

9241009
def multiply(self, y, /, **kwargs):
925-
self._check_args(**kwargs)
926-
self.x[self.idx] *= y
927-
return self.x
1010+
"""x[idx] *= y"""
1011+
return self._iop("multiply", operator.mul, y, **kwargs)
9281012

9291013
def divide(self, y, /, **kwargs):
930-
self._check_args(**kwargs)
931-
self.x[self.idx] /= y
932-
return self.x
933-
1014+
"""x[idx] /= y"""
1015+
return self._iop("divide", operator.truediv, y, **kwargs)
1016+
9341017
def power(self, y, /, **kwargs):
935-
self._check_args(**kwargs)
936-
self.x[self.idx] **= y
937-
return self.x
1018+
"""x[idx] **= y"""
1019+
return self._iop("power", operator.pow, y, **kwargs)
9381020

9391021
def min(self, y, /, **kwargs):
940-
self._check_args(**kwargs)
941-
xp = array_namespace(self.x, y)
942-
self.x[self.idx] = xp.minimum(self.x[self.idx], y)
943-
return self.x
1022+
"""x[idx] = minimum(x[idx], y)"""
1023+
xp = array_namespace(self.x)
1024+
return self._iop("min", xp.minimum, y, **kwargs)
9441025

9451026
def max(self, y, /, **kwargs):
946-
self._check_args(**kwargs)
947-
xp = array_namespace(self.x, y)
948-
self.x[self.idx] = xp.maximum(self.x[self.idx], y)
949-
return self.x
950-
951-
def apply(self, ufunc, /, **kwargs):
952-
self._check_args(**kwargs)
953-
ufunc.at(self.x, self.idx)
954-
return self.x
955-
956-
def get(self, **kwargs):
957-
self._check_args(**kwargs)
958-
return self.x[self.idx]
959-
960-
def iwhere(condition, x, y, /):
961-
"""Variant of xp.where(condition, x, y) which may or may not update
962-
x in place, if it's possible and beneficial for performance.
1027+
"""x[idx] = maximum(x[idx], y)"""
1028+
xp = array_namespace(self.x)
1029+
return self._iop("max", xp.maximum, y, **kwargs)
1030+
1031+
def where(condition, x, y, /, copy: bool | None = True):
1032+
"""Return elements from x when condition is True and from y when
1033+
it is False.
1034+
1035+
This is a wrapper around xp.where that adds the copy parameter:
1036+
1037+
None
1038+
x *may* be modified in place if it is possible and beneficial
1039+
for performance. You should not use x after calling this function.
1040+
True
1041+
Ensure that the inputs are not modified.
1042+
This is the default, in line with np.where.
1043+
False
1044+
Raise ValueError if a copy cannot be avoided.
9631045
"""
1046+
copy = _parse_copy_param(x, copy)
9641047
xp = array_namespace(condition, x, y)
965-
if is_writeable_array(x):
1048+
if copy:
1049+
return xp.where(condition, x, y)
1050+
else:
9661051
condition, x, y = xp.broadcast_arrays(condition, x, y)
9671052
x[condition] = y[condition]
9681053
return x
969-
else:
970-
return xp.where(condition, x, y)
1054+
9711055

9721056
__all__ = [
9731057
"array_namespace",
@@ -993,7 +1077,7 @@ def iwhere(condition, x, y, /):
9931077
"size",
9941078
"to_device",
9951079
"at",
996-
"iwhere",
1080+
"where",
9971081
]
9981082

9991083
_all_ignore = ['sys', 'math', 'inspect', 'warnings']

docs/helper-functions.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ instead, which would be wrapped.
3737
.. autofunction:: to_device
3838
.. autofunction:: size
3939
.. autofunction:: at
40-
.. autofunction:: iwhere
40+
.. autofunction:: where
4141

4242
Inspection Helpers
4343
------------------

0 commit comments

Comments
 (0)