Skip to content
Draft
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Fixed MIFID incorrectly converts inputs to `byte` dtype with custom encoders ([#3064](https://github.com/Lightning-AI/torchmetrics/pull/3064))


- Fixed `_bincount` being less restrictive ([#3087](https://github.com/Lightning-AI/torchmetrics/pull/3087))

---

## [1.7.1] - 2025-04-06
Expand Down
14 changes: 8 additions & 6 deletions src/torchmetrics/utilities/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from torch import Tensor

from torchmetrics.utilities.exceptions import TorchMetricsUserWarning
from torchmetrics.utilities.imports import _TORCH_LESS_THAN_2_6, _XLA_AVAILABLE
from torchmetrics.utilities.imports import _TORCH_GREATER_EQUAL_2_1, _TORCH_LESS_THAN_2_6, _XLA_AVAILABLE
from torchmetrics.utilities.prints import rank_zero_warn

METRIC_EPS = 1e-6
Expand Down Expand Up @@ -178,10 +178,12 @@ def _squeeze_if_scalar(data: Any) -> Any:
def _bincount(x: Tensor, minlength: Optional[int] = None) -> Tensor:
"""Implement custom bincount.

PyTorch currently does not support ``torch.bincount`` when running in deterministic mode on GPU or when running
MPS devices or when running on XLA device. This implementation therefore falls back to using a combination of
`torch.arange` and `torch.eq` in these scenarios. A small performance hit can expected and higher memory consumption
as `[batch_size, mincount]` tensor needs to be initialized compared to native ``torch.bincount``.
As of PyTorch v2.1, ``torch.bincount`` is supported in deterministic mode on CUDA
when no ``weights`` are provided and gradients are not required. However, this
operation remains unsupported or limited on some backends, such as MPS and XLA.
In those cases, we fall back to a manual implementation using `torch.arange` and `torch.eq`.
A small performance hit can expected and higher memory consumption as `[batch_size, mincount]`
tensor needs to be initialized compared to native ``torch.bincount``.

Args:
x: tensor to count
Expand All @@ -199,7 +201,7 @@ def _bincount(x: Tensor, minlength: Optional[int] = None) -> Tensor:
if minlength is None:
minlength = len(torch.unique(x))

if torch.are_deterministic_algorithms_enabled() or _XLA_AVAILABLE or x.is_mps:
if (not _TORCH_GREATER_EQUAL_2_1 and torch.are_deterministic_algorithms_enabled()) or _XLA_AVAILABLE or x.is_mps:
mesh = torch.arange(minlength, device=x.device).repeat(len(x), 1)
return torch.eq(x.reshape(-1, 1), mesh).sum(dim=0)

Expand Down
Loading