Skip to content
Draft
Show file tree
Hide file tree
Changes from 15 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Fixed MIFID incorrectly converts inputs to `byte` dtype with custom encoders ([#3064](https://github.com/Lightning-AI/torchmetrics/pull/3064))


- Fixed `_bincount` being less restrictive ([#3087](https://github.com/Lightning-AI/torchmetrics/pull/3087))


- Fixed `ignore_index` in `MultilabelExactMatch` ([#3085](https://github.com/Lightning-AI/torchmetrics/pull/3085))


Expand Down
14 changes: 8 additions & 6 deletions src/torchmetrics/utilities/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from torch import Tensor

from torchmetrics.utilities.exceptions import TorchMetricsUserWarning
from torchmetrics.utilities.imports import _TORCH_LESS_THAN_2_6, _XLA_AVAILABLE
from torchmetrics.utilities.imports import _TORCH_GREATER_EQUAL_2_1, _TORCH_LESS_THAN_2_6, _XLA_AVAILABLE
from torchmetrics.utilities.prints import rank_zero_warn

METRIC_EPS = 1e-6
Expand Down Expand Up @@ -178,10 +178,12 @@ def _squeeze_if_scalar(data: Any) -> Any:
def _bincount(x: Tensor, minlength: Optional[int] = None) -> Tensor:
"""Implement custom bincount.

PyTorch currently does not support ``torch.bincount`` when running in deterministic mode on GPU or when running
MPS devices or when running on XLA device. This implementation therefore falls back to using a combination of
`torch.arange` and `torch.eq` in these scenarios. A small performance hit can expected and higher memory consumption
as `[batch_size, mincount]` tensor needs to be initialized compared to native ``torch.bincount``.
As of PyTorch v2.1, ``torch.bincount`` is supported in deterministic mode on CUDA
when no ``weights`` are provided and gradients are not required. However, this
operation remains unsupported or limited on some backends, such as MPS and XLA.
In those cases, we fall back to a manual implementation using `torch.arange` and `torch.eq`.
A small performance hit can expected and higher memory consumption as `[batch_size, mincount]`
tensor needs to be initialized compared to native ``torch.bincount``.

Args:
x: tensor to count
Expand All @@ -199,7 +201,7 @@ def _bincount(x: Tensor, minlength: Optional[int] = None) -> Tensor:
if minlength is None:
minlength = len(torch.unique(x))

if torch.are_deterministic_algorithms_enabled() or _XLA_AVAILABLE or x.is_mps:
if (not _TORCH_GREATER_EQUAL_2_1 and torch.are_deterministic_algorithms_enabled()) or _XLA_AVAILABLE or x.is_mps:
mesh = torch.arange(minlength, device=x.device).repeat(len(x), 1)
return torch.eq(x.reshape(-1, 1), mesh).sum(dim=0)

Expand Down
Loading