Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 32 additions & 1 deletion array_api_compat/torch/_aliases.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,7 +220,6 @@ def min(x: Array, /, *, axis: int | tuple[int, ...] |None = None, keepdims: bool
return torch.clone(x)
return torch.amin(x, axis, keepdims=keepdims)

clip = get_xp(torch)(_aliases.clip)
unstack = get_xp(torch)(_aliases.unstack)
cumulative_sum = get_xp(torch)(_aliases.cumulative_sum)
cumulative_prod = get_xp(torch)(_aliases.cumulative_prod)
Expand Down Expand Up @@ -808,6 +807,38 @@ def take_along_axis(x: Array, indices: Array, /, *, axis: int = -1) -> Array:
return torch.take_along_dim(x, indices, dim=axis)


def clip(
x: Array,
/,
min: int | float | Array | None = None,
max: int | float | Array | None = None,
**kwargs
) -> Array:
def _isscalar(a: object):
return isinstance(a, int | float) or a is None

# cf clip in common/_aliases.py
if not x.is_floating_point():
if type(min) is int and min <= torch.iinfo(x.dtype).min:
min = None
if type(max) is int and max >= torch.iinfo(x.dtype).max:
max = None

if min is None and max is None:
return torch.clone(x)

min_is_scalar = _isscalar(min)
max_is_scalar = _isscalar(max)

if min is not None and max is not None:
if min_is_scalar and not max_is_scalar:
min = torch.as_tensor(min, dtype=x.dtype, device=x.device)
if max_is_scalar and not min_is_scalar:
max = torch.as_tensor(max, dtype=x.dtype, device=x.device)

return torch.clamp(x, min, max, **kwargs)


def sign(x: Array, /) -> Array:
# torch sign() does not support complex numbers and does not propagate
# nans. See https://github.com/data-apis/array-api-compat/issues/136
Expand Down
12 changes: 12 additions & 0 deletions tests/test_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,18 @@ def test_gh_273(self, default_dt, dtype_a, dtype_b):
torch.set_default_dtype(prev_default)


def test_clip_vmap():
# https://github.com/data-apis/array-api-compat/issues/350
def apply_clip_compat(a):
return xp.clip(a, min=0, max=30)

a = xp.asarray([[5.1, 2.0, 64.1, -1.5]])

ref = apply_clip_compat(a)
v1 = torch.vmap(apply_clip_compat)
assert xp.all(v1(a) == ref)


def test_meshgrid():
"""Verify that array_api_compat.torch.meshgrid defaults to indexing='xy'."""

Expand Down
Loading