@@ -683,7 +683,7 @@ def _common(
683683 xp : ModuleType | None = None ,
684684 _is_update : bool = True ,
685685 ** kwargs : Untyped ,
686- ) -> tuple [Untyped , None ] | tuple [None , Array ]:
686+ ) -> tuple [Array , None ] | tuple [None , Array ]:
687687 """Perform common prepocessing.
688688
689689 Returns
@@ -704,16 +704,22 @@ def _common(
704704
705705 x = self .x
706706
707+ if copy not in (True , False , None ):
708+ msg = f"copy must be True, False, or None; got { copy !r} " # pyright: ignore[reportUnreachable]
709+ raise ValueError (msg )
710+
707711 if copy is None :
708712 writeable = is_writeable_array (x )
709713 copy = _is_update and not writeable
710714 elif copy :
711715 writeable = None
712- else :
716+ elif _is_update :
713717 writeable = is_writeable_array (x )
714718 if not writeable :
715719 msg = "Cannot modify parameter in place"
716720 raise ValueError (msg )
721+ else :
722+ writeable = None
717723
718724 if copy :
719725 try :
@@ -723,10 +729,10 @@ def _common(
723729 # with a copy followed by an update
724730 if xp is None :
725731 xp = array_namespace (x )
726- # Create writeable copy of read-only numpy array
727732 x = xp .asarray (x , copy = True )
728733 if writeable is False :
729734 # A copy of a read-only numpy array is writeable
735+ # Note: this assumes that a copy of a writeable array is writeable
730736 writeable = None
731737 else :
732738 # Use JAX's at[] or other library that with the same duck-type API
@@ -743,12 +749,18 @@ def _common(
743749
744750 return None , x
745751
746- def get (self , ** kwargs : Untyped ) -> Untyped :
752+ def get (
753+ self ,
754+ / ,
755+ copy : bool | None = True ,
756+ xp : ModuleType | None = None ,
757+ ** kwargs : Untyped ,
758+ ) -> Untyped :
747759 """Return ``x[idx]``. In addition to plain ``__getitem__``, this allows ensuring
748760 that the output is either a copy or a view; it also allows passing
749761 keyword arguments to the backend.
750762 """
751- if kwargs . get ( " copy" ) is False :
763+ if copy is False :
752764 if is_array_api_obj (self .idx ):
753765 # Boolean index. Note that the array API spec
754766 # https://data-apis.org/array-api/latest/API_specification/indexing.html
@@ -758,19 +770,38 @@ def get(self, **kwargs: Untyped) -> Untyped:
758770 # which can be caught by testing the user code vs. array-api-strict.
759771 msg = "get() with an array index always returns a copy"
760772 raise ValueError (msg )
773+
774+ # Prevent scalar indices together with copy=False.
775+ # Even if some backends may return a scalar view of the original, we chose to be
776+ # strict here beceause some other backends, such as numpy, definitely don't.
777+ tup_idx = self .idx if isinstance (self .idx , tuple ) else (self .idx ,)
778+ if any (
779+ i is not None and i is not Ellipsis and not isinstance (i , slice )
780+ for i in tup_idx
781+ ):
782+ msg = "get() with a scalar index typically returns a copy"
783+ raise ValueError (msg )
784+
761785 if is_dask_array (self .x ):
762786 msg = "get() on Dask arrays always returns a copy"
763787 raise ValueError (msg )
764788
765- res , x = self ._common ("get" , _is_update = False , ** kwargs )
789+ res , x = self ._common ("get" , copy = copy , xp = xp , _is_update = False , ** kwargs )
766790 if res is not None :
767791 return res
768792 assert x is not None
769793 return x [self .idx ]
770794
771- def set (self , y : Array , / , ** kwargs : Untyped ) -> Array :
795+ def set (
796+ self ,
797+ y : Array ,
798+ / ,
799+ copy : bool | None = True ,
800+ xp : ModuleType | None = None ,
801+ ** kwargs : Untyped ,
802+ ) -> Array :
772803 """Apply ``x[idx] = y`` and return the update array"""
773- res , x = self ._common ("set" , y , ** kwargs )
804+ res , x = self ._common ("set" , y , copy = copy , xp = xp , ** kwargs )
774805 if res is not None :
775806 return res
776807 assert x is not None
@@ -785,6 +816,8 @@ def _iop(
785816 elwise_op : Callable [[Array , Array ], Array ],
786817 y : Array ,
787818 / ,
819+ copy : bool | None = True ,
820+ xp : ModuleType | None = None ,
788821 ** kwargs : Untyped ,
789822 ) -> Array :
790823 """x[idx] += y or equivalent in-place operation on a subset of x
@@ -796,41 +829,92 @@ def _iop(
796829 Consider for example when x is a numpy array and idx is a fancy index, which
797830 triggers a deep copy on __getitem__.
798831 """
799- res , x = self ._common (at_op , y , ** kwargs )
832+ res , x = self ._common (at_op , y , copy = copy , xp = xp , ** kwargs )
800833 if res is not None :
801834 return res
802835 assert x is not None
803836 x [self .idx ] = elwise_op (x [self .idx ], y )
804837 return x
805838
806- def add (self , y : Array , / , ** kwargs : Untyped ) -> Array :
839+ def add (
840+ self ,
841+ y : Array ,
842+ / ,
843+ copy : bool | None = True ,
844+ xp : ModuleType | None = None ,
845+ ** kwargs : Untyped ,
846+ ) -> Array :
807847 """Apply ``x[idx] += y`` and return the updated array"""
808- return self ._iop ("add" , operator .add , y , ** kwargs )
848+ return self ._iop ("add" , operator .add , y , copy = copy , xp = xp , ** kwargs )
809849
810- def subtract (self , y : Array , / , ** kwargs : Untyped ) -> Array :
850+ def subtract (
851+ self ,
852+ y : Array ,
853+ / ,
854+ copy : bool | None = True ,
855+ xp : ModuleType | None = None ,
856+ ** kwargs : Untyped ,
857+ ) -> Array :
811858 """Apply ``x[idx] -= y`` and return the updated array"""
812- return self ._iop ("subtract" , operator .sub , y , ** kwargs )
859+ return self ._iop ("subtract" , operator .sub , y , copy = copy , xp = xp , ** kwargs )
813860
814- def multiply (self , y : Array , / , ** kwargs : Untyped ) -> Array :
861+ def multiply (
862+ self ,
863+ y : Array ,
864+ / ,
865+ copy : bool | None = True ,
866+ xp : ModuleType | None = None ,
867+ ** kwargs : Untyped ,
868+ ) -> Array :
815869 """Apply ``x[idx] *= y`` and return the updated array"""
816- return self ._iop ("multiply" , operator .mul , y , ** kwargs )
870+ return self ._iop ("multiply" , operator .mul , y , copy = copy , xp = xp , ** kwargs )
817871
818- def divide (self , y : Array , / , ** kwargs : Untyped ) -> Array :
872+ def divide (
873+ self ,
874+ y : Array ,
875+ / ,
876+ copy : bool | None = True ,
877+ xp : ModuleType | None = None ,
878+ ** kwargs : Untyped ,
879+ ) -> Array :
819880 """Apply ``x[idx] /= y`` and return the updated array"""
820- return self ._iop ("divide" , operator .truediv , y , ** kwargs )
881+ return self ._iop ("divide" , operator .truediv , y , copy = copy , xp = xp , ** kwargs )
821882
822- def power (self , y : Array , / , ** kwargs : Untyped ) -> Array :
883+ def power (
884+ self ,
885+ y : Array ,
886+ / ,
887+ copy : bool | None = True ,
888+ xp : ModuleType | None = None ,
889+ ** kwargs : Untyped ,
890+ ) -> Array :
823891 """Apply ``x[idx] **= y`` and return the updated array"""
824- return self ._iop ("power" , operator .pow , y , ** kwargs )
892+ return self ._iop ("power" , operator .pow , y , copy = copy , xp = xp , ** kwargs )
825893
826- def min (self , y : Array , / , ** kwargs : Untyped ) -> Array :
894+ def min (
895+ self ,
896+ y : Array ,
897+ / ,
898+ copy : bool | None = True ,
899+ xp : ModuleType | None = None ,
900+ ** kwargs : Untyped ,
901+ ) -> Array :
827902 """Apply ``x[idx] = minimum(x[idx], y)`` and return the updated array"""
828- xp = array_namespace (self .x )
903+ if xp is None :
904+ xp = array_namespace (self .x )
829905 y = xp .asarray (y )
830- return self ._iop ("min" , xp .minimum , y , ** kwargs )
906+ return self ._iop ("min" , xp .minimum , y , copy = copy , xp = xp , ** kwargs )
831907
832- def max (self , y : Array , / , ** kwargs : Untyped ) -> Array :
908+ def max (
909+ self ,
910+ y : Array ,
911+ / ,
912+ copy : bool | None = True ,
913+ xp : ModuleType | None = None ,
914+ ** kwargs : Untyped ,
915+ ) -> Array :
833916 """Apply ``x[idx] = maximum(x[idx], y)`` and return the updated array"""
834- xp = array_namespace (self .x )
917+ if xp is None :
918+ xp = array_namespace (self .x )
835919 y = xp .asarray (y )
836- return self ._iop ("max" , xp .maximum , y , ** kwargs )
920+ return self ._iop ("max" , xp .maximum , y , copy = copy , xp = xp , ** kwargs )
0 commit comments