Skip to content

Commit 1a077c4

Browse files
authored
Merge pull request #378 from ev-br/fix_clip
BUG: torch: fix up clip
2 parents dd6d3e8 + a970410 commit 1a077c4

File tree

1 file changed

+18
-6
lines changed

1 file changed

+18
-6
lines changed

array_api_compat/torch/_aliases.py

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from functools import reduce as _reduce, wraps as _wraps
55
from builtins import all as _builtin_all, any as _builtin_any
66
from typing import Any, Literal
7+
import math
78

89
import torch
910

@@ -857,13 +858,24 @@ def _isscalar(a: object):
857858
min_is_scalar = _isscalar(min)
858859
max_is_scalar = _isscalar(max)
859860

860-
if min is not None and max is not None:
861-
if min_is_scalar and not max_is_scalar:
862-
min = torch.as_tensor(min, dtype=x.dtype, device=x.device)
863-
if max_is_scalar and not min_is_scalar:
864-
max = torch.as_tensor(max, dtype=x.dtype, device=x.device)
861+
if min_is_scalar and max_is_scalar:
862+
if (min is not None and math.isnan(min)) or (max is not None and math.isnan(max)):
863+
# edge case: torch.clamp(torch.zeros(1), float('nan')) -> tensor(0.)
864+
# https://github.com/pytorch/pytorch/issues/172067
865+
return torch.full_like(x, fill_value=torch.nan)
866+
return torch.clamp(x, min, max, **kwargs)
865867

866-
return torch.clamp(x, min, max, **kwargs)
868+
# pytorch has (tensor, tensor, tensor) and (tensor, scalar, scalar) signatures,
869+
# but does not accept (tensor, scalar, tensor)
870+
a_min = min
871+
if min is not None and min_is_scalar:
872+
a_min = torch.as_tensor(min, dtype=x.dtype, device=x.device)
873+
874+
a_max = max
875+
if max is not None and max_is_scalar:
876+
a_max = torch.as_tensor(max, dtype=x.dtype, device=x.device)
877+
878+
return torch.clamp(x, a_min, a_max, **kwargs)
867879

868880

869881
def sign(x: Array, /) -> Array:

0 commit comments

Comments
 (0)