Skip to content

torchmetrics.functional.classification.binary_auroc gives wrong results when logits are large #2819

@zbingsong

Description

@zbingsong

🐛 Bug

torchmetrics.functional.classification.binary_auroc always gives 0.5 when all logits are large. This seems to be caused by a floating point precision error with sigmoid.

To Reproduce

Code sample
import torch
import torchmetrics.functional.classification

preds = torch.tensor([98.0950, 98.4612, 98.1145, 98.1506, 97.6037, 98.9425, 99.2644, 99.5014, 99.7280, 99.6595, 99.6931, 99.4667, 99.9623, 99.8949, 99.8768])
labels = torch.tensor([0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1])

torchmetrics.functional.classification.binary_auroc(preds, labels)

Output:

tensor(0.5000)

Expected behavior

AUROC of the above example should be 0.9286, as computed by sklearn.

import sklearn.metrics

sklearn.metrics.roc_auc_score(labels, preds)

Output:

0.9285714285714286

Environment

  • Windows 11 24H2
  • Python version 3.10.11
  • TorchMetrics version 1.4.3
  • PyTorch version 2.4.1+cu124

Additional context

This appears to be a problem of floating point precision with sigmoid at line 185 in function _binary_precision_recall_curve_format in file torchmetrics/src/torchmetrics/functional/classification/precision_recall_curve.py.

I extracted all the necessary functions and made a miniature binary_auroc function that uses exactly the same algorithm (works for the above example, did not test for other examples):

def binary_auroc(
    preds: torch.Tensor,
    target: torch.Tensor,
) -> torch.Tensor:
    preds = preds.sigmoid()
    print(preds)

    desc_score_indices = torch.argsort(preds, descending=True)
    preds = preds[desc_score_indices]
    target = target[desc_score_indices]
    # print(preds, target)

    # pred typically has many tied values. Here we extract
    # the indices associated with the distinct values. We also
    # concatenate a value for the end of the curve.
    distinct_value_indices = torch.nonzero(preds[1:] - preds[:-1], as_tuple=True)[0]
    # print(distinct_value_indices)
    threshold_idxs = torch.nn.functional.pad(distinct_value_indices, [0, 1], value=target.size(0) - 1)
    # print(threshold_idxs)
    tps = torch.cumsum(target, dim=0)[threshold_idxs]
    fps = 1 + threshold_idxs - tps
    # print(tps, fps)

    # Add an extra threshold position to make sure that the curve starts at (0, 0)
    tps = torch.cat([torch.zeros(1, dtype=tps.dtype, device=tps.device), tps])
    fps = torch.cat([torch.zeros(1, dtype=fps.dtype, device=fps.device), fps])
    tpr = tps / tps[-1]
    fpr = fps / fps[-1]
    # print(fpr, tpr)
    return torch.trapezoid(tpr, fpr, dim=-1)

Output:

tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.])

tensor(0.5000)

preds = preds.sigmoid() converts all logits to 1 as if all the logits are the same, which is not the case. The maximum magnitude of a logit must be less than 36.74 for double or 16.64 for float32 to avoid being converted to exactly 1.

Suggested fix

It's probably a good idea to scale the raw logits before sigmoid, something like below:

preds /= torch.max(torch.abs(preds)) # scales max element to 1
preds = preds.sigmoid()

All functions that applies sigmoid to raw ogits will need such a fix.

Metadata

Metadata

Assignees

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions