Skip to content

Commit dd6d3e8

Browse files
authored
Merge pull request #353 from ev-br/torch_clip
ENH: use torch.clamp for wrapped_torch.clip
2 parents a732948 + d794015 commit dd6d3e8

File tree

2 files changed

+44
-1
lines changed

2 files changed

+44
-1
lines changed

array_api_compat/torch/_aliases.py

Lines changed: 32 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -220,7 +220,6 @@ def min(x: Array, /, *, axis: int | tuple[int, ...] |None = None, keepdims: bool
220220
return torch.clone(x)
221221
return torch.amin(x, axis, keepdims=keepdims)
222222

223-
clip = get_xp(torch)(_aliases.clip)
224223
unstack = get_xp(torch)(_aliases.unstack)
225224
cumulative_sum = get_xp(torch)(_aliases.cumulative_sum)
226225
cumulative_prod = get_xp(torch)(_aliases.cumulative_prod)
@@ -835,6 +834,38 @@ def take_along_axis(x: Array, indices: Array, /, *, axis: int = -1) -> Array:
835834
)
836835

837836

837+
def clip(
838+
x: Array,
839+
/,
840+
min: int | float | Array | None = None,
841+
max: int | float | Array | None = None,
842+
**kwargs
843+
) -> Array:
844+
def _isscalar(a: object):
845+
return isinstance(a, int | float) or a is None
846+
847+
# cf clip in common/_aliases.py
848+
if not x.is_floating_point():
849+
if type(min) is int and min <= torch.iinfo(x.dtype).min:
850+
min = None
851+
if type(max) is int and max >= torch.iinfo(x.dtype).max:
852+
max = None
853+
854+
if min is None and max is None:
855+
return torch.clone(x)
856+
857+
min_is_scalar = _isscalar(min)
858+
max_is_scalar = _isscalar(max)
859+
860+
if min is not None and max is not None:
861+
if min_is_scalar and not max_is_scalar:
862+
min = torch.as_tensor(min, dtype=x.dtype, device=x.device)
863+
if max_is_scalar and not min_is_scalar:
864+
max = torch.as_tensor(max, dtype=x.dtype, device=x.device)
865+
866+
return torch.clamp(x, min, max, **kwargs)
867+
868+
838869
def sign(x: Array, /) -> Array:
839870
# torch sign() does not support complex numbers and does not propagate
840871
# nans. See https://github.com/data-apis/array-api-compat/issues/136

tests/test_torch.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,18 @@ def test_gh_273(self, default_dt, dtype_a, dtype_b):
102102
torch.set_default_dtype(prev_default)
103103

104104

105+
def test_clip_vmap():
106+
# https://github.com/data-apis/array-api-compat/issues/350
107+
def apply_clip_compat(a):
108+
return xp.clip(a, min=0, max=30)
109+
110+
a = xp.asarray([[5.1, 2.0, 64.1, -1.5]])
111+
112+
ref = apply_clip_compat(a)
113+
v1 = torch.vmap(apply_clip_compat)
114+
assert xp.all(v1(a) == ref)
115+
116+
105117
def test_meshgrid():
106118
"""Verify that array_api_compat.torch.meshgrid defaults to indexing='xy'."""
107119

0 commit comments

Comments
 (0)