@@ -220,7 +220,6 @@ def min(x: Array, /, *, axis: int | tuple[int, ...] |None = None, keepdims: bool
220220 return torch .clone (x )
221221 return torch .amin (x , axis , keepdims = keepdims )
222222
223- clip = get_xp (torch )(_aliases .clip )
224223unstack = get_xp (torch )(_aliases .unstack )
225224cumulative_sum = get_xp (torch )(_aliases .cumulative_sum )
226225cumulative_prod = get_xp (torch )(_aliases .cumulative_prod )
@@ -835,6 +834,38 @@ def take_along_axis(x: Array, indices: Array, /, *, axis: int = -1) -> Array:
835834 )
836835
837836
837+ def clip (
838+ x : Array ,
839+ / ,
840+ min : int | float | Array | None = None ,
841+ max : int | float | Array | None = None ,
842+ ** kwargs
843+ ) -> Array :
844+ def _isscalar (a : object ):
845+ return isinstance (a , int | float ) or a is None
846+
847+ # cf clip in common/_aliases.py
848+ if not x .is_floating_point ():
849+ if type (min ) is int and min <= torch .iinfo (x .dtype ).min :
850+ min = None
851+ if type (max ) is int and max >= torch .iinfo (x .dtype ).max :
852+ max = None
853+
854+ if min is None and max is None :
855+ return torch .clone (x )
856+
857+ min_is_scalar = _isscalar (min )
858+ max_is_scalar = _isscalar (max )
859+
860+ if min is not None and max is not None :
861+ if min_is_scalar and not max_is_scalar :
862+ min = torch .as_tensor (min , dtype = x .dtype , device = x .device )
863+ if max_is_scalar and not min_is_scalar :
864+ max = torch .as_tensor (max , dtype = x .dtype , device = x .device )
865+
866+ return torch .clamp (x , min , max , ** kwargs )
867+
868+
838869def sign (x : Array , / ) -> Array :
839870 # torch sign() does not support complex numbers and does not propagate
840871 # nans. See https://github.com/data-apis/array-api-compat/issues/136
0 commit comments