Skip to content

Commit 1f30d35

Browse files
committed
WIP at
1 parent 7875ed6 commit 1f30d35

File tree

7 files changed

+446
-4
lines changed

7 files changed

+446
-4
lines changed

docs/api-reference.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
:nosignatures:
77
:toctree: generated
88
9+
at
910
atleast_nd
1011
cov
1112
create_diagonal

pyproject.toml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -235,6 +235,9 @@ ignore = [
235235
"PLR09", # Too many <...>
236236
"PLR2004", # Magic value used in comparison
237237
"ISC001", # Conflicts with formatter
238+
"EM101", # raw-string-in-exception
239+
"EM102", # f-string-in-exception
240+
"PD008", # pandas-use-of-dot-at
238241
]
239242
isort.required-imports = ["from __future__ import annotations"]
240243

src/array_api_extra/__init__.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,22 @@
11
from __future__ import annotations # https://github.com/pylint-dev/pylint/pull/9990
22

3-
from ._funcs import atleast_nd, cov, create_diagonal, expand_dims, kron, setdiff1d, sinc
3+
from ._funcs import (
4+
at,
5+
atleast_nd,
6+
cov,
7+
create_diagonal,
8+
expand_dims,
9+
kron,
10+
setdiff1d,
11+
sinc,
12+
)
413

514
__version__ = "0.3.3"
615

