7
7
"""
8
8
from __future__ import annotations
9
9
10
+ import operator
10
11
from typing import TYPE_CHECKING
11
12
12
13
if TYPE_CHECKING :
13
- from typing import Optional , Union , Any
14
+ from typing import Callable , Optional , Union , Any
14
15
from ._typing import Array , Device
15
16
16
17
import sys
@@ -811,9 +812,22 @@ def is_writeable_array(x):
811
812
return False
812
813
return True
813
814
815
+ def _parse_copy_param (x , copy : bool | None ) -> bool :
816
+ """Preprocess and validate a copy parameter, in line with the same
817
+ parameter in np.asarray(), np.astype(), etc.
818
+ """
819
+ if copy is None :
820
+ return not is_writeable_array (x )
821
+ if copy is False :
822
+ if not is_writeable_array (x ):
823
+ raise ValueError ("Cannot avoid modifying parameter in place" )
824
+ elif copy is not True :
825
+ raise ValueError (f"Invalid value for copy: { copy !r} " )
826
+ return copy
827
+
814
828
_undef = object ()
815
829
816
- def at ( x , idx = _undef , / ) :
830
+ class at :
817
831
"""
818
832
Update operations for read-only arrays.
819
833
@@ -823,12 +837,22 @@ def at(x, idx=_undef, /):
823
837
Keyword arguments (e.g. ``indices_are_sorted``) are passed to JAX and are
824
838
quietly ignored for backends that don't support them.
825
839
840
+ Additionally, this introduces support for the `copy` keyword for all backends:
841
+
842
+ None
843
+ x *may* be modified in place if it is possible and beneficial
844
+ for performance. You should not use x after calling this function.
845
+ True
846
+ Ensure that the inputs are not modified. This is the default.
847
+ False
848
+ Raise ValueError if a copy cannot be avoided.
849
+
826
850
Examples
827
851
--------
828
852
Given either of these equivalent expressions::
829
853
830
- x = at(x)[1].add(2)
831
- x = at(x, 1).add(2)
854
+ x = at(x)[1].add(2, copy=None )
855
+ x = at(x, 1).add(2, copy=None )
832
856
833
857
If x is a JAX array, they are the same as::
834
858
@@ -845,16 +869,17 @@ def at(x, idx=_undef, /):
845
869
846
870
Warning
847
871
-------
848
- You should always immediately overwrite the parameter array::
872
+ When you use copy=None, you should always immediately overwrite
873
+ the parameter array::
849
874
850
- x = at(x, 0).set(2)
875
+ x = at(x, 0).set(2, copy=None )
851
876
852
877
The anti-pattern below must be avoided, as it will result in different behaviour
853
878
on read-only versus writeable arrays:
854
879
855
880
x = xp.asarray([0, 0, 0])
856
- y = at(x, 0).set(2)
857
- z = at(x, 1).set(3)
881
+ y = at(x, 0).set(2, copy=None )
882
+ z = at(x, 1).set(3, copy=None )
858
883
859
884
In the above example, y == [2, 0, 0] and z == [0, 3, 0] when x is read-only,
860
885
whereas y == z == [2, 3, 0] when x is writeable!
@@ -863,18 +888,6 @@ def at(x, idx=_undef, /):
863
888
--------
864
889
https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.ndarray.at.html
865
890
"""
866
- if is_jax_array (x ):
867
- return x .at
868
- if is_numpy_array (x ) and not x .flags .writeable :
869
- x = x .copy ()
870
- return _InPlaceAt (x , idx )
871
-
872
- class _InPlaceAt :
873
- """Helper of at().
874
-
875
- Trivially implement jax.numpy.ndarray.at for other backends.
876
- x is updated in place.
877
- """
878
891
__slots__ = ("x" , "idx" )
879
892
880
893
def __init__ (self , x , idx = _undef ):
@@ -890,7 +903,16 @@ def __getitem__(self, idx):
890
903
self .idx = idx
891
904
return self
892
905
893
- def _check_args (self , mode = "promise_in_bounds" , ** kwargs ):
906
+ def _common (self , at_op , y = _undef , mode : str = "promise_in_bounds" , ** kwargs ):
907
+ """Validate kwargs and perform common prepocessing.
908
+
909
+ Returns
910
+ -------
911
+ If the operation can be resolved by at[],
912
+ (return value, None)
913
+ Otherwise,
914
+ (None, preprocessed x)
915
+ """
894
916
if self .idx is _undef :
895
917
raise TypeError (
896
918
"Index has not been set.\n "
@@ -900,74 +922,136 @@ def _check_args(self, mode="promise_in_bounds", **kwargs):
900
922
" at(x)[idx].set(value)\n "
901
923
"(same for all other methods)."
902
924
)
903
- if mode != "promise_in_bounds" :
925
+ if mode != "promise_in_bounds" and not is_jax_array ( self . x ) :
904
926
xp = array_namespace (self .x )
905
927
raise NotImplementedError (
906
- f"mode='{ mode } ' is not supported for backend { xp .__name__ } "
928
+ f"mode='{ mode !r} ' is not supported for backend { xp .__name__ } "
929
+ )
930
+
931
+ copy = _parse_copy_param (self .x , copy )
932
+
933
+ if copy and is_jax_array (self .x ):
934
+ # Use JAX's at[]
935
+ at_ = self .x .at [self .idx ]
936
+ args = (y , ) if y is not _undef else ()
937
+ return getattr (at_ , at_op )(* args , mode = mode , ** kwargs ), None
938
+
939
+ # Emulate at[] behaviour for non-JAX arrays
940
+ x = self .x .copy () if copy else self .x
941
+ return None , x
942
+
943
+ def get (self , copy : bool | None = True , ** kwargs ):
944
+ """Return x[idx]. In addition to plain __getitem__, this allows ensuring
945
+ that the output is (not) a copy and kwargs are passed to the backend."""
946
+ # Special case when xp=numpy and idx is a fancy index
947
+ # If copy is not False, avoid an unnecessary double copy.
948
+ # if copy is forced to False, raise.
949
+ if (
950
+ is_numpy_array (self .x )
951
+ and (
952
+ isinstance (self .idx , (list , tuple ))
953
+ or (is_numpy_array (self .idx ) and self .idx .dtype .kind in "biu" )
907
954
)
955
+ ):
956
+ if copy is True :
957
+ copy = None
958
+ elif copy is False :
959
+ raise ValueError (
960
+ "Indexing a numpy array with a fancy index always "
961
+ "results in a copy"
962
+ )
963
+
964
+ res , x = self ._common ("get" , copy = copy , ** kwargs )
965
+ if res is not None :
966
+ return res
967
+ return x [self .idx ]
908
968
909
969
def set (self , y , / , ** kwargs ):
910
- self ._check_args (** kwargs )
911
- self .x [self .idx ] = y
912
- return self .x
970
+ """x[idx] = y"""
971
+ res , x = self ._common ("set" , y , ** kwargs )
972
+ if res is not None :
973
+ return res
974
+ x [self .idx ] = y
975
+ return x
976
+
977
+ def apply (self , ufunc , / , ** kwargs ):
978
+ """ufunc.at(x, idx)"""
979
+ res , x = self ._common ("apply" , ufunc , ** kwargs )
980
+ if res is not None :
981
+ return res
982
+ ufunc .at (x , self .idx )
983
+ return x
984
+
985
+ def _iop (self , at_op : str , elwise_op : Callable [[Array , Array ], Array ], y : Array , ** kwargs ):
986
+ """x[idx] += y or equivalent in-place operation on a subset of x
987
+
988
+ which is the same as saying
989
+ x[idx] = x[idx] + y
990
+ Note that this is not the same as
991
+ operator.iadd(x[idx], y)
992
+ Consider for example when x is a numpy array and idx is a fancy index, which
993
+ triggers a deep copy on __getitem__.
994
+ """
995
+ res , x = self ._common (at_op , y , ** kwargs )
996
+ if res is not None :
997
+ return res
998
+ x [self .idx ] = elwise_op (x [self .idx ], y )
999
+ return x
913
1000
914
1001
def add (self , y , / , ** kwargs ):
915
- self ._check_args (** kwargs )
916
- self .x [self .idx ] += y
917
- return self .x
918
-
1002
+ """x[idx] += y"""
1003
+ return self ._iop ("add" , operator .add , y , ** kwargs )
1004
+
919
1005
def subtract (self , y , / , ** kwargs ):
920
- self ._check_args (** kwargs )
921
- self .x [self .idx ] -= y
922
- return self .x
1006
+ """x[idx] -= y"""
1007
+ return self ._iop ("subtract" , operator .sub , y , ** kwargs )
923
1008
924
1009
def multiply (self , y , / , ** kwargs ):
925
- self ._check_args (** kwargs )
926
- self .x [self .idx ] *= y
927
- return self .x
1010
+ """x[idx] *= y"""
1011
+ return self ._iop ("multiply" , operator .mul , y , ** kwargs )
928
1012
929
1013
def divide (self , y , / , ** kwargs ):
930
- self ._check_args (** kwargs )
931
- self .x [self .idx ] /= y
932
- return self .x
933
-
1014
+ """x[idx] /= y"""
1015
+ return self ._iop ("divide" , operator .truediv , y , ** kwargs )
1016
+
934
1017
def power (self , y , / , ** kwargs ):
935
- self ._check_args (** kwargs )
936
- self .x [self .idx ] **= y
937
- return self .x
1018
+ """x[idx] **= y"""
1019
+ return self ._iop ("power" , operator .pow , y , ** kwargs )
938
1020
939
1021
def min (self , y , / , ** kwargs ):
940
- self ._check_args (** kwargs )
941
- xp = array_namespace (self .x , y )
942
- self .x [self .idx ] = xp .minimum (self .x [self .idx ], y )
943
- return self .x
1022
+ """x[idx] = minimum(x[idx], y)"""
1023
+ xp = array_namespace (self .x )
1024
+ return self ._iop ("min" , xp .minimum , y , ** kwargs )
944
1025
945
1026
def max (self , y , / , ** kwargs ):
946
- self ._check_args (** kwargs )
947
- xp = array_namespace (self .x , y )
948
- self .x [self .idx ] = xp .maximum (self .x [self .idx ], y )
949
- return self .x
950
-
951
- def apply (self , ufunc , / , ** kwargs ):
952
- self ._check_args (** kwargs )
953
- ufunc .at (self .x , self .idx )
954
- return self .x
955
-
956
- def get (self , ** kwargs ):
957
- self ._check_args (** kwargs )
958
- return self .x [self .idx ]
959
-
960
- def iwhere (condition , x , y , / ):
961
- """Variant of xp.where(condition, x, y) which may or may not update
962
- x in place, if it's possible and beneficial for performance.
1027
+ """x[idx] = maximum(x[idx], y)"""
1028
+ xp = array_namespace (self .x )
1029
+ return self ._iop ("max" , xp .maximum , y , ** kwargs )
1030
+
1031
+ def where (condition , x , y , / , copy : bool | None = True ):
1032
+ """Return elements from x when condition is True and from y when
1033
+ it is False.
1034
+
1035
+ This is a wrapper around xp.where that adds the copy parameter:
1036
+
1037
+ None
1038
+ x *may* be modified in place if it is possible and beneficial
1039
+ for performance. You should not use x after calling this function.
1040
+ True
1041
+ Ensure that the inputs are not modified.
1042
+ This is the default, in line with np.where.
1043
+ False
1044
+ Raise ValueError if a copy cannot be avoided.
963
1045
"""
1046
+ copy = _parse_copy_param (x , copy )
964
1047
xp = array_namespace (condition , x , y )
965
- if is_writeable_array (x ):
1048
+ if copy :
1049
+ return xp .where (condition , x , y )
1050
+ else :
966
1051
condition , x , y = xp .broadcast_arrays (condition , x , y )
967
1052
x [condition ] = y [condition ]
968
1053
return x
969
- else :
970
- return xp .where (condition , x , y )
1054
+
971
1055
972
1056
__all__ = [
973
1057
"array_namespace" ,
@@ -993,7 +1077,7 @@ def iwhere(condition, x, y, /):
993
1077
"size" ,
994
1078
"to_device" ,
995
1079
"at" ,
996
- "iwhere " ,
1080
+ "where " ,
997
1081
]
998
1082
999
1083
_all_ignore = ['sys' , 'math' , 'inspect' , 'warnings' ]
0 commit comments