Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 44 additions & 7 deletions monai/metrics/panoptic_quality.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@

linear_sum_assignment, _ = optional_import("scipy.optimize", name="linear_sum_assignment")

__all__ = ["PanopticQualityMetric", "compute_panoptic_quality"]
__all__ = ["PanopticQualityMetric", "compute_panoptic_quality", "compute_mean_iou"]


class PanopticQualityMetric(CumulativeIterationMetric):
Expand Down Expand Up @@ -55,6 +55,8 @@ class PanopticQualityMetric(CumulativeIterationMetric):
If set `match_iou_threshold` < 0.5, this function uses Munkres assignment to find the
maximal amount of unique pairing.
smooth_numerator: a small constant added to the numerator to avoid zero.
return_confusion_matrix: if True, returns raw confusion matrix values (tp, fp, fn, iou_sum)
instead of computed metrics. Default is False.

"""

Expand All @@ -65,19 +67,22 @@ def __init__(
reduction: MetricReduction | str = MetricReduction.MEAN_BATCH,
match_iou_threshold: float = 0.5,
smooth_numerator: float = 1e-6,
return_confusion_matrix: bool = False,
) -> None:
super().__init__()
self.num_classes = num_classes
self.reduction = reduction
self.match_iou_threshold = match_iou_threshold
self.smooth_numerator = smooth_numerator
self.metric_name = ensure_tuple(metric_name)
self.return_confusion_matrix = return_confusion_matrix

def _compute_tensor(self, y_pred: torch.Tensor, y: torch.Tensor) -> torch.Tensor: # type: ignore[override]
"""
Args:
y_pred: Predictions. It must be in the form of B2HW and have integer type. The first channel and the
second channel represent the instance predictions and classification predictions respectively.
y_pred: Predictions. It must be in the form of B2HW (2D) or B2HWD (3D) and have integer type.
The first channel and the second channel represent the instance predictions and classification
predictions respectively.
y: ground truth. It must have the same shape as `y_pred` and have integer type. The first channel and the
second channel represent the instance labels and classification labels respectively.
Values in the second channel of `y_pred` and `y` should be in the range of 0 to `self.num_classes`,
Expand All @@ -86,7 +91,7 @@ def _compute_tensor(self, y_pred: torch.Tensor, y: torch.Tensor) -> torch.Tensor
Raises:
ValueError: when `y_pred` and `y` have different shapes.
ValueError: when `y_pred` and `y` have != 2 channels.
ValueError: when `y_pred` and `y` have != 4 dimensions.
ValueError: when `y_pred` and `y` have != 4 or 5 dimensions.

"""
if y_pred.shape != y.shape:
Expand All @@ -98,8 +103,10 @@ def _compute_tensor(self, y_pred: torch.Tensor, y: torch.Tensor) -> torch.Tensor
)

dims = y_pred.ndimension()
if dims != 4:
raise ValueError(f"y_pred should have 4 dimensions (batch, 2, h, w), got {dims}.")
if dims not in (4, 5):
raise ValueError(
f"y_pred should have 4 dimensions (batch, 2, h, w) or 5 dimensions (batch, 2, h, w, d), got {dims}."
)

batch_size = y_pred.shape[0]

Expand Down Expand Up @@ -131,13 +138,22 @@ def aggregate(self, reduction: MetricReduction | str | None = None) -> torch.Ten
available reduction modes: {``"none"``, ``"mean"``, ``"sum"``, ``"mean_batch"``, ``"sum_batch"``,
``"mean_channel"``, ``"sum_channel"``}, default to `self.reduction`. if "none", will not do reduction.

Returns:
If `return_confusion_matrix` is True, returns the raw confusion matrix [tp, fp, fn, iou_sum].
Otherwise, returns the computed metric(s) based on `metric_name`.

"""
data = self.get_buffer()
if not isinstance(data, torch.Tensor):
raise ValueError("the data to aggregate must be PyTorch Tensor.")

# do metric reduction
f, _ = do_metric_reduction(data, reduction or self.reduction)

