|
1 | 1 | from __future__ import annotations # https://github.com/pylint-dev/pylint/pull/9990 |
2 | 2 |
|
3 | | -import typing |
| 3 | +import operator |
4 | 4 | import warnings |
| 5 | +from typing import TYPE_CHECKING, Any, Callable, Literal |
5 | 6 |
|
6 | | -if typing.TYPE_CHECKING: |
| 7 | +if TYPE_CHECKING: |
7 | 8 | from ._lib._typing import Array, ModuleType |
8 | 9 |
|
9 | 10 | from ._lib import _utils |
10 | | -from ._lib._compat import array_namespace |
| 11 | +from ._lib._compat import ( |
| 12 | + array_namespace, |
| 13 | + is_array_api_obj, |
| 14 | + is_dask_array, |
| 15 | + is_jax_array, |
| 16 | + is_writeable_array, |
| 17 | +) |
11 | 18 |
|
12 | 19 | __all__ = [ |
| 20 | + "at", |
13 | 21 | "atleast_nd", |
14 | 22 | "cov", |
15 | 23 | "create_diagonal", |
@@ -546,3 +554,237 @@ def sinc(x: Array, /, *, xp: ModuleType | None = None) -> Array: |
546 | 554 | x, x, xp.asarray(xp.finfo(x.dtype).eps, dtype=x.dtype, device=x.device) |
547 | 555 | ) |
548 | 556 | return xp.sin(y) / y |
| 557 | + |
| 558 | + |
| 559 | +def _is_fancy_index(idx) -> bool: |
| 560 | + if not isinstance(idx, tuple): |
| 561 | + idx = (idx,) |
| 562 | + return any(isinstance(i, (list, tuple)) or is_array_api_obj(i) for i in idx) |
| 563 | + |
| 564 | + |
| 565 | +_undef = object() |
| 566 | + |
| 567 | + |
| 568 | +class at: |
| 569 | + """ |
| 570 | + Update operations for read-only arrays. |
| 571 | +
|
| 572 | + This implements ``jax.numpy.ndarray.at`` for all backends. |
| 573 | +
|
| 574 | + Keyword arguments are passed verbatim to backends that support the `ndarray.at` |
| 575 | + method; e.g. you may pass ``indices_are_sorted=True`` to JAX; they are quietly |
| 576 | + ignored for backends that don't support them. |
| 577 | +
|
| 578 | + Additionally, this introduces support for the `copy` keyword for all backends: |
| 579 | +
|
| 580 | + None |
| 581 | + The array parameter *may* be modified in place if it is possible and beneficial |
| 582 | + for performance. You should not reuse it after calling this function. |
| 583 | + True |
| 584 | + Ensure that the inputs are not modified. This is the default. |
| 585 | + False |
| 586 | + Raise ValueError if a copy cannot be avoided. |
| 587 | +
|
| 588 | + Examples |
| 589 | + -------- |
| 590 | + Given either of these equivalent expressions:: |
| 591 | +
|
| 592 | + x = at(x)[1].add(2, copy=None) |
| 593 | + x = at(x, 1).add(2, copy=None) |
| 594 | +
|
| 595 | + If x is a JAX array, they are the same as:: |
| 596 | +
|
| 597 | + x = x.at[1].add(2) |
| 598 | +
|
| 599 | + If x is a read-only numpy array, they are the same as:: |
| 600 | +
|
| 601 | + x = x.copy() |
| 602 | + x[1] += 2 |
| 603 | +
|
| 604 | + Otherwise, they are the same as:: |
| 605 | +
|
| 606 | + x[1] += 2 |
| 607 | +
|
| 608 | + Warning |
| 609 | + ------- |
| 610 | + When you use copy=None, you should always immediately overwrite |
| 611 | + the parameter array:: |
| 612 | +
|
| 613 | + x = at(x, 0).set(2, copy=None) |
| 614 | +
|
| 615 | + The anti-pattern below must be avoided, as it will result in different behaviour |
| 616 | + on read-only versus writeable arrays:: |
| 617 | +
|
| 618 | + x = xp.asarray([0, 0, 0]) |
| 619 | + y = at(x, 0).set(2, copy=None) |
| 620 | + z = at(x, 1).set(3, copy=None) |
| 621 | +
|
| 622 | + In the above example, ``x == [0, 0, 0]``, ``y == [2, 0, 0]`` and z == ``[0, 3, 0]`` |
| 623 | + when x is read-only, whereas ``x == y == z == [2, 3, 0]`` when x is writeable! |
| 624 | +
|
| 625 | + Warning |
| 626 | + ------- |
| 627 | + The behaviour of update methods when the index is an array of integers which |
| 628 | + contains multiple occurrences of the same index is undefined; |
| 629 | + e.g. ``at(x, [0, 0]).set(2)`` |
| 630 | +
|
| 631 | + Note |
| 632 | + ---- |
| 633 | + `sparse <https://sparse.pydata.org/>`_ is not supported by update methods yet. |
| 634 | +
|
| 635 | + See Also |
| 636 | + -------- |
| 637 | + `jax.numpy.ndarray.at <https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.ndarray.at.html>`_ |
| 638 | + """ |
| 639 | + |
| 640 | + x: Array |
| 641 | + idx: Any |
| 642 | + __slots__ = ("x", "idx") |
| 643 | + |
| 644 | + def __init__(self, x: Array, idx: Any = _undef, /): |
| 645 | + self.x = x |
| 646 | + self.idx = idx |
| 647 | + |
| 648 | + def __getitem__(self, idx): |
| 649 | + """ |
| 650 | + Allow for the alternate syntax ``at(x)[start:stop:step]``, |
| 651 | + which looks prettier than ``at(x, slice(start, stop, step))`` |
| 652 | + and feels more intuitive coming from the JAX documentation. |
| 653 | + """ |
| 654 | + if self.idx is not _undef: |
| 655 | + raise ValueError("Index has already been set") |
| 656 | + self.idx = idx |
| 657 | + return self |
| 658 | + |
| 659 | + def _common( |
| 660 | + self, |
| 661 | + at_op: str, |
| 662 | + y=_undef, |
| 663 | + copy: bool | None | Literal["_force_false"] = True, |
| 664 | + **kwargs, |
| 665 | + ): |
| 666 | + """Perform common prepocessing. |
| 667 | +
|
| 668 | + Returns |
| 669 | + ------- |
| 670 | + If the operation can be resolved by at[], (return value, None) |
| 671 | + Otherwise, (None, preprocessed x) |
| 672 | + """ |
| 673 | + if self.idx is _undef: |
| 674 | + raise TypeError( |
| 675 | + "Index has not been set.\n" |
| 676 | + "Usage: either\n" |
| 677 | + " at(x, idx).set(value)\n" |
| 678 | + "or\n" |
| 679 | + " at(x)[idx].set(value)\n" |
| 680 | + "(same for all other methods)." |
| 681 | + ) |
| 682 | + |
| 683 | + x = self.x |
| 684 | + |
| 685 | + if copy is False: |
| 686 | + if not is_writeable_array(x) or is_dask_array(x): |
| 687 | + raise ValueError("Cannot modify parameter in place") |
| 688 | + elif copy is None: |
| 689 | + copy = not is_writeable_array(x) |
| 690 | + elif copy == "_force_false": |
| 691 | + copy = False |
| 692 | + elif copy is not True: |
| 693 | + raise ValueError(f"Invalid value for copy: {copy!r}") |
| 694 | + |
| 695 | + if is_jax_array(x): |
| 696 | + # Use JAX's at[] |
| 697 | + at_ = x.at[self.idx] |
| 698 | + args = (y,) if y is not _undef else () |
| 699 | + return getattr(at_, at_op)(*args, **kwargs), None |
| 700 | + |
| 701 | + # Emulate at[] behaviour for non-JAX arrays |
| 702 | + if copy: |
| 703 | + # FIXME We blindly expect the output of x.copy() to be always writeable. |
| 704 | + # This holds true for read-only numpy arrays, but not necessarily for |
| 705 | + # other backends. |
| 706 | + xp = array_namespace(x) |
| 707 | + x = xp.asarray(x, copy=True) |
| 708 | + |
| 709 | + return None, x |
| 710 | + |
| 711 | + def get(self, **kwargs): |
| 712 | + """ |
| 713 | + Return ``x[idx]``. In addition to plain ``__getitem__``, this allows ensuring |
| 714 | + that the output is either a copy or a view; it also allows passing |
| 715 | + keyword arguments to the backend. |
| 716 | + """ |
| 717 | + # __getitem__ with a fancy index always returns a copy. |
| 718 | + # Avoid an unnecessary double copy. |
| 719 | + # If copy is forced to False, raise. |
| 720 | + if _is_fancy_index(self.idx): |
| 721 | + if kwargs.get("copy", True) is False: |
| 722 | + raise TypeError( |
| 723 | + "Indexing a numpy array with a fancy index always " |
| 724 | + "results in a copy" |
| 725 | + ) |
| 726 | + # Skip copy inside _common, even if array is not writeable |
| 727 | + kwargs["copy"] = "_force_false" |
| 728 | + |
| 729 | + res, x = self._common("get", **kwargs) |
| 730 | + if res is not None: |
| 731 | + return res |
| 732 | + return x[self.idx] |
| 733 | + |
| 734 | + def set(self, y, /, **kwargs): |
| 735 | + """Apply ``x[idx] = y`` and return the update array""" |
| 736 | + res, x = self._common("set", y, **kwargs) |
| 737 | + if res is not None: |
| 738 | + return res |
| 739 | + x[self.idx] = y |
| 740 | + return x |
| 741 | + |
| 742 | + def _iop( |
| 743 | + self, at_op: str, elwise_op: Callable[[Array, Array], Array], y: Array, **kwargs |
| 744 | + ): |
| 745 | + """x[idx] += y or equivalent in-place operation on a subset of x |
| 746 | +
|
| 747 | + which is the same as saying |
| 748 | + x[idx] = x[idx] + y |
| 749 | + Note that this is not the same as |
| 750 | + operator.iadd(x[idx], y) |
| 751 | + Consider for example when x is a numpy array and idx is a fancy index, which |
| 752 | + triggers a deep copy on __getitem__. |
| 753 | + """ |
| 754 | + res, x = self._common(at_op, y, **kwargs) |
| 755 | + if res is not None: |
| 756 | + return res |
| 757 | + x[self.idx] = elwise_op(x[self.idx], y) |
| 758 | + return x |
| 759 | + |
| 760 | + def add(self, y, /, **kwargs): |
| 761 | + """Apply ``x[idx] += y`` and return the updated array""" |
| 762 | + return self._iop("add", operator.add, y, **kwargs) |
| 763 | + |
| 764 | + def subtract(self, y, /, **kwargs): |
| 765 | + """Apply ``x[idx] -= y`` and return the updated array""" |
| 766 | + return self._iop("subtract", operator.sub, y, **kwargs) |
| 767 | + |
| 768 | + def multiply(self, y, /, **kwargs): |
| 769 | + """Apply ``x[idx] *= y`` and return the updated array""" |
| 770 | + return self._iop("multiply", operator.mul, y, **kwargs) |
| 771 | + |
| 772 | + def divide(self, y, /, **kwargs): |
| 773 | + """Apply ``x[idx] /= y`` and return the updated array""" |
| 774 | + return self._iop("divide", operator.truediv, y, **kwargs) |
| 775 | + |
| 776 | + def power(self, y, /, **kwargs): |
| 777 | + """Apply ``x[idx] **= y`` and return the updated array""" |
| 778 | + return self._iop("power", operator.pow, y, **kwargs) |
| 779 | + |
| 780 | + def min(self, y, /, **kwargs): |
| 781 | + """Apply ``x[idx] = minimum(x[idx], y)`` and return the updated array""" |
| 782 | + xp = array_namespace(self.x) |
| 783 | + y = xp.asarray(y) |
| 784 | + return self._iop("min", xp.minimum, y, **kwargs) |
| 785 | + |
| 786 | + def max(self, y, /, **kwargs): |
| 787 | + """Apply ``x[idx] = maximum(x[idx], y)`` and return the updated array""" |
| 788 | + xp = array_namespace(self.x) |
| 789 | + y = xp.asarray(y) |
| 790 | + return self._iop("max", xp.maximum, y, **kwargs) |
0 commit comments