diff --git a/CHANGELOG.md b/CHANGELOG.md index 9b830ffd5d1..b624d2b9e10 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -33,7 +33,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Fixed -- +- Fixed `_bincount` being less restrictive ([#3087](https://github.com/Lightning-AI/torchmetrics/pull/3087)) --- diff --git a/src/torchmetrics/utilities/data.py b/src/torchmetrics/utilities/data.py index dc67d5a4e34..1f71e0e75c6 100644 --- a/src/torchmetrics/utilities/data.py +++ b/src/torchmetrics/utilities/data.py @@ -20,7 +20,7 @@ from torch import Tensor from torchmetrics.utilities.exceptions import TorchMetricsUserWarning -from torchmetrics.utilities.imports import _TORCH_LESS_THAN_2_6, _XLA_AVAILABLE +from torchmetrics.utilities.imports import _TORCH_GREATER_EQUAL_2_1, _TORCH_LESS_THAN_2_6, _XLA_AVAILABLE from torchmetrics.utilities.prints import rank_zero_warn METRIC_EPS = 1e-6 @@ -178,10 +178,12 @@ def _squeeze_if_scalar(data: Any) -> Any: def _bincount(x: Tensor, minlength: Optional[int] = None) -> Tensor: """Implement custom bincount. - PyTorch currently does not support ``torch.bincount`` when running in deterministic mode on GPU or when running - MPS devices or when running on XLA device. This implementation therefore falls back to using a combination of - `torch.arange` and `torch.eq` in these scenarios. A small performance hit can expected and higher memory consumption - as `[batch_size, mincount]` tensor needs to be initialized compared to native ``torch.bincount``. + As of PyTorch v2.1, ``torch.bincount`` is supported in deterministic mode on CUDA + when no ``weights`` are provided and gradients are not required. However, this + operation remains unsupported or limited on some backends, such as MPS and XLA. + In those cases, we fall back to a manual implementation using `torch.arange` and `torch.eq`. + A small performance hit can expected and higher memory consumption as `[batch_size, mincount]` + tensor needs to be initialized compared to native ``torch.bincount``. Args: x: tensor to count @@ -199,7 +201,7 @@ def _bincount(x: Tensor, minlength: Optional[int] = None) -> Tensor: if minlength is None: minlength = len(torch.unique(x)) - if torch.are_deterministic_algorithms_enabled() or _XLA_AVAILABLE or x.is_mps: + if (not _TORCH_GREATER_EQUAL_2_1 and torch.are_deterministic_algorithms_enabled()) or _XLA_AVAILABLE or x.is_mps: mesh = torch.arange(minlength, device=x.device).repeat(len(x), 1) return torch.eq(x.reshape(-1, 1), mesh).sum(dim=0)