1313 array_namespace ,
1414 is_array_api_obj ,
1515 is_dask_array ,
16+ is_jax_array ,
17+ is_pydata_sparse_array ,
1618 is_writeable_array ,
1719)
1820
1921if typing .TYPE_CHECKING :
20- from ._lib ._typing import Array , Index , ModuleType , Untyped
22+ from ._lib ._typing import Array , Index , ModuleType
2123
2224__all__ = [
2325 "at" ,
@@ -593,11 +595,6 @@ class at: # pylint: disable=invalid-name
593595 xp : array_namespace, optional
594596 The standard-compatible namespace for `x`. Default: infer
595597
596- **kwargs:
597- If the backend supports an `at` method, any additional keyword
598- arguments are passed to it verbatim; e.g. this allows passing
599- ``indices_are_sorted=True`` to JAX.
600-
601598 Returns
602599 -------
603600 Updated input array.
@@ -674,23 +671,7 @@ def __getitem__(self, idx: Index, /) -> at:
674671 self .idx = idx
675672 return self
676673
677- def _common (
678- self ,
679- at_op : str ,
680- y : Array = _undef ,
681- / ,
682- copy : bool | None = True ,
683- xp : ModuleType | None = None ,
684- _is_update : bool = True ,
685- ** kwargs : Untyped ,
686- ) -> tuple [Array , None ] | tuple [None , Array ]:
687- """Perform common prepocessing.
688-
689- Returns
690- -------
691- If the operation can be resolved by at[], (return value, None)
692- Otherwise, (None, preprocessed x)
693- """
674+ def _check_args (self , / , copy : bool | None ) -> None :
694675 if self .idx is _undef :
695676 msg = (
696677 "Index has not been set.\n "
@@ -702,64 +683,23 @@ def _common(
702683 )
703684 raise TypeError (msg )
704685
705- x = self .x
706-
707686 if copy not in (True , False , None ):
708687 msg = f"copy must be True, False, or None; got { copy !r} " # pyright: ignore[reportUnreachable]
709688 raise ValueError (msg )
710689
711- if copy is None :
712- writeable = is_writeable_array (x )
713- copy = _is_update and not writeable
714- elif copy :
715- writeable = None
716- elif _is_update :
717- writeable = is_writeable_array (x )
718- if not writeable :
719- msg = "Cannot modify parameter in place"
720- raise ValueError (msg )
721- else :
722- writeable = None
723-
724- if copy :
725- try :
726- at_ = x .at
727- except AttributeError :
728- # Emulate at[] behaviour for non-JAX arrays
729- # with a copy followed by an update
730- if xp is None :
731- xp = array_namespace (x )
732- x = xp .asarray (x , copy = True )
733- if writeable is False :
734- # A copy of a read-only numpy array is writeable
735- # Note: this assumes that a copy of a writeable array is writeable
736- writeable = None
737- else :
738- # Use JAX's at[] or other library that with the same duck-type API
739- args = (y ,) if y is not _undef else ()
740- return getattr (at_ [self .idx ], at_op )(* args , ** kwargs ), None
741-
742- if _is_update :
743- if writeable is None :
744- writeable = is_writeable_array (x )
745- if not writeable :
746- # sparse crashes here
747- msg = f"Array { x } has no `at` method and is read-only"
748- raise ValueError (msg )
749-
750- return None , x
751-
752690 def get (
753691 self ,
754692 / ,
755693 copy : bool | None = True ,
756694 xp : ModuleType | None = None ,
757- ** kwargs : Untyped ,
758- ) -> Untyped :
759- """Return ``x[idx]``. In addition to plain ``__getitem__``, this allows ensuring
760- that the output is either a copy or a view; it also allows passing
695+ ) -> Array :
696+ """Return ``xp.asarray(x[idx])``. In addition to plain ``__getitem__``, this allows
697+ ensuring that the output is either a copy or a view; it also allows passing
761698 keyword arguments to the backend.
762699 """
700+ self ._check_args (copy = copy )
701+ x = self .x
702+
763703 if copy is False :
764704 if is_array_api_obj (self .idx ):
765705 # Boolean index. Note that the array API spec
@@ -782,26 +722,81 @@ def get(
782722 msg = "get() with a scalar index typically returns a copy"
783723 raise ValueError (msg )
784724
785- if is_dask_array (self .x ):
786- msg = "get() on Dask arrays always returns a copy"
725+ # Note: this is not the same list of backends as is_writeable_array()
726+ if is_dask_array (x ) or is_jax_array (x ) or is_pydata_sparse_array (x ):
727+ msg = f"get() on { array_namespace (x )} arrays always returns a copy"
787728 raise ValueError (msg )
788729
789- res , x = self ._common ("get" , copy = copy , xp = xp , _is_update = False , ** kwargs )
790- if res is not None :
791- return res
792- assert x is not None
793- return x [self .idx ]
730+ if is_jax_array (x ):
731+ # Use JAX's at[] or other library that with the same duck-type API
732+ return x .at [self .idx ].get ()
733+
734+ if xp is None :
735+ xp = array_namespace (x )
736+ # Note: when self.idx is a boolean mask, numpy always returns a deep copy.
737+ # However, some backends may legitimately return a view when the mask can
738+ # be downgraded to a slice, e.g. a[[True, True, False]] -> a[:2].
739+ # Err on the side of caution and perform a double-copy in numpy.
740+ return xp .asarray (x [self .idx ], copy = copy )
741+
742+ def _update_common (
743+ self ,
744+ at_op : str ,
745+ y : Array = _undef ,
746+ / ,
747+ copy : bool | None = True ,
748+ xp : ModuleType | None = None ,
749+ ) -> tuple [Array , None ] | tuple [None , Array ]:
750+ """Perform common prepocessing to all update operations.
751+
752+ Returns
753+ -------
754+ If the operation can be resolved by at[], (return value, None)
755+ Otherwise, (None, preprocessed x)
756+ """
757+ x = self .x
758+ if copy is None :
759+ writeable = is_writeable_array (x )
760+ copy = not writeable
761+ elif copy :
762+ writeable = None
763+ else :
764+ writeable = is_writeable_array (x )
765+
766+ if copy :
767+ if is_jax_array (x ):
768+ # Use JAX's at[] or other library that with the same duck-type API
769+ func = getattr (x .at [self .idx ], at_op )
770+ return func (y ) if y is not _undef else func (), None
771+ # Emulate at[] behaviour for non-JAX arrays
772+ # with a copy followed by an update
773+ if xp is None :
774+ xp = array_namespace (x )
775+ x = xp .asarray (x , copy = True )
776+ if writeable is False :
777+ # A copy of a read-only numpy array is writeable
778+ # Note: this assumes that a copy of a writeable array is writeable
779+ writeable = None
780+
781+ if writeable is None :
782+ writeable = is_writeable_array (x )
783+ if not writeable :
784+ # sparse crashes here
785+ msg = f"Array { x } has no `at` method and is read-only"
786+ raise ValueError (msg )
787+
788+ return None , x
794789
795790 def set (
796791 self ,
797792 y : Array ,
798793 / ,
799794 copy : bool | None = True ,
800795 xp : ModuleType | None = None ,
801- ** kwargs : Untyped ,
802796 ) -> Array :
803797 """Apply ``x[idx] = y`` and return the update array"""
804- res , x = self ._common ("set" , y , copy = copy , xp = xp , ** kwargs )
798+ self ._check_args (copy = copy )
799+ res , x = self ._update_common ("set" , y , copy = copy , xp = xp )
805800 if res is not None :
806801 return res
807802 assert x is not None
@@ -818,7 +813,6 @@ def _iop(
818813 / ,
819814 copy : bool | None = True ,
820815 xp : ModuleType | None = None ,
821- ** kwargs : Untyped ,
822816 ) -> Array :
823817 """x[idx] += y or equivalent in-place operation on a subset of x
824818
@@ -829,7 +823,8 @@ def _iop(
829823 Consider for example when x is a numpy array and idx is a fancy index, which
830824 triggers a deep copy on __getitem__.
831825 """
832- res , x = self ._common (at_op , y , copy = copy , xp = xp , ** kwargs )
826+ self ._check_args (copy = copy )
827+ res , x = self ._update_common (at_op , y , copy = copy , xp = xp )
833828 if res is not None :
834829 return res
835830 assert x is not None
@@ -842,79 +837,72 @@ def add(
842837 / ,
843838 copy : bool | None = True ,
844839 xp : ModuleType | None = None ,
845- ** kwargs : Untyped ,
846840 ) -> Array :
847841 """Apply ``x[idx] += y`` and return the updated array"""
848- return self ._iop ("add" , operator .add , y , copy = copy , xp = xp , ** kwargs )
842+ return self ._iop ("add" , operator .add , y , copy = copy , xp = xp )
849843
850844 def subtract (
851845 self ,
852846 y : Array ,
853847 / ,
854848 copy : bool | None = True ,
855849 xp : ModuleType | None = None ,
856- ** kwargs : Untyped ,
857850 ) -> Array :
858851 """Apply ``x[idx] -= y`` and return the updated array"""
859- return self ._iop ("subtract" , operator .sub , y , copy = copy , xp = xp , ** kwargs )
852+ return self ._iop ("subtract" , operator .sub , y , copy = copy , xp = xp )
860853
861854 def multiply (
862855 self ,
863856 y : Array ,
864857 / ,
865858 copy : bool | None = True ,
866859 xp : ModuleType | None = None ,
867- ** kwargs : Untyped ,
868860 ) -> Array :
869861 """Apply ``x[idx] *= y`` and return the updated array"""
870- return self ._iop ("multiply" , operator .mul , y , copy = copy , xp = xp , ** kwargs )
862+ return self ._iop ("multiply" , operator .mul , y , copy = copy , xp = xp )
871863
872864 def divide (
873865 self ,
874866 y : Array ,
875867 / ,
876868 copy : bool | None = True ,
877869 xp : ModuleType | None = None ,
878- ** kwargs : Untyped ,
879870 ) -> Array :
880871 """Apply ``x[idx] /= y`` and return the updated array"""
881- return self ._iop ("divide" , operator .truediv , y , copy = copy , xp = xp , ** kwargs )
872+ return self ._iop ("divide" , operator .truediv , y , copy = copy , xp = xp )
882873
883874 def power (
884875 self ,
885876 y : Array ,
886877 / ,
887878 copy : bool | None = True ,
888879 xp : ModuleType | None = None ,
889- ** kwargs : Untyped ,
890880 ) -> Array :
891881 """Apply ``x[idx] **= y`` and return the updated array"""
892- return self ._iop ("power" , operator .pow , y , copy = copy , xp = xp , ** kwargs )
882+ return self ._iop ("power" , operator .pow , y , copy = copy , xp = xp )
893883
894884 def min (
895885 self ,
896886 y : Array ,
897887 / ,
898888 copy : bool | None = True ,
899889 xp : ModuleType | None = None ,
900- ** kwargs : Untyped ,
901890 ) -> Array :
902891 """Apply ``x[idx] = minimum(x[idx], y)`` and return the updated array"""
903892 if xp is None :
904893 xp = array_namespace (self .x )
905894 y = xp .asarray (y )
906- return self ._iop ("min" , xp .minimum , y , copy = copy , xp = xp , ** kwargs )
895+ return self ._iop ("min" , xp .minimum , y , copy = copy , xp = xp )
907896
908897 def max (
909898 self ,
910899 y : Array ,
911900 / ,
912901 copy : bool | None = True ,
913902 xp : ModuleType | None = None ,
914- ** kwargs : Untyped ,
915903 ) -> Array :
916904 """Apply ``x[idx] = maximum(x[idx], y)`` and return the updated array"""
917905 if xp is None :
918906 xp = array_namespace (self .x )
919907 y = xp .asarray (y )
920- return self ._iop ("max" , xp .maximum , y , copy = copy , xp = xp , ** kwargs )
908+ return self ._iop ("max" , xp .maximum , y , copy = copy , xp = xp )
0 commit comments