diff --git a/array_api_compat/torch/_aliases.py b/array_api_compat/torch/_aliases.py index 91161d24..af3dffc5 100644 --- a/array_api_compat/torch/_aliases.py +++ b/array_api_compat/torch/_aliases.py @@ -220,7 +220,6 @@ def min(x: Array, /, *, axis: int | tuple[int, ...] |None = None, keepdims: bool return torch.clone(x) return torch.amin(x, axis, keepdims=keepdims) -clip = get_xp(torch)(_aliases.clip) unstack = get_xp(torch)(_aliases.unstack) cumulative_sum = get_xp(torch)(_aliases.cumulative_sum) cumulative_prod = get_xp(torch)(_aliases.cumulative_prod) @@ -808,6 +807,38 @@ def take_along_axis(x: Array, indices: Array, /, *, axis: int = -1) -> Array: return torch.take_along_dim(x, indices, dim=axis) +def clip( + x: Array, + /, + min: int | float | Array | None = None, + max: int | float | Array | None = None, + **kwargs +) -> Array: + def _isscalar(a: object): + return isinstance(a, int | float) or a is None + + # cf clip in common/_aliases.py + if not x.is_floating_point(): + if type(min) is int and min <= torch.iinfo(x.dtype).min: + min = None + if type(max) is int and max >= torch.iinfo(x.dtype).max: + max = None + + if min is None and max is None: + return torch.clone(x) + + min_is_scalar = _isscalar(min) + max_is_scalar = _isscalar(max) + + if min is not None and max is not None: + if min_is_scalar and not max_is_scalar: + min = torch.as_tensor(min, dtype=x.dtype, device=x.device) + if max_is_scalar and not min_is_scalar: + max = torch.as_tensor(max, dtype=x.dtype, device=x.device) + + return torch.clamp(x, min, max, **kwargs) + + def sign(x: Array, /) -> Array: # torch sign() does not support complex numbers and does not propagate # nans. See https://github.com/data-apis/array-api-compat/issues/136 diff --git a/tests/test_torch.py b/tests/test_torch.py index 7adb4ab3..b3445a0e 100644 --- a/tests/test_torch.py +++ b/tests/test_torch.py @@ -102,6 +102,18 @@ def test_gh_273(self, default_dt, dtype_a, dtype_b): torch.set_default_dtype(prev_default) +def test_clip_vmap(): + # https://github.com/data-apis/array-api-compat/issues/350 + def apply_clip_compat(a): + return xp.clip(a, min=0, max=30) + + a = xp.asarray([[5.1, 2.0, 64.1, -1.5]]) + + ref = apply_clip_compat(a) + v1 = torch.vmap(apply_clip_compat) + assert xp.all(v1(a) == ref) + + def test_meshgrid(): """Verify that array_api_compat.torch.meshgrid defaults to indexing='xy'."""