|
4 | 4 | from functools import reduce as _reduce, wraps as _wraps |
5 | 5 | from builtins import all as _builtin_all, any as _builtin_any |
6 | 6 | from typing import Any, Literal |
| 7 | +import math |
7 | 8 |
|
8 | 9 | import torch |
9 | 10 |
|
@@ -857,13 +858,24 @@ def _isscalar(a: object): |
857 | 858 | min_is_scalar = _isscalar(min) |
858 | 859 | max_is_scalar = _isscalar(max) |
859 | 860 |
|
860 | | - if min is not None and max is not None: |
861 | | - if min_is_scalar and not max_is_scalar: |
862 | | - min = torch.as_tensor(min, dtype=x.dtype, device=x.device) |
863 | | - if max_is_scalar and not min_is_scalar: |
864 | | - max = torch.as_tensor(max, dtype=x.dtype, device=x.device) |
| 861 | + if min_is_scalar and max_is_scalar: |
| 862 | + if (min is not None and math.isnan(min)) or (max is not None and math.isnan(max)): |
| 863 | + # edge case: torch.clamp(torch.zeros(1), float('nan')) -> tensor(0.) |
| 864 | + # https://github.com/pytorch/pytorch/issues/172067 |
| 865 | + return torch.full_like(x, fill_value=torch.nan) |
| 866 | + return torch.clamp(x, min, max, **kwargs) |
865 | 867 |
|
866 | | - return torch.clamp(x, min, max, **kwargs) |
| 868 | + # pytorch has (tensor, tensor, tensor) and (tensor, scalar, scalar) signatures, |
| 869 | + # but does not accept (tensor, scalar, tensor) |
| 870 | + a_min = min |
| 871 | + if min is not None and min_is_scalar: |
| 872 | + a_min = torch.as_tensor(min, dtype=x.dtype, device=x.device) |
| 873 | + |
| 874 | + a_max = max |
| 875 | + if max is not None and max_is_scalar: |
| 876 | + a_max = torch.as_tensor(max, dtype=x.dtype, device=x.device) |
| 877 | + |
| 878 | + return torch.clamp(x, a_min, a_max, **kwargs) |
867 | 879 |
|
868 | 880 |
|
869 | 881 | def sign(x: Array, /) -> Array: |
|
0 commit comments