-
Notifications
You must be signed in to change notification settings - Fork 464
Description
🐛 Bug
torchmetrics.functional.classification.binary_auroc
always gives 0.5 when all logits are large. This seems to be caused by a floating point precision error with sigmoid.
To Reproduce
Code sample
import torch
import torchmetrics.functional.classification
preds = torch.tensor([98.0950, 98.4612, 98.1145, 98.1506, 97.6037, 98.9425, 99.2644, 99.5014, 99.7280, 99.6595, 99.6931, 99.4667, 99.9623, 99.8949, 99.8768])
labels = torch.tensor([0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1])
torchmetrics.functional.classification.binary_auroc(preds, labels)
Output:
tensor(0.5000)
Expected behavior
AUROC of the above example should be 0.9286, as computed by sklearn
.
import sklearn.metrics
sklearn.metrics.roc_auc_score(labels, preds)
Output:
0.9285714285714286
Environment
- Windows 11 24H2
- Python version 3.10.11
- TorchMetrics version 1.4.3
- PyTorch version 2.4.1+cu124
Additional context
This appears to be a problem of floating point precision with sigmoid at line 185 in function _binary_precision_recall_curve_format
in file torchmetrics/src/torchmetrics/functional/classification/precision_recall_curve.py
.
I extracted all the necessary functions and made a miniature binary_auroc
function that uses exactly the same algorithm (works for the above example, did not test for other examples):
def binary_auroc(
preds: torch.Tensor,
target: torch.Tensor,
) -> torch.Tensor:
preds = preds.sigmoid()
print(preds)
desc_score_indices = torch.argsort(preds, descending=True)
preds = preds[desc_score_indices]
target = target[desc_score_indices]
# print(preds, target)
# pred typically has many tied values. Here we extract
# the indices associated with the distinct values. We also
# concatenate a value for the end of the curve.
distinct_value_indices = torch.nonzero(preds[1:] - preds[:-1], as_tuple=True)[0]
# print(distinct_value_indices)
threshold_idxs = torch.nn.functional.pad(distinct_value_indices, [0, 1], value=target.size(0) - 1)
# print(threshold_idxs)
tps = torch.cumsum(target, dim=0)[threshold_idxs]
fps = 1 + threshold_idxs - tps
# print(tps, fps)
# Add an extra threshold position to make sure that the curve starts at (0, 0)
tps = torch.cat([torch.zeros(1, dtype=tps.dtype, device=tps.device), tps])
fps = torch.cat([torch.zeros(1, dtype=fps.dtype, device=fps.device), fps])
tpr = tps / tps[-1]
fpr = fps / fps[-1]
# print(fpr, tpr)
return torch.trapezoid(tpr, fpr, dim=-1)
Output:
tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.])
tensor(0.5000)
preds = preds.sigmoid()
converts all logits to 1 as if all the logits are the same, which is not the case. The maximum magnitude of a logit must be less than 36.74 for double
or 16.64 for float32
to avoid being converted to exactly 1.
Suggested fix
It's probably a good idea to scale the raw logits before sigmoid, something like below:
preds /= torch.max(torch.abs(preds)) # scales max element to 1
preds = preds.sigmoid()
All functions that applies sigmoid to raw ogits will need such a fix.