716
# pylint: disable=duplicate-code
817
__all__ = [
918
"__version__",
19+
"at",
1020
"atleast_nd",
1121
"cov",
1222
"create_diagonal",

src/array_api_extra/_funcs.py

Lines changed: 245 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,23 @@
11
from __future__ import annotations # https://github.com/pylint-dev/pylint/pull/9990
22

3-
import typing
3+
import operator
44
import warnings
5+
from typing import TYPE_CHECKING, Any, Callable, Literal
56

6-
if typing.TYPE_CHECKING:
7+
if TYPE_CHECKING:
78
from ._lib._typing import Array, ModuleType
89

910
from ._lib import _utils
10-
from ._lib._compat import array_namespace
11+
from ._lib._compat import (
12+
array_namespace,
13+
is_array_api_obj,
14+
is_dask_array,
15+
is_jax_array,
16+
is_writeable_array,
17+
)
1118

1219
__all__ = [
20+
"at",
1321
"atleast_nd",
1422
"cov",
1523
"create_diagonal",
@@ -546,3 +554,237 @@ def sinc(x: Array, /, *, xp: ModuleType | None = None) -> Array:
546554
x, x, xp.asarray(xp.finfo(x.dtype).eps, dtype=x.dtype, device=x.device)
547555
)
548556
return xp.sin(y) / y
557+
558+
559+
def _is_fancy_index(idx) -> bool:
560+
if not isinstance(idx, tuple):
561+
idx = (idx,)
562+
return any(isinstance(i, (list, tuple)) or is_array_api_obj(i) for i in idx)
563+
564+
565+
_undef = object()
566+
567+
568+
class at:
569+
"""
570+
Update operations for read-only arrays.
571+
572+
This implements ``jax.numpy.ndarray.at`` for all backends.
573+
574+
Keyword arguments are passed verbatim to backends that support the `ndarray.at`
575+
method; e.g. you may pass ``indices_are_sorted=True`` to JAX; they are quietly
576+
ignored for backends that don't support them.
577+
578+
Additionally, this introduces support for the `copy` keyword for all backends:
579+
580+
None
581+
The array parameter *may* be modified in place if it is possible and beneficial
582+
for performance. You should not reuse it after calling this function.
583+
True
584+
Ensure that the inputs are not modified. This is the default.
585+
False
586+
Raise ValueError if a copy cannot be avoided.
587+
588+
Examples
589+
--------
590+
Given either of these equivalent expressions::
591+
592+
x = at(x)[1].add(2, copy=None)
593+
x = at(x, 1).add(2, copy=None)
594+
595+
If x is a JAX array, they are the same as::
596+
597+
x = x.at[1].add(2)
598+
599+
If x is a read-only numpy array, they are the same as::
600+
601+
x = x.copy()
602+
x[1] += 2
603+
604+
Otherwise, they are the same as::
605+
606+
x[1] += 2
607+
608+
Warning
609+
-------
610+
When you use copy=None, you should always immediately overwrite
611+
the parameter array::
612+
613+
x = at(x, 0).set(2, copy=None)
614+
615+
The anti-pattern below must be avoided, as it will result in different behaviour
616+
on read-only versus writeable arrays::
617+
618+
x = xp.asarray([0, 0, 0])
619+
y = at(x, 0).set(2, copy=None)
620+
z = at(x, 1).set(3, copy=None)
621+
622+
In the above example, ``x == [0, 0, 0]``, ``y == [2, 0, 0]`` and z == ``[0, 3, 0]``
623+
when x is read-only, whereas ``x == y == z == [2, 3, 0]`` when x is writeable!
624+
625+
Warning
626+
-------
627+
The behaviour of update methods when the index is an array of integers which
628+
contains multiple occurrences of the same index is undefined;
629+
e.g. ``at(x, [0, 0]).set(2)``
630+
631+
Note
632+
----
633+
`sparse <https://sparse.pydata.org/>`_ is not supported by update methods yet.
634+
635+
See Also
636+
--------
637+
`jax.numpy.ndarray.at <https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.ndarray.at.html>`_
638+
"""
639+
640+
x: Array
641+
idx: Any
642+
__slots__ = ("x", "idx")
643+
644+
def __init__(self, x: Array, idx: Any = _undef, /):
645+
self.x = x
646+
self.idx = idx
647+
648+
def __getitem__(self, idx):
649+
"""
650+
Allow for the alternate syntax ``at(x)[start:stop:step]``,
651+
which looks prettier than ``at(x, slice(start, stop, step))``
652+
and feels more intuitive coming from the JAX documentation.
653+
"""
654+
if self.idx is not _undef:
655+
raise ValueError("Index has already been set")
656+
self.idx = idx
657+
return self
658+
659+
def _common(
660+
self,
661+
at_op: str,
662+
y=_undef,
663+
copy: bool | None | Literal["_force_false"] = True,
664+
**kwargs,
665+
):
666+
"""Perform common prepocessing.
667+
668+
Returns
669+
-------
670+
If the operation can be resolved by at[], (return value, None)
671+
Otherwise, (None, preprocessed x)
672+
"""
673+
if self.idx is _undef:
674+
raise TypeError(
675+
"Index has not been set.\n"
676+
"Usage: either\n"
677+
" at(x, idx).set(value)\n"
678+
"or\n"
679+
" at(x)[idx].set(value)\n"
680+
"(same for all other methods)."
681+
)
682+
683+
x = self.x
684+
685+
if copy is False:
686+
if not is_writeable_array(x) or is_dask_array(x):
687+
raise ValueError("Cannot modify parameter in place")
688+
elif copy is None:
689+
copy = not is_writeable_array(x)
690+
elif copy == "_force_false":
691+
copy = False
692+
elif copy is not True:
693+
raise ValueError(f"Invalid value for copy: {copy!r}")
694+
695+
if is_jax_array(x):
696+
# Use JAX's at[]
697+
at_ = x.at[self.idx]
698+
args = (y,) if y is not _undef else ()
699+
return getattr(at_, at_op)(*args, **kwargs), None
700+
701+
# Emulate at[] behaviour for non-JAX arrays
702+
if copy:
703+
# FIXME We blindly expect the output of x.copy() to be always writeable.
704+
# This holds true for read-only numpy arrays, but not necessarily for
705+
# other backends.
706+
xp = array_namespace(x)
707+
x = xp.asarray(x, copy=True)
708+
709+
return None, x
710+
711+
def get(self, **kwargs):
712+
"""
713+
Return ``x[idx]``. In addition to plain ``__getitem__``, this allows ensuring
714+
that the output is either a copy or a view; it also allows passing
715+
keyword arguments to the backend.
716+
"""
717+
# __getitem__ with a fancy index always returns a copy.
718+
# Avoid an unnecessary double copy.
719+
# If copy is forced to False, raise.
720+
if _is_fancy_index(self.idx):
721+
if kwargs.get("copy", True) is False:
722+
raise TypeError(
723+
"Indexing a numpy array with a fancy index always "
724+
"results in a copy"
725+
)
726+
# Skip copy inside _common, even if array is not writeable
727+
kwargs["copy"] = "_force_false"
728+
729+
res, x = self._common("get", **kwargs)
730+
if res is not None:
731+
return res
732+
return x[self.idx]
733+
734+
def set(self, y, /, **kwargs):
735+
"""Apply ``x[idx] = y`` and return the update array"""
736+
res, x = self._common("set", y, **kwargs)
737+
if res is not None:
738+
return res
739+
x[self.idx] = y
740+
return x
741+
742+
def _iop(
743+
self, at_op: str, elwise_op: Callable[[Array, Array], Array], y: Array, **kwargs
744+
):
745+
"""x[idx] += y or equivalent in-place operation on a subset of x
746+
747+
which is the same as saying
748+
x[idx] = x[idx] + y
749+
Note that this is not the same as
750+
operator.iadd(x[idx], y)
751+
Consider for example when x is a numpy array and idx is a fancy index, which
752+
triggers a deep copy on __getitem__.
753+
"""
754+
res, x = self._common(at_op, y, **kwargs)
755+
if res is not None:
756+
return res
757+
x[self.idx] = elwise_op(x[self.idx], y)
758+
return x
759+
760+
def add(self, y, /, **kwargs):
761+
"""Apply ``x[idx] += y`` and return the updated array"""
762+
return self._iop("add", operator.add, y, **kwargs)
763+
764+
def subtract(self, y, /, **kwargs):
765+
"""Apply ``x[idx] -= y`` and return the updated array"""
766+
return self._iop("subtract", operator.sub, y, **kwargs)
767+
768+
def multiply(self, y, /, **kwargs):
769+
"""Apply ``x[idx] *= y`` and return the updated array"""
770+
return self._iop("multiply", operator.mul, y, **kwargs)
771+
772+
def divide(self, y, /, **kwargs):
773+
"""Apply ``x[idx] /= y`` and return the updated array"""
774+
return self._iop("divide", operator.truediv, y, **kwargs)
775+
776+
def power(self, y, /, **kwargs):
777+
"""Apply ``x[idx] **= y`` and return the updated array"""
778+
return self._iop("power", operator.pow, y, **kwargs)
779+
780+
def min(self, y, /, **kwargs):
781+
"""Apply ``x[idx] = minimum(x[idx], y)`` and return the updated array"""
782+
xp = array_namespace(self.x)
783+
y = xp.asarray(y)
784+
return self._iop("min", xp.minimum, y, **kwargs)
785+
786+
def max(self, y, /, **kwargs):
787+
"""Apply ``x[idx] = maximum(x[idx], y)`` and return the updated array"""
788+
xp = array_namespace(self.x)
789+
y = xp.asarray(y)
790+
return self._iop("max", xp.maximum, y, **kwargs)

src/array_api_extra/_lib/_compat.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,14 +6,25 @@
66
from ..._array_api_compat_vendor import ( # pyright: ignore[reportMissingImports]
77
array_namespace, # pyright: ignore[reportUnknownVariableType]
88
device, # pyright: ignore[reportUnknownVariableType]
9+
is_array_api_obj, # pyright: ignore[reportUnknownVariableType]
10+
is_dask_array, # pyright: ignore[reportUnknownVariableType]
11+
is_jax_array, # pyright: ignore[reportUnknownVariableType]
12+
is_writeable_array, # pyright: ignore[reportUnknownVariableType]
913
)
1014
except ImportError:
1115
from array_api_compat import ( # pyright: ignore[reportMissingTypeStubs]
1216
array_namespace, # pyright: ignore[reportUnknownVariableType]
1317
device,
18+
is_dask_array, # pyright: ignore[reportUnknownVariableType]
19+
is_jax_array, # pyright: ignore[reportUnknownVariableType]
20+
is_writeable_array, # pyright: ignore[reportUnknownVariableType,reportAttributeAccessIssue]
1421
)
1522

1623
__all__ = [
1724
"array_namespace",
1825
"device",
26+
"is_array_api_obj",
27+
"is_dask_array",
28+
"is_jax_array",
29+
"is_writeable_array",
1930
]

src/array_api_extra/_lib/_compat.pyi

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,3 +11,7 @@ def array_namespace(
1111
use_compat: bool | None = None,
1212
) -> ArrayModule: ...
1313
def device(x: Array, /) -> Device: ...
14+
def is_array_api_obj(x: object, /) -> bool: ...
15+
def is_dask_array(x: object, /) -> bool: ...
16+
def is_jax_array(x: object, /) -> bool: ...
17+
def is_writeable_array(x: object, /) -> bool: ...

0 commit comments

Comments
 (0)