diff --git a/CHANGELOG.md b/CHANGELOG.md index 7ccf8cae705..3bd98ac9b76 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -36,6 +36,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - + +- Fixed `GeneralizedDiceScore` to yield `NaN` if there are missing classes ([#2846](https://github.com/Lightning-AI/torchmetrics/issues/2846)) + --- ## [1.8.2] - 2025-09-03 diff --git a/src/torchmetrics/functional/segmentation/generalized_dice.py b/src/torchmetrics/functional/segmentation/generalized_dice.py index 1a110980a32..d8d6f17f8d4 100644 --- a/src/torchmetrics/functional/segmentation/generalized_dice.py +++ b/src/torchmetrics/functional/segmentation/generalized_dice.py @@ -90,7 +90,7 @@ def _generalized_dice_compute(numerator: Tensor, denominator: Tensor, per_class: if not per_class: numerator = torch.sum(numerator, 1) denominator = torch.sum(denominator, 1) - return _safe_divide(numerator, denominator) + return _safe_divide(numerator, denominator, "nan") def generalized_dice_score( diff --git a/src/torchmetrics/segmentation/generalized_dice.py b/src/torchmetrics/segmentation/generalized_dice.py index a047ecf2b18..6bd3c182410 100644 --- a/src/torchmetrics/segmentation/generalized_dice.py +++ b/src/torchmetrics/segmentation/generalized_dice.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from collections.abc import Sequence -from typing import Any, Optional, Union +from typing import Any, List, Optional, Union import torch from torch import Tensor @@ -24,6 +24,7 @@ _generalized_dice_validate_args, ) from torchmetrics.metric import Metric +from torchmetrics.utilities.data import dim_zero_cat from torchmetrics.utilities.imports import _MATPLOTLIB_AVAILABLE from torchmetrics.utilities.plot import _AX_TYPE, _PLOT_OUT_TYPE @@ -100,15 +101,16 @@ class GeneralizedDiceScore(Metric): tensor(0.4992) >>> gds = GeneralizedDiceScore(num_classes=3, per_class=True) >>> gds(preds, target) - tensor([0.5001, 0.4993, 0.4982]) + tensor([0.5000, 0.4993, 0.4983]) >>> gds = GeneralizedDiceScore(num_classes=3, per_class=True, include_background=False) >>> gds(preds, target) - tensor([0.4993, 0.4982]) + tensor([0.4993, 0.4983]) """ - score: Tensor - samples: Tensor + class_present: Tensor + numerator: List[Tensor] + denominator: List[Tensor] full_state_update: bool = False is_differentiable: bool = False higher_is_better: bool = True @@ -133,20 +135,26 @@ def __init__( self.input_format = input_format num_classes = num_classes - 1 if not include_background else num_classes - self.add_state("score", default=torch.zeros(num_classes if per_class else 1), dist_reduce_fx="sum") - self.add_state("samples", default=torch.zeros(1), dist_reduce_fx="sum") + self.add_state("numerator", default=[], dist_reduce_fx="cat") + self.add_state("denominator", default=[], dist_reduce_fx="cat") def update(self, preds: Tensor, target: Tensor) -> None: """Update the state with new data.""" numerator, denominator = _generalized_dice_update( preds, target, self.num_classes, self.include_background, self.weight_type, self.input_format ) - self.score += _generalized_dice_compute(numerator, denominator, self.per_class).sum(dim=0) - self.samples += preds.shape[0] + self.numerator.append(numerator) + self.denominator.append(denominator) def compute(self) -> Tensor: """Compute the final generalized dice score.""" - return self.score / self.samples + numerator = dim_zero_cat(self.numerator) + denominator = dim_zero_cat(self.denominator) + if self.per_class: + numerator = torch.sum(numerator, 0, keepdim=True) + denominator = torch.sum(denominator, 0, keepdim=True) + score = _generalized_dice_compute(dim_zero_cat(numerator), dim_zero_cat(denominator), self.per_class) + return score.mean(dim=0) def plot(self, val: Union[Tensor, Sequence[Tensor], None] = None, ax: Optional[_AX_TYPE] = None) -> _PLOT_OUT_TYPE: """Plot a single or multiple values from the metric. diff --git a/tests/unittests/segmentation/test_generalized_dice_score.py b/tests/unittests/segmentation/test_generalized_dice_score.py index fbecdbdb041..cc55bd58a65 100644 --- a/tests/unittests/segmentation/test_generalized_dice_score.py +++ b/tests/unittests/segmentation/test_generalized_dice_score.py @@ -82,6 +82,8 @@ def _reference_generalized_dice( class TestGeneralizedDiceScore(MetricTester): """Test class for `GeneralizedDiceScore` metric.""" + atol = 2e-3 + @pytest.mark.parametrize("ddp", [pytest.param(True, marks=pytest.mark.DDP), False]) def test_generalized_dice_class(self, preds, target, input_format, include_background, ddp): """Test class implementation of metric.""" @@ -122,3 +124,52 @@ def test_generalized_dice_functional(self, preds, target, input_format, include_ "input_format": input_format, }, ) + + +@pytest.mark.parametrize("per_class", [True, False]) +@pytest.mark.parametrize("include_background", [True, False]) +def test_samples_with_missing_classes(per_class, include_background): + """Test GeneralizedDiceScore with missing classes in some samples.""" + target = torch.zeros((4, NUM_CLASSES, 128, 128), dtype=torch.int8) + preds = torch.zeros((4, NUM_CLASSES, 128, 128), dtype=torch.int8) + + target[0, 0, 0, 0] = 1 + preds[0, 0, 0, 0] = 1 + + target[2, 1, 0, 0] = 1 + preds[2, 1, 0, 0] = 1 + + metric = GeneralizedDiceScore(num_classes=NUM_CLASSES, per_class=per_class, include_background=include_background) + score = metric(preds, target) + + target_slice = target if include_background else target[:, 1:] + output_classes = NUM_CLASSES if include_background else NUM_CLASSES - 1 + + if per_class: + assert len(score) == output_classes + for c in range(output_classes): + assert score[c] == 1.0 if target_slice[:, c].sum() > 0 else torch.isnan(score[c]) + else: + assert score.isnan() + + +@pytest.mark.parametrize("per_class", [True, False]) +@pytest.mark.parametrize("include_background", [True, False]) +def test_generalized_dice_zero_denominator(per_class, include_background): + """Check that GeneralizedDiceScore returns NaN when the denominator is all zero (no class present).""" + target = torch.full((4, NUM_CLASSES, 128, 128), 0, dtype=torch.int8) + preds = torch.full((4, NUM_CLASSES, 128, 128), 0, dtype=torch.int8) + + metric = GeneralizedDiceScore(num_classes=NUM_CLASSES, per_class=per_class, include_background=include_background) + + score = metric(preds, target) + + if per_class and include_background: + assert len(score) == NUM_CLASSES + assert all(t.isnan() for t in score) + elif per_class and not include_background: + assert len(score) == NUM_CLASSES - 1 + assert all(t.isnan() for t in score) + else: + # Expect scalar NaN + assert score.isnan()