diff --git a/src/torchmetrics/utilities/compute.py b/src/torchmetrics/utilities/compute.py index 647cfb8371a..def75725b42 100644 --- a/src/torchmetrics/utilities/compute.py +++ b/src/torchmetrics/utilities/compute.py @@ -234,13 +234,34 @@ def normalize_logits_if_needed(tensor: Tensor, normalization: Optional[Literal[" # decrease sigmoid on cpu . if tensor.device == torch.device("cpu"): if not torch.all((tensor >= 0) * (tensor <= 1)): - tensor = tensor.sigmoid() if normalization == "sigmoid" else torch.softmax(tensor, dim=1) + if normalization == "sigmoid": + # Apply numerically stable sigmoid by subtracting max to prevent overflow + # For large positive logits (>16.7 for float32, >36.7 for float64), sigmoid(x) overflows to 1.0 + # Only apply stabilization when min value is also large (indicating all values will overflow) + # This avoids the issue where subtracting max creates artificial ties for widely spread values + min_val = tensor.min() + max_val = tensor.max() + if min_val > 15: # All values are large enough to potentially overflow + tensor = (tensor - max_val).sigmoid() + else: + tensor = tensor.sigmoid() + else: + tensor = torch.softmax(tensor, dim=1) return tensor # decrease device-host sync on device . condition = ((tensor < 0) | (tensor > 1)).any() - return torch.where( - condition, - torch.sigmoid(tensor) if normalization == "sigmoid" else torch.softmax(tensor, dim=1), - tensor, - ) + if normalization == "sigmoid": + # Apply numerically stable sigmoid by subtracting max to prevent overflow + # Only stabilize when all values are large to avoid creating artificial ties + min_val = tensor.min() + max_val = tensor.max() + # Use stable sigmoid only when minimum value is also large (all values will overflow) + needs_stabilization = min_val > 15 + if needs_stabilization: + tensor_stable = tensor - max_val + return torch.where(condition, tensor_stable.sigmoid(), tensor) + else: + return torch.where(condition, tensor.sigmoid(), tensor) + else: + return torch.where(condition, torch.softmax(tensor, dim=1), tensor) diff --git a/tests/unittests/classification/test_auroc.py b/tests/unittests/classification/test_auroc.py index df12366b941..a30d281e989 100644 --- a/tests/unittests/classification/test_auroc.py +++ b/tests/unittests/classification/test_auroc.py @@ -140,6 +140,41 @@ def test_binary_auroc_threshold_arg(self, inputs, threshold_fn): assert torch.allclose(ap1, ap2) +def test_binary_auroc_large_logits(): + """Test that large logits don't cause numerical overflow in sigmoid. + + Regression test for issue where very large logits (>16.7 for float32) cause + naive sigmoid to overflow to 1.0 for all values, losing ranking information + needed for AUROC. + """ + # Test case from the issue: all logits in range 97-100 + preds = torch.tensor([98.0950, 98.4612, 98.1145, 98.1506, 97.6037, 98.9425, 99.2644, + 99.5014, 99.7280, 99.6595, 99.6931, 99.4667, 99.9623, 99.8949, 99.8768]) + target = torch.tensor([0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]) + + result = binary_auroc(preds, target, thresholds=None) + + # Expected AUROC is 0.9286 (as computed by sklearn) + # The ranking is preserved: lowest value (97.6037) corresponds to label 0, + # all others are higher and correspond to label 1 + expected_sklearn = sk_roc_auc_score(target.numpy(), preds.numpy()) + assert torch.allclose(result, torch.tensor(expected_sklearn), atol=1e-4) + + # Test with even larger logits + preds_huge = torch.tensor([200.0, 201.0, 202.0, 203.0]) + target_huge = torch.tensor([0, 0, 1, 1]) + result_huge = binary_auroc(preds_huge, target_huge, thresholds=None) + expected_huge = sk_roc_auc_score(target_huge.numpy(), preds_huge.numpy()) + assert torch.allclose(result_huge, torch.tensor(expected_huge), atol=1e-4) + + # Test with mixed large and normal logits + preds_mixed = torch.tensor([-5.0, 0.0, 5.0, 50.0, 100.0]) + target_mixed = torch.tensor([0, 0, 1, 1, 1]) + result_mixed = binary_auroc(preds_mixed, target_mixed, thresholds=None) + expected_mixed = sk_roc_auc_score(target_mixed.numpy(), preds_mixed.numpy()) + assert torch.allclose(result_mixed, torch.tensor(expected_mixed), atol=1e-4) + + def _reference_sklearn_auroc_multiclass(preds, target, average="macro", ignore_index=None): preds = np.moveaxis(preds.numpy(), 1, -1).reshape((-1, preds.shape[1])) target = target.numpy().flatten()