Skip to content

Commit 68ab328

Browse files
committed
WIP at() method
1 parent e0c92d3 commit 68ab328

File tree

12 files changed

+474
-17
lines changed

12 files changed

+474
-17
lines changed

docs/api-reference.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
:nosignatures:
77
:toctree: generated
88
9+
at
910
atleast_nd
1011
cov
1112
create_diagonal

pyproject.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -162,6 +162,7 @@ run.source = ["array_api_extra"]
162162
report.exclude_also = [
163163
'\.\.\.',
164164
'if typing.TYPE_CHECKING:',
165+
'if TYPE_CHECKING:',
165166
]
166167

167168

@@ -233,6 +234,7 @@ ignore = [
233234
"PLR09", # Too many <...>
234235
"PLR2004", # Magic value used in comparison
235236
"ISC001", # Conflicts with formatter
237+
"PD008", # pandas-use-of-dot-at
236238
]
237239

238240
[tool.ruff.lint.per-file-ignores]

src/array_api_extra/__init__.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,22 @@
11
from __future__ import annotations # https://github.com/pylint-dev/pylint/pull/9990
22

3-
from ._funcs import atleast_nd, cov, create_diagonal, expand_dims, kron, setdiff1d, sinc
3+
from ._funcs import (
4+
at,
5+
atleast_nd,
6+
cov,
7+
create_diagonal,
8+
expand_dims,
9+
kron,
10+
setdiff1d,
11+
sinc,
12+
)
413

514
__version__ = "0.3.3.dev0"
615