if self.return_confusion_matrix:
# Return raw confusion matrix values
return f

tp, fp, fn, iou_sum = f[..., 0], f[..., 1], f[..., 2], f[..., 3]
results = []
for metric_name in self.metric_name:
Expand Down Expand Up @@ -169,7 +185,7 @@ def compute_panoptic_quality(
calculate PQ, and returning them directly enables further calculation over all images.

Args:
pred: input data to compute, it must be in the form of HW and have integer type.
pred: input data to compute, it must be in the form of HW (2D) or HWD (3D) and have integer type.
gt: ground truth. It must have the same shape as `pred` and have integer type.
metric_name: output metric. The value can be "pq", "sq" or "rq".
remap: whether to remap `pred` and `gt` to ensure contiguous ordering of instance id.
Expand Down Expand Up @@ -294,3 +310,24 @@ def _check_panoptic_metric_name(metric_name: str) -> str:
if metric_name in ["recognition_quality", "rq"]:
return "rq"
raise ValueError(f"metric name: {metric_name} is wrong, please use 'pq', 'sq' or 'rq'.")


def compute_mean_iou(confusion_matrix: torch.Tensor, smooth_numerator: float = 1e-6) -> torch.Tensor:
"""Compute mean IoU from confusion matrix values.

Args:
confusion_matrix: tensor with shape (..., 4) where the last dimension contains
[tp, fp, fn, iou_sum] as returned by `compute_panoptic_quality` with `output_confusion_matrix=True`.
smooth_numerator: a small constant added to the numerator to avoid zero.

Returns:
Mean IoU computed as iou_sum / (tp + smooth_numerator).

"""
if confusion_matrix.shape[-1] != 4:
raise ValueError(
f"confusion_matrix should have shape (..., 4) with [tp, fp, fn, iou_sum], "
f"got shape {confusion_matrix.shape}."
)
tp, iou_sum = confusion_matrix[..., 0], confusion_matrix[..., 3]
return iou_sum / (tp + smooth_numerator)
123 changes: 123 additions & 0 deletions tests/metrics/test_compute_panoptic_quality.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,37 @@
[torch.as_tensor([[0.0, 1.0, 0.0], [0.6667, 0.0, 0.4]]), torch.as_tensor([[0.0, 0.5, 0.0], [0.3333, 0.0, 0.4]])],
]

# 3D test cases
sample_3d_pred = torch.as_tensor(
[
[
[[[2, 0], [1, 1]], [[0, 1], [2, 1]]], # instance channel
[[[0, 1], [3, 0]], [[1, 0], [1, 1]]], # class channel
]
],
device=_device,
)

sample_3d_gt = torch.as_tensor(
[
[
[[[2, 0], [0, 0]], [[2, 2], [2, 3]]], # instance channel
[[[3, 3], [3, 2]], [[2, 2], [3, 3]]], # class channel
]
],
device=_device,
)

# test 3D sample, num_classes = 3, match_iou_threshold = 0.5
TEST_3D_CASE_1 = [{"num_classes": 3, "match_iou_threshold": 0.5}, sample_3d_pred, sample_3d_gt]

# test confusion matrix return
TEST_CM_CASE_1 = [
{"num_classes": 3, "match_iou_threshold": 0.5, "return_confusion_matrix": True},
sample_3_pred,
sample_3_gt,
]


@SkipIfNoModule("scipy.optimize")
class TestPanopticQualityMetric(unittest.TestCase):
Expand All @@ -108,6 +139,98 @@ def test_value_class(self, input_params, y_pred, y_gt, expected_value):
else:
np.testing.assert_allclose(outputs.cpu().numpy(), np.asarray(expected_value), atol=1e-4)

def test_3d_support(self):
"""Test that 3D input is properly supported."""
input_params, y_pred, y_gt = TEST_3D_CASE_1
metric = PanopticQualityMetric(**input_params)
# Should not raise an error for 3D input
metric(y_pred, y_gt)
outputs = metric.aggregate()
# Check that output is a tensor
self.assertIsInstance(outputs, torch.Tensor)
# Check that output shape is correct (num_classes,)
self.assertEqual(outputs.shape, torch.Size([3]))

