@@ -909,8 +909,7 @@ def _common(
909
909
self ,
910
910
at_op : str ,
911
911
y = _undef ,
912
- copy : bool | None = True ,
913
- mode : str = "promise_in_bounds" ,
912
+ copy : bool | None = True ,
914
913
** kwargs ,
915
914
):
916
915
"""Validate kwargs and perform common prepocessing.
@@ -931,19 +930,14 @@ def _common(
931
930
" at(x)[idx].set(value)\n "
932
931
"(same for all other methods)."
933
932
)
934
- if mode != "promise_in_bounds" and not is_jax_array (self .x ):
935
- xp = array_namespace (self .x )
936
- raise NotImplementedError (
937
- f"mode='{ mode !r} ' is not supported for backend { xp .__name__ } "
938
- )
939
933
940
934
copy = _parse_copy_param (self .x , copy )
941
935
942
936
if copy and is_jax_array (self .x ):
943
937
# Use JAX's at[]
944
938
at_ = self .x .at [self .idx ]
945
939
args = (y , ) if y is not _undef else ()
946
- return getattr (at_ , at_op )(* args , mode = mode , ** kwargs ), None
940
+ return getattr (at_ , at_op )(* args , ** kwargs ), None
947
941
948
942
# Emulate at[] behaviour for non-JAX arrays
949
943
x = self .x .copy () if copy else self .x
0 commit comments