1111from ._lib import _utils
1212from ._lib ._compat import (
1313 array_namespace ,
14+ is_array_api_obj ,
15+ is_dask_array ,
1416 is_jax_array ,
17+ is_pydata_sparse_array ,
1518 is_writeable_array ,
1619)
1720
@@ -686,6 +689,73 @@ def __getitem__(self, idx: Index, /) -> at:
686689 self ._idx = idx
687690 return self
688691
692+ def _check_args (self , / , copy : bool | None ) -> None :
693+ if self ._idx is _undef :
694+ msg = (
695+ "Index has not been set.\n "
696+ "Usage: either\n "
697+ " at(x, idx).set(value)\n "
698+ "or\n "
699+ " at(x)[idx].set(value)\n "
700+ "(same for all other methods)."
701+ )
702+ raise TypeError (msg )
703+
704+ if copy not in (True , False , None ):
705+ msg = f"copy must be True, False, or None; got { copy !r} " # pyright: ignore[reportUnreachable]
706+ raise ValueError (msg )
707+
708+ def get (
709+ self ,
710+ / ,
711+ copy : bool | None = True ,
712+ xp : ModuleType | None = None ,
713+ ) -> Array :
714+ """Return ``xp.asarray(x[idx])``. In addition to plain ``__getitem__``,
715+ this allows ensuring that the output is either a copy or a view
716+ """
717+ self ._check_args (copy = copy )
718+ x = self ._x
719+
720+ if copy is False :
721+ if is_array_api_obj (self ._idx ):
722+ # Boolean index. Note that the array API spec
723+ # https://data-apis.org/array-api/latest/API_specification/indexing.html
724+ # does not allow for list, tuple, and tuples of slices plus one or more
725+ # one-dimensional array indices, although many backends support them.
726+ # So this check will encounter a lot of false negatives in real life,
727+ # which can be caught by testing the user code vs. array-api-strict.
728+ msg = "get() with an array index always returns a copy"
729+ raise ValueError (msg )
730+
731+ # Prevent scalar indices together with copy=False.
732+ # Even if some backends may return a scalar view of the original, we chose to be
733+ # strict here beceause some other backends, such as numpy, definitely don't.
734+ tup_idx = self ._idx if isinstance (self ._idx , tuple ) else (self ._idx ,)
735+ if any (
736+ i is not None and i is not Ellipsis and not isinstance (i , slice )
737+ for i in tup_idx
738+ ):
739+ msg = "get() with a scalar index typically returns a copy"
740+ raise ValueError (msg )
741+
742+ # Note: this is not the same list of backends as is_writeable_array()
743+ if is_dask_array (x ) or is_jax_array (x ) or is_pydata_sparse_array (x ):
744+ msg = f"get() on { array_namespace (x )} arrays always returns a copy"
745+ raise ValueError (msg )
746+
747+ if is_jax_array (x ):
748+ # Use JAX's at[] or other library that with the same duck-type API
749+ return x .at [self ._idx ].get ()
750+
751+ if xp is None :
752+ xp = array_namespace (x )
753+ # Note: when idx is a boolean mask, numpy always returns a deep copy.
754+ # However, some backends may legitimately return a view when the mask can
755+ # be downgraded to a slice, e.g. a[[True, True, False]] -> a[:2].
756+ # Err on the side of caution and perform a double-copy in numpy.
757+ return xp .asarray (x [self ._idx ], copy = copy )
758+
689759 def _update_common (
690760 self ,
691761 at_op : str ,
@@ -701,23 +771,7 @@ def _update_common(
701771 If the operation can be resolved by at[], (return value, None)
702772 Otherwise, (None, preprocessed x)
703773 """
704- x , idx = self ._x , self ._idx
705-
706- if idx is _undef :
707- msg = (
708- "Index has not been set.\n "
709- "Usage: either\n "
710- " at(x, idx).set(value)\n "
711- "or\n "
712- " at(x)[idx].set(value)\n "
713- "(same for all other methods)."
714- )
715- raise TypeError (msg )
716-
717- if copy not in (True , False , None ):
718- msg = f"copy must be True, False, or None; got { copy !r} " # pyright: ignore[reportUnreachable]
719- raise ValueError (msg )
720-
774+ x = self ._x
721775 if copy is None :
722776 writeable = is_writeable_array (x )
723777 copy = not writeable
@@ -758,6 +812,7 @@ def set(
758812 xp : ModuleType | None = None ,
759813 ) -> Array :
760814 """Apply ``x[idx] = y`` and return the update array"""
815+ self ._check_args (copy = copy )
761816 res , x = self ._update_common ("set" , y , copy = copy , xp = xp )
762817 if res is not None :
763818 return res
@@ -785,6 +840,7 @@ def _iop(
785840 Consider for example when x is a numpy array and idx is a fancy index, which
786841 triggers a deep copy on __getitem__.
787842 """
843+ self ._check_args (copy = copy )
788844 res , x = self ._update_common (at_op , y , copy = copy , xp = xp )
789845 if res is not None :
790846 return res
0 commit comments