716
# pylint: disable=duplicate-code
817
__all__ = [
918
"__version__",
19+
"at",
1020
"atleast_nd",
1121
"cov",
1222
"create_diagonal",

src/array_api_extra/_funcs.py

Lines changed: 284 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,22 @@
11
from __future__ import annotations # https://github.com/pylint-dev/pylint/pull/9990
22

3-
import typing
3+
import operator
44
import warnings
5+
from typing import TYPE_CHECKING, Any, Callable
56

6-
if typing.TYPE_CHECKING:
7+
if TYPE_CHECKING:
78
from ._lib._typing import Array, ModuleType
89

910
from ._lib import _utils
10-
from ._lib._compat import array_namespace
11+
from ._lib._compat import (
12+
array_namespace,
13+
is_array_api_obj,
14+
is_dask_array,
15+
is_writeable_array,
16+
)
1117

1218
__all__ = [
19+
"at",
1320
"atleast_nd",
1421
"cov",
1522
"create_diagonal",
@@ -548,3 +555,277 @@ def sinc(x: Array, /, *, xp: ModuleType | None = None) -> Array:
548555
xp.asarray(xp.finfo(x.dtype).eps, dtype=x.dtype, device=x.device),
549556
)
550557
return xp.sin(y) / y
558+
559+
560+
_undef = object()
561+
562+
563+
class at:
564+
"""
565+
Update operations for read-only arrays.
566+
567+
This implements ``jax.numpy.ndarray.at`` for all backends.
568+
569+
Parameters
570+
----------
571+
x : array
572+
Input array.
573+
idx : index, optional
574+
You may use two alternate syntaxes::
575+
576+
at(x, idx).set(value) # or get(), add(), etc.
577+
at(x)[idx].set(value)
578+
579+
copy : bool, optional
580+
True (default)
581+
Ensure that the inputs are not modified.
582+
False
583+
Ensure that the update operation writes back to the input.
584+
Raise ValueError if a copy cannot be avoided.
585+
None
586+
The array parameter *may* be modified in place if it is possible and
587+
beneficial for performance.
588+
You should not reuse it after calling this function.
589+
xp : array_namespace, optional
590+
The standard-compatible namespace for `x`. Default: infer
591+
592+
Additionally, if the backend supports an `at` method, any additional keyword
593+
arguments are passed to it verbatim; e.g. this allows passing
594+
``indices_are_sorted=True`` to JAX.
595+
596+
Returns
597+
-------
598+
Updated input array.
599+
600+
Examples
601+
--------
602+
Given either of these equivalent expressions::
603+
604+
x = at(x)[1].add(2, copy=None)
605+
x = at(x, 1).add(2, copy=None)
606+
607+
If x is a JAX array, they are the same as::
608+
609+
x = x.at[1].add(2)
610+
611+
If x is a read-only numpy array, they are the same as::
612+
613+
x = x.copy()
614+
x[1] += 2
615+
616+
Otherwise, they are the same as::
617+
618+
x[1] += 2
619+
620+
Warning
621+
-------
622+
When you use copy=None, you should always immediately overwrite
623+
the parameter array::
624+
625+
x = at(x, 0).set(2, copy=None)
626+
627+
The anti-pattern below must be avoided, as it will result in different behaviour
628+
on read-only versus writeable arrays::
629+
630+
x = xp.asarray([0, 0, 0])
631+
y = at(x, 0).set(2, copy=None)
632+
z = at(x, 1).set(3, copy=None)
633+
634+
In the above example, ``x == [0, 0, 0]``, ``y == [2, 0, 0]`` and z == ``[0, 3, 0]``
635+
when x is read-only, whereas ``x == y == z == [2, 3, 0]`` when x is writeable!
636+
637+
Warning
638+
-------
639+
The behaviour of update methods when the index is an array of integers which
640+
contains multiple occurrences of the same index is undefined;
641+
e.g. ``at(x, [0, 0]).set(2)``
642+
643+
Note
644+
----
645+
`sparse <https://sparse.pydata.org/>`_ is not supported by update methods yet.
646+
647+
See Also
648+
--------
649+
`jax.numpy.ndarray.at <https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.ndarray.at.html>`_
650+
"""
651+
652+
x: Array
653+
idx: Any
654+
__slots__ = ("x", "idx")
655+
656+
def __init__(self, x: Array, idx: Any = _undef, /):
657+
self.x = x
658+
self.idx = idx
659+
660+
def __getitem__(self, idx: Any) -> Any:
661+
"""Allow for the alternate syntax ``at(x)[start:stop:step]``,
662+
which looks prettier than ``at(x, slice(start, stop, step))``
663+
and feels more intuitive coming from the JAX documentation.
664+
"""
665+
if self.idx is not _undef:
666+
msg = "Index has already been set"
667+
raise ValueError(msg)
668+
self.idx = idx
669+
return self
670+
671+
def _common(
672+
self,
673+
at_op: str,
674+
y: Array = _undef,
675+
/,
676+
copy: bool | None = True,
677+
xp: ModuleType | None = None,
678+
_is_update: bool = True,
679+
**kwargs: Any,
680+
) -> tuple[Any, None] | tuple[None, Array]:
681+
"""Perform common prepocessing.
682+
683+
Returns
684+
-------
685+
If the operation can be resolved by at[], (return value, None)
686+
Otherwise, (None, preprocessed x)
687+
"""
688+
if self.idx is _undef:
689+
msg = (
690+
"Index has not been set.\n"
691+
"Usage: either\n"
692+
" at(x, idx).set(value)\n"
693+
"or\n"
694+
" at(x)[idx].set(value)\n"
695+
"(same for all other methods)."
696+
)
697+
raise TypeError(msg)
698+
699+
x = self.x
700+
701+
if copy is True:
702+
writeable = None
703+
elif copy is False:
704+
writeable = is_writeable_array(x)
705+
if not writeable:
706+
msg = "Cannot modify parameter in place"
707+
raise ValueError(msg)
708+
elif copy is None:
709+
writeable = is_writeable_array(x)
710+
copy = _is_update and not writeable
711+
else:
712+
msg = f"Invalid value for copy: {copy!r}"
713+
raise ValueError(msg)
714+
715+
if copy:
716+
try:
717+
at_ = x.at
718+
except AttributeError:
719+
# Emulate at[] behaviour for non-JAX arrays
720+
# with a copy followed by an update
721+
if xp is None:
722+
xp = array_namespace(x)
723+
# Create writeable copy of read-only numpy array
724+
x = xp.asarray(x, copy=True)
725+
if writeable is False:
726+
# A copy of a read-only numpy array is writeable
727+
writeable = None
728+
else:
729+
# Use JAX's at[] or other library that with the same duck-type API
730+
args = (y,) if y is not _undef else ()
731+
return getattr(at_[self.idx], at_op)(*args, **kwargs), None
732+
733+
if _is_update:
734+
if writeable is None:
735+
writeable = is_writeable_array(x)
736+
if not writeable:
737+
# sparse crashes here
738+
msg = f"Array {x} has no `at` method and is read-only"
739+
raise ValueError(msg)
740+
741+
return None, x
742+
743+
def get(self, **kwargs: Any) -> Any:
744+
"""Return ``x[idx]``. In addition to plain ``__getitem__``, this allows ensuring
745+
that the output is either a copy or a view; it also allows passing
746+
keyword arguments to the backend.
747+
"""
748+
if kwargs.get("copy") is False:
749+
if is_array_api_obj(self.idx):
750+
# Boolean index. Note that the array API spec
751+
# https://data-apis.org/array-api/latest/API_specification/indexing.html
752+
# does not allow for list, tuple, and tuples of slices plus one or more
753+
# one-dimensional array indices, although many backends support them.
754+
# So this check will encounter a lot of false negatives in real life,
755+
# which can be caught by testing the user code vs. array-api-strict.
756+
msg = "get() with an array index always returns a copy"
757+
raise ValueError(msg)
758+
if is_dask_array(self.x):
759+
msg = "get() on Dask arrays always returns a copy"
760+
raise ValueError(msg)
761+
762+
res, x = self._common("get", _is_update=False, **kwargs)
763+
if res is not None:
764+
return res
765+
assert x is not None
766+
return x[self.idx]
767+
768+
def set(self, y: Array, /, **kwargs: Any) -> Array:
769+
"""Apply ``x[idx] = y`` and return the update array"""
770+
res, x = self._common("set", y, **kwargs)
771+
if res is not None:
772+
return res
773+
assert x is not None
774+
x[self.idx] = y
775+
return x
776+
777+
def _iop(
778+
self,
779+
at_op: str,
780+
elwise_op: Callable[[Array, Array], Array],
781+
y: Array,
782+
/,
783+
**kwargs: Any,
784+
) -> Array:
785+
"""x[idx] += y or equivalent in-place operation on a subset of x
786+
787+
which is the same as saying
788+
x[idx] = x[idx] + y
789+
Note that this is not the same as
790+
operator.iadd(x[idx], y)
791+
Consider for example when x is a numpy array and idx is a fancy index, which
792+
triggers a deep copy on __getitem__.
793+
"""
794+
res, x = self._common(at_op, y, **kwargs)
795+
if res is not None:
796+
return res
797+
assert x is not None
798+
x[self.idx] = elwise_op(x[self.idx], y)
799+
return x
800+
801+
def add(self, y: Array, /, **kwargs: Any) -> Array:
802+
"""Apply ``x[idx] += y`` and return the updated array"""
803+
return self._iop("add", operator.add, y, **kwargs)
804+
805+
def subtract(self, y: Array, /, **kwargs: Any) -> Array:
806+
"""Apply ``x[idx] -= y`` and return the updated array"""
807+
return self._iop("subtract", operator.sub, y, **kwargs)
808+
809+
def multiply(self, y: Array, /, **kwargs: Any) -> Array:
810+
"""Apply ``x[idx] *= y`` and return the updated array"""
811+
return self._iop("multiply", operator.mul, y, **kwargs)
812+
813+
def divide(self, y: Array, /, **kwargs: Any) -> Array:
814+
"""Apply ``x[idx] /= y`` and return the updated array"""
815+
return self._iop("divide", operator.truediv, y, **kwargs)
816+
817+
def power(self, y: Array, /, **kwargs: Any) -> Array:
818+
"""Apply ``x[idx] **= y`` and return the updated array"""
819+
return self._iop("power", operator.pow, y, **kwargs)
820+
821+
def min(self, y: Array, /, **kwargs: Any) -> Array:
822+
"""Apply ``x[idx] = minimum(x[idx], y)`` and return the updated array"""
823+
xp = array_namespace(self.x)
824+
y = xp.asarray(y)
825+
return self._iop("min", xp.minimum, y, **kwargs)
826+
827+
def max(self, y: Array, /, **kwargs: Any) -> Array:
828+
"""Apply ``x[idx] = maximum(x[idx], y)`` and return the updated array"""
829+
xp = array_namespace(self.x)
830+
y = xp.asarray(y)
831+
return self._iop("max", xp.maximum, y, **kwargs)

src/array_api_extra/_lib/_compat.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,14 +6,23 @@
66
from ..._array_api_compat_vendor import ( # pyright: ignore[reportMissingImports]
77
array_namespace, # pyright: ignore[reportUnknownVariableType]
88
device, # pyright: ignore[reportUnknownVariableType]
9+
is_array_api_obj, # pyright: ignore[reportUnknownVariableType]
10+
is_dask_array, # pyright: ignore[reportUnknownVariableType]
11+
is_writeable_array, # pyright: ignore[reportUnknownVariableType]
912
)
1013
except ImportError:
1114
from array_api_compat import ( # pyright: ignore[reportMissingTypeStubs]
1215
array_namespace, # pyright: ignore[reportUnknownVariableType]
1316
device,
17+
is_array_api_obj, # pyright: ignore[reportUnknownVariableType]
18+
is_dask_array, # pyright: ignore[reportUnknownVariableType]
19+
is_writeable_array, # pyright: ignore[reportUnknownVariableType,reportAttributeAccessIssue]
1420
)
1521

16-
__all__ = [
22+
__all__ = (
1723
"array_namespace",
1824
"device",
19-
]
25+
"is_array_api_obj",
26+
"is_dask_array",
27+
"is_writeable_array",
28+
)

src/array_api_extra/_lib/_compat.pyi

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,3 +11,6 @@ def array_namespace(
1111
use_compat: bool | None = None,
1212
) -> ArrayModule: ...
1313
def device(x: Array, /) -> Device: ...
14+
def is_array_api_obj(x: object, /) -> bool: ...
15+
def is_dask_array(x: object, /) -> bool: ...
16+
def is_writeable_array(x: object, /) -> bool: ...

0 commit comments

Comments
 (0)