@@ -801,6 +801,54 @@ def size(x):
801801 return None
802802 return math .prod (x .shape )
803803
804+
805+ def is_writeable_array (x ):
806+ """
807+ Return False if x.__setitem__ is expected to raise; True otherwise
808+ """
809+ if is_numpy_array (x ):
810+ return x .flags .writeable
811+ if is_jax_array (x ) or is_pydata_sparse_array (x ):
812+ return False
813+ return True
814+
815+
816+ def where (condition , x = None , y = None , / , copy : bool | None = True ):
817+ """Return elements from x when condition is True and from y when
818+ it is False.
819+
820+ This is a wrapper around xp.where that adds the copy parameter:
821+
822+ None
823+ x *may* be modified in place if it is possible and beneficial
824+ for performance. You should not use x after calling this function.
825+ True
826+ Ensure that the inputs are not modified.
827+ This is the default, in line with np.where.
828+ False
829+ Raise ValueError if a copy cannot be avoided.
830+ """
831+ if x is None and y is None :
832+ xp = array_namespace (condition , use_compat = False )
833+ return xp .where (condition )
834+
835+ if copy is False :
836+ if not is_writeable_array (x ):
837+ raise ValueError ("Cannot modify parameter in place" )
838+ elif copy is None :
839+ copy = not is_writeable_array (x )
840+ elif copy is not True :
841+ raise ValueError (f"Invalid value for copy: { copy !r} " )
842+
843+ xp = array_namespace (condition , x , y , use_compat = False )
844+ if copy :
845+ return xp .where (condition , x , y )
846+ else :
847+ condition , x , y = xp .broadcast_arrays (condition , x , y )
848+ x [condition ] = y [condition ]
849+ return x
850+
851+
804852__all__ = [
805853 "array_namespace" ,
806854 "device" ,
@@ -821,8 +869,10 @@ def size(x):
821869 "is_ndonnx_namespace" ,
822870 "is_pydata_sparse_array" ,
823871 "is_pydata_sparse_namespace" ,
872+ "is_writeable_array" ,
824873 "size" ,
825874 "to_device" ,
875+ "where" ,
826876]
827877
828878_all_ignore = ['sys' , 'math' , 'inspect' , 'warnings' ]
0 commit comments