Skip to content
Open
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
-



- Fixed `GeneralizedDiceScore` to yield `NaN` if there are missing classes ([#2846](https://github.com/Lightning-AI/torchmetrics/issues/2846))

---

## [1.8.2] - 2025-09-03
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@ def _generalized_dice_compute(numerator: Tensor, denominator: Tensor, per_class:
if not per_class:
numerator = torch.sum(numerator, 1)
denominator = torch.sum(denominator, 1)
return _safe_divide(numerator, denominator, "nan")
return _safe_divide(numerator, denominator)


Expand Down
14 changes: 11 additions & 3 deletions src/torchmetrics/segmentation/generalized_dice.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,15 +100,16 @@ class GeneralizedDiceScore(Metric):
tensor(0.4992)
>>> gds = GeneralizedDiceScore(num_classes=3, per_class=True)
>>> gds(preds, target)
tensor([0.5001, 0.4993, 0.4982])
tensor([5.0008, 4.9930, 4.9825])
>>> gds = GeneralizedDiceScore(num_classes=3, per_class=True, include_background=False)
>>> gds(preds, target)
tensor([0.4993, 0.4982])
tensor([4.9930, 4.9825])

"""

score: Tensor
samples: Tensor
class_present: Tensor
full_state_update: bool = False
is_differentiable: bool = False
higher_is_better: bool = True
Expand All @@ -135,6 +136,7 @@ def __init__(
num_classes = num_classes - 1 if not include_background else num_classes
self.add_state("score", default=torch.zeros(num_classes if per_class else 1), dist_reduce_fx="sum")
self.add_state("samples", default=torch.zeros(1), dist_reduce_fx="sum")
self.add_state("class_present", default=torch.zeros(num_classes, dtype=torch.int), dist_reduce_fx="sum")

def update(self, preds: Tensor, target: Tensor) -> None:
"""Update the state with new data."""
Expand All @@ -144,9 +146,15 @@ def update(self, preds: Tensor, target: Tensor) -> None:
self.score += _generalized_dice_compute(numerator, denominator, self.per_class).sum(dim=0)
self.samples += preds.shape[0]

if self.per_class:
class_mask = target.sum(dim=(0, *range(2, target.ndim))) > 0
self.class_present += class_mask[1:] if not self.include_background else class_mask

def compute(self) -> Tensor:
"""Compute the final generalized dice score."""
return self.score / self.samples
if not self.per_class:
return self.score / self.samples
return torch.where(self.class_present > 0, self.score, torch.tensor(float("nan")))

def plot(self, val: Union[Tensor, Sequence[Tensor], None] = None, ax: Optional[_AX_TYPE] = None) -> _PLOT_OUT_TYPE:
"""Plot a single or multiple values from the metric.
Expand Down
49 changes: 49 additions & 0 deletions tests/unittests/segmentation/test_generalized_dice_score.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,3 +122,52 @@ def test_generalized_dice_functional(self, preds, target, input_format, include_
"input_format": input_format,
},
)


@pytest.mark.parametrize("per_class", [True, False])
@pytest.mark.parametrize("include_background", [True, False])
def test_samples_with_missing_classes(per_class, include_background):
"""Test GeneralizedDiceScore with missing classes in some samples."""
target = torch.zeros((4, NUM_CLASSES, 128, 128), dtype=torch.int8)
preds = torch.zeros((4, NUM_CLASSES, 128, 128), dtype=torch.int8)

target[0, 0, 0, 0] = 1
preds[0, 0, 0, 0] = 1

target[2, 1, 0, 0] = 1
preds[2, 1, 0, 0] = 1

metric = GeneralizedDiceScore(num_classes=NUM_CLASSES, per_class=per_class, include_background=include_background)
score = metric(preds, target)

target_slice = target if include_background else target[:, 1:]
output_classes = NUM_CLASSES if include_background else NUM_CLASSES - 1

if per_class:
assert len(score) == output_classes
for c in range(output_classes):
assert score[c] == 1.0 if target_slice[:, c].sum() > 0 else torch.isnan(score[c])
else:
assert score.isnan()


@pytest.mark.parametrize("per_class", [True, False])
@pytest.mark.parametrize("include_background", [True, False])
def test_generalized_dice_zero_denominator(per_class, include_background):
"""Check that GeneralizedDiceScore returns NaN when the denominator is all zero (no class present)."""
target = torch.full((4, NUM_CLASSES, 128, 128), 0, dtype=torch.int8)
preds = torch.full((4, NUM_CLASSES, 128, 128), 0, dtype=torch.int8)

metric = GeneralizedDiceScore(num_classes=NUM_CLASSES, per_class=per_class, include_background=include_background)

score = metric(preds, target)

if per_class and include_background:
assert len(score) == NUM_CLASSES
assert all(t.isnan() for t in score)
elif per_class and not include_background:
assert len(score) == NUM_CLASSES - 1
assert all(t.isnan() for t in score)
else:
# Expect scalar NaN
assert score.isnan()
Loading