Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 27 additions & 6 deletions src/torchmetrics/utilities/compute.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,13 +234,34 @@ def normalize_logits_if_needed(tensor: Tensor, normalization: Optional[Literal["
# decrease sigmoid on cpu .
if tensor.device == torch.device("cpu"):
if not torch.all((tensor >= 0) * (tensor <= 1)):
tensor = tensor.sigmoid() if normalization == "sigmoid" else torch.softmax(tensor, dim=1)
if normalization == "sigmoid":
# Apply numerically stable sigmoid by subtracting max to prevent overflow
# For large positive logits (>16.7 for float32, >36.7 for float64), sigmoid(x) overflows to 1.0
# Only apply stabilization when min value is also large (indicating all values will overflow)
# This avoids the issue where subtracting max creates artificial ties for widely spread values
min_val = tensor.min()
max_val = tensor.max()
if min_val > 15: # All values are large enough to potentially overflow
tensor = (tensor - max_val).sigmoid()
else:
tensor = tensor.sigmoid()
else:
tensor = torch.softmax(tensor, dim=1)
return tensor

# decrease device-host sync on device .
condition = ((tensor < 0) | (tensor > 1)).any()
return torch.where(
condition,
torch.sigmoid(tensor) if normalization == "sigmoid" else torch.softmax(tensor, dim=1),
tensor,
)
if normalization == "sigmoid":
# Apply numerically stable sigmoid by subtracting max to prevent overflow
# Only stabilize when all values are large to avoid creating artificial ties
min_val = tensor.min()
max_val = tensor.max()
# Use stable sigmoid only when minimum value is also large (all values will overflow)
needs_stabilization = min_val > 15
if needs_stabilization:
tensor_stable = tensor - max_val
return torch.where(condition, tensor_stable.sigmoid(), tensor)
else:
return torch.where(condition, tensor.sigmoid(), tensor)
else:
return torch.where(condition, torch.softmax(tensor, dim=1), tensor)
35 changes: 35 additions & 0 deletions tests/unittests/classification/test_auroc.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,41 @@ def test_binary_auroc_threshold_arg(self, inputs, threshold_fn):
assert torch.allclose(ap1, ap2)


def test_binary_auroc_large_logits():
"""Test that large logits don't cause numerical overflow in sigmoid.

Regression test for issue where very large logits (>16.7 for float32) cause
naive sigmoid to overflow to 1.0 for all values, losing ranking information
needed for AUROC.
"""
# Test case from the issue: all logits in range 97-100
preds = torch.tensor([98.0950, 98.4612, 98.1145, 98.1506, 97.6037, 98.9425, 99.2644,
99.5014, 99.7280, 99.6595, 99.6931, 99.4667, 99.9623, 99.8949, 99.8768])
target = torch.tensor([0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1])

result = binary_auroc(preds, target, thresholds=None)

# Expected AUROC is 0.9286 (as computed by sklearn)
# The ranking is preserved: lowest value (97.6037) corresponds to label 0,
# all others are higher and correspond to label 1
expected_sklearn = sk_roc_auc_score(target.numpy(), preds.numpy())
assert torch.allclose(result, torch.tensor(expected_sklearn), atol=1e-4)

# Test with even larger logits
preds_huge = torch.tensor([200.0, 201.0, 202.0, 203.0])
target_huge = torch.tensor([0, 0, 1, 1])
result_huge = binary_auroc(preds_huge, target_huge, thresholds=None)
expected_huge = sk_roc_auc_score(target_huge.numpy(), preds_huge.numpy())
assert torch.allclose(result_huge, torch.tensor(expected_huge), atol=1e-4)

# Test with mixed large and normal logits
preds_mixed = torch.tensor([-5.0, 0.0, 5.0, 50.0, 100.0])
target_mixed = torch.tensor([0, 0, 1, 1, 1])
result_mixed = binary_auroc(preds_mixed, target_mixed, thresholds=None)
expected_mixed = sk_roc_auc_score(target_mixed.numpy(), preds_mixed.numpy())
assert torch.allclose(result_mixed, torch.tensor(expected_mixed), atol=1e-4)


def _reference_sklearn_auroc_multiclass(preds, target, average="macro", ignore_index=None):
preds = np.moveaxis(preds.numpy(), 1, -1).reshape((-1, preds.shape[1]))
target = target.numpy().flatten()
Expand Down
Loading