1111from ._lib import _utils
1212from ._lib ._compat import (
1313 array_namespace ,
14- is_array_api_obj ,
15- is_dask_array ,
1614 is_jax_array ,
17- is_pydata_sparse_array ,
1815 is_writeable_array ,
1916)
2017
@@ -689,73 +686,6 @@ def __getitem__(self, idx: Index, /) -> at:
689686 self ._idx = idx
690687 return self
691688
692- def _check_args (self , / , copy : bool | None ) -> None :
693- if self ._idx is _undef :
694- msg = (
695- "Index has not been set.\n "
696- "Usage: either\n "
697- " at(x, idx).set(value)\n "
698- "or\n "
699- " at(x)[idx].set(value)\n "
700- "(same for all other methods)."
701- )
702- raise TypeError (msg )
703-
704- if copy not in (True , False , None ):
705- msg = f"copy must be True, False, or None; got { copy !r} " # pyright: ignore[reportUnreachable]
706- raise ValueError (msg )
707-
708- def get (
709- self ,
710- / ,
711- copy : bool | None = True ,
712- xp : ModuleType | None = None ,
713- ) -> Array :
714- """Return ``xp.asarray(x[idx])``. In addition to plain ``__getitem__``,
715- this allows ensuring that the output is either a copy or a view
716- """
717- self ._check_args (copy = copy )
718- x = self ._x
719-
720- if copy is False :
721- if is_array_api_obj (self ._idx ):
722- # Boolean index. Note that the array API spec
723- # https://data-apis.org/array-api/latest/API_specification/indexing.html
724- # does not allow for list, tuple, and tuples of slices plus one or more
725- # one-dimensional array indices, although many backends support them.
726- # So this check will encounter a lot of false negatives in real life,
727- # which can be caught by testing the user code vs. array-api-strict.
728- msg = "get() with an array index always returns a copy"
729- raise ValueError (msg )
730-
731- # Prevent scalar indices together with copy=False.
732- # Even if some backends may return a scalar view of the original, we chose to be
733- # strict here beceause some other backends, such as numpy, definitely don't.
734- tup_idx = self ._idx if isinstance (self ._idx , tuple ) else (self ._idx ,)
735- if any (
736- i is not None and i is not Ellipsis and not isinstance (i , slice )
737- for i in tup_idx
738- ):
739- msg = "get() with a scalar index typically returns a copy"
740- raise ValueError (msg )
741-
742- # Note: this is not the same list of backends as is_writeable_array()
743- if is_dask_array (x ) or is_jax_array (x ) or is_pydata_sparse_array (x ):
744- msg = f"get() on { array_namespace (x )} arrays always returns a copy"
745- raise ValueError (msg )
746-
747- if is_jax_array (x ):
748- # Use JAX's at[] or other library that with the same duck-type API
749- return x .at [self ._idx ].get ()
750-
751- if xp is None :
752- xp = array_namespace (x )
753- # Note: when idx is a boolean mask, numpy always returns a deep copy.
754- # However, some backends may legitimately return a view when the mask can
755- # be downgraded to a slice, e.g. a[[True, True, False]] -> a[:2].
756- # Err on the side of caution and perform a double-copy in numpy.
757- return xp .asarray (x [self ._idx ], copy = copy )
758-
759689 def _update_common (
760690 self ,
761691 at_op : str ,
@@ -771,7 +701,23 @@ def _update_common(
771701 If the operation can be resolved by at[], (return value, None)
772702 Otherwise, (None, preprocessed x)
773703 """
774- x = self ._x
704+ x , idx = self ._x , self ._idx
705+
706+ if idx is _undef :
707+ msg = (
708+ "Index has not been set.\n "
709+ "Usage: either\n "
710+ " at(x, idx).set(value)\n "
711+ "or\n "
712+ " at(x)[idx].set(value)\n "
713+ "(same for all other methods)."
714+ )
715+ raise TypeError (msg )
716+
717+ if copy not in (True , False , None ):
718+ msg = f"copy must be True, False, or None; got { copy !r} " # pyright: ignore[reportUnreachable]
719+ raise ValueError (msg )
720+
775721 if copy is None :
776722 writeable = is_writeable_array (x )
777723 copy = not writeable
@@ -812,7 +758,6 @@ def set(
812758 xp : ModuleType | None = None ,
813759 ) -> Array :
814760 """Apply ``x[idx] = y`` and return the update array"""
815- self ._check_args (copy = copy )
816761 res , x = self ._update_common ("set" , y , copy = copy , xp = xp )
817762 if res is not None :
818763 return res
@@ -840,7 +785,6 @@ def _iop(
840785 Consider for example when x is a numpy array and idx is a fancy index, which
841786 triggers a deep copy on __getitem__.
842787 """
843- self ._check_args (copy = copy )
844788 res , x = self ._update_common (at_op , y , copy = copy , xp = xp )
845789 if res is not None :
846790 return res
0 commit comments