def test_confusion_matrix_return(self):
"""Test that confusion matrix can be returned instead of computed metrics."""
input_params, y_pred, y_gt = TEST_CM_CASE_1
metric = PanopticQualityMetric(**input_params)
metric(y_pred, y_gt)
outputs = metric.aggregate()
# Check that output is a tensor with shape (batch_size, num_classes, 4)
self.assertIsInstance(outputs, torch.Tensor)
self.assertEqual(outputs.shape[-1], 4)
# Verify that values correspond to [tp, fp, fn, iou_sum]
tp, fp, fn, iou_sum = outputs[..., 0], outputs[..., 1], outputs[..., 2], outputs[..., 3]
# tp, fp, fn should be non-negative integers
self.assertTrue(torch.all(tp >= 0))
self.assertTrue(torch.all(fp >= 0))
self.assertTrue(torch.all(fn >= 0))
# iou_sum should be non-negative float
self.assertTrue(torch.all(iou_sum >= 0))

def test_compute_mean_iou(self):
"""Test mean IoU computation from confusion matrix."""
from monai.metrics.panoptic_quality import compute_mean_iou

input_params, y_pred, y_gt = TEST_CM_CASE_1
metric = PanopticQualityMetric(**input_params)
metric(y_pred, y_gt)
confusion_matrix = metric.aggregate()
mean_iou = compute_mean_iou(confusion_matrix)
# Check shape is correct
self.assertEqual(mean_iou.shape, confusion_matrix.shape[:-1])
# Check values are non-negative
self.assertTrue(torch.all(mean_iou >= 0))

def test_metric_name_filtering(self):
"""Test that metric_name parameter properly filters output."""
# Test single metric "sq"
metric_sq = PanopticQualityMetric(num_classes=3, metric_name="sq", match_iou_threshold=0.5)
metric_sq(sample_3_pred, sample_3_gt)
result_sq = metric_sq.aggregate()
self.assertIsInstance(result_sq, torch.Tensor)
self.assertEqual(result_sq.shape, torch.Size([3]))

# Test single metric "rq"
metric_rq = PanopticQualityMetric(num_classes=3, metric_name="rq", match_iou_threshold=0.5)
metric_rq(sample_3_pred, sample_3_gt)
result_rq = metric_rq.aggregate()
self.assertIsInstance(result_rq, torch.Tensor)
self.assertEqual(result_rq.shape, torch.Size([3]))

# Results should be different for different metrics
self.assertFalse(torch.allclose(result_sq, result_rq, atol=1e-4))

def test_invalid_3d_shape(self):
"""Test that invalid 3D shapes are rejected."""
# Shape with 3 dimensions should fail
invalid_pred = torch.randint(0, 5, (2, 2, 10))
invalid_gt = torch.randint(0, 5, (2, 2, 10))
metric = PanopticQualityMetric(num_classes=3)
with self.assertRaises(ValueError):
metric(invalid_pred, invalid_gt)

# Shape with 6 dimensions should fail
invalid_pred = torch.randint(0, 5, (1, 2, 8, 8, 8, 8))
invalid_gt = torch.randint(0, 5, (1, 2, 8, 8, 8, 8))
with self.assertRaises(ValueError):
metric(invalid_pred, invalid_gt)

def test_compute_mean_iou_invalid_shape(self):
"""Test that compute_mean_iou raises ValueError for invalid shapes."""
from monai.metrics.panoptic_quality import compute_mean_iou

# Shape (..., 3) instead of (..., 4) should fail
invalid_confusion_matrix = torch.zeros(3, 3)
with self.assertRaises(ValueError):
compute_mean_iou(invalid_confusion_matrix)

# Shape (..., 5) should also fail
invalid_confusion_matrix = torch.zeros(2, 5)
with self.assertRaises(ValueError):
compute_mean_iou(invalid_confusion_matrix)


if __name__ == "__main__":
unittest.main()
Loading