Skip to content
Open
Show file tree
Hide file tree
Changes from 12 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
-



- Fixed `GeneralizedDiceScore` to yield `NaN` if there are missing classes ([#2846](https://github.com/Lightning-AI/torchmetrics/issues/2846))

---

## [1.8.2] - 2025-09-03
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@ def _generalized_dice_compute(numerator: Tensor, denominator: Tensor, per_class:
if not per_class:
numerator = torch.sum(numerator, 1)
denominator = torch.sum(denominator, 1)
return _safe_divide(numerator, denominator, "nan")
return _safe_divide(numerator, denominator)


Expand Down
28 changes: 19 additions & 9 deletions src/torchmetrics/segmentation/generalized_dice.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,15 +100,16 @@ class GeneralizedDiceScore(Metric):
tensor(0.4992)
>>> gds = GeneralizedDiceScore(num_classes=3, per_class=True)
>>> gds(preds, target)
tensor([0.5001, 0.4993, 0.4982])
tensor([0.5000, 0.4993, 0.4983])
>>> gds = GeneralizedDiceScore(num_classes=3, per_class=True, include_background=False)
>>> gds(preds, target)
tensor([0.4993, 0.4982])
tensor([0.4993, 0.4983])

"""

score: Tensor
samples: Tensor
class_present: Tensor
numerator: Tensor
denominator: Tensor
full_state_update: bool = False
is_differentiable: bool = False
higher_is_better: bool = True
Expand All @@ -133,20 +134,29 @@ def __init__(
self.input_format = input_format

num_classes = num_classes - 1 if not include_background else num_classes
self.add_state("score", default=torch.zeros(num_classes if per_class else 1), dist_reduce_fx="sum")
self.add_state("samples", default=torch.zeros(1), dist_reduce_fx="sum")
self.add_state("class_present", default=torch.zeros(num_classes, dtype=torch.int), dist_reduce_fx="sum")
self.add_state("numerator", default=torch.zeros((0, num_classes)), dist_reduce_fx="cat")
self.add_state("denominator", default=torch.zeros((0, num_classes)), dist_reduce_fx="cat")

def update(self, preds: Tensor, target: Tensor) -> None:
"""Update the state with new data."""
numerator, denominator = _generalized_dice_update(
preds, target, self.num_classes, self.include_background, self.weight_type, self.input_format
)
self.score += _generalized_dice_compute(numerator, denominator, self.per_class).sum(dim=0)
self.samples += preds.shape[0]
self.numerator = torch.cat([self.numerator, numerator], dim=0)
self.denominator = torch.cat([self.denominator, denominator], dim=0)
if self.per_class:
class_mask = target.sum(dim=(0, *range(2, target.ndim))) > 0
self.class_present += class_mask[1:] if not self.include_background else class_mask
self.numerator = torch.sum(self.numerator, dim=0, keepdim=True)
self.denominator = torch.sum(self.denominator, dim=0, keepdim=True)

def compute(self) -> Tensor:
"""Compute the final generalized dice score."""
return self.score / self.samples
score = _generalized_dice_compute(self.numerator, self.denominator, self.per_class)
if not self.per_class:
return score.mean()
return torch.where(self.class_present > 0, score, torch.tensor(float("nan"))).squeeze()

def plot(self, val: Union[Tensor, Sequence[Tensor], None] = None, ax: Optional[_AX_TYPE] = None) -> _PLOT_OUT_TYPE:
"""Plot a single or multiple values from the metric.
Expand Down
51 changes: 51 additions & 0 deletions tests/unittests/segmentation/test_generalized_dice_score.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,8 @@ def _reference_generalized_dice(
class TestGeneralizedDiceScore(MetricTester):
"""Test class for `GeneralizedDiceScore` metric."""

atol = 2e-2

@pytest.mark.parametrize("ddp", [pytest.param(True, marks=pytest.mark.DDP), False])
def test_generalized_dice_class(self, preds, target, input_format, include_background, ddp):
"""Test class implementation of metric."""
Expand Down Expand Up @@ -122,3 +124,52 @@ def test_generalized_dice_functional(self, preds, target, input_format, include_
"input_format": input_format,
},
)


@pytest.mark.parametrize("per_class", [True, False])
@pytest.mark.parametrize("include_background", [True, False])
def test_samples_with_missing_classes(per_class, include_background):
"""Test GeneralizedDiceScore with missing classes in some samples."""
target = torch.zeros((4, NUM_CLASSES, 128, 128), dtype=torch.int8)
preds = torch.zeros((4, NUM_CLASSES, 128, 128), dtype=torch.int8)

target[0, 0, 0, 0] = 1
preds[0, 0, 0, 0] = 1

target[2, 1, 0, 0] = 1
preds[2, 1, 0, 0] = 1

metric = GeneralizedDiceScore(num_classes=NUM_CLASSES, per_class=per_class, include_background=include_background)
score = metric(preds, target)

target_slice = target if include_background else target[:, 1:]
output_classes = NUM_CLASSES if include_background else NUM_CLASSES - 1

if per_class:
assert len(score) == output_classes
for c in range(output_classes):
assert score[c] == 1.0 if target_slice[:, c].sum() > 0 else torch.isnan(score[c])
else:
assert score.isnan()


@pytest.mark.parametrize("per_class", [True, False])
@pytest.mark.parametrize("include_background", [True, False])
def test_generalized_dice_zero_denominator(per_class, include_background):
"""Check that GeneralizedDiceScore returns NaN when the denominator is all zero (no class present)."""
target = torch.full((4, NUM_CLASSES, 128, 128), 0, dtype=torch.int8)
preds = torch.full((4, NUM_CLASSES, 128, 128), 0, dtype=torch.int8)

metric = GeneralizedDiceScore(num_classes=NUM_CLASSES, per_class=per_class, include_background=include_background)

score = metric(preds, target)

if per_class and include_background:
assert len(score) == NUM_CLASSES
assert all(t.isnan() for t in score)
elif per_class and not include_background:
assert len(score) == NUM_CLASSES - 1
assert all(t.isnan() for t in score)
else:
# Expect scalar NaN
assert score.isnan()
Loading