diff --git a/monai/metrics/panoptic_quality.py b/monai/metrics/panoptic_quality.py index 7c9d59c264..03cef9d566 100644 --- a/monai/metrics/panoptic_quality.py +++ b/monai/metrics/panoptic_quality.py @@ -21,7 +21,7 @@ linear_sum_assignment, _ = optional_import("scipy.optimize", name="linear_sum_assignment") -__all__ = ["PanopticQualityMetric", "compute_panoptic_quality"] +__all__ = ["PanopticQualityMetric", "compute_panoptic_quality", "compute_mean_iou"] class PanopticQualityMetric(CumulativeIterationMetric): @@ -55,6 +55,8 @@ class PanopticQualityMetric(CumulativeIterationMetric): If set `match_iou_threshold` < 0.5, this function uses Munkres assignment to find the maximal amount of unique pairing. smooth_numerator: a small constant added to the numerator to avoid zero. + return_confusion_matrix: if True, returns raw confusion matrix values (tp, fp, fn, iou_sum) + instead of computed metrics. Default is False. """ @@ -65,6 +67,7 @@ def __init__( reduction: MetricReduction | str = MetricReduction.MEAN_BATCH, match_iou_threshold: float = 0.5, smooth_numerator: float = 1e-6, + return_confusion_matrix: bool = False, ) -> None: super().__init__() self.num_classes = num_classes @@ -72,12 +75,14 @@ def __init__( self.match_iou_threshold = match_iou_threshold self.smooth_numerator = smooth_numerator self.metric_name = ensure_tuple(metric_name) + self.return_confusion_matrix = return_confusion_matrix def _compute_tensor(self, y_pred: torch.Tensor, y: torch.Tensor) -> torch.Tensor: # type: ignore[override] """ Args: - y_pred: Predictions. It must be in the form of B2HW and have integer type. The first channel and the - second channel represent the instance predictions and classification predictions respectively. + y_pred: Predictions. It must be in the form of B2HW (2D) or B2HWD (3D) and have integer type. + The first channel and the second channel represent the instance predictions and classification + predictions respectively. y: ground truth. It must have the same shape as `y_pred` and have integer type. The first channel and the second channel represent the instance labels and classification labels respectively. Values in the second channel of `y_pred` and `y` should be in the range of 0 to `self.num_classes`, @@ -86,7 +91,7 @@ def _compute_tensor(self, y_pred: torch.Tensor, y: torch.Tensor) -> torch.Tensor Raises: ValueError: when `y_pred` and `y` have different shapes. ValueError: when `y_pred` and `y` have != 2 channels. - ValueError: when `y_pred` and `y` have != 4 dimensions. + ValueError: when `y_pred` and `y` have != 4 or 5 dimensions. """ if y_pred.shape != y.shape: @@ -98,8 +103,10 @@ def _compute_tensor(self, y_pred: torch.Tensor, y: torch.Tensor) -> torch.Tensor ) dims = y_pred.ndimension() - if dims != 4: - raise ValueError(f"y_pred should have 4 dimensions (batch, 2, h, w), got {dims}.") + if dims not in (4, 5): + raise ValueError( + f"y_pred should have 4 dimensions (batch, 2, h, w) or 5 dimensions (batch, 2, h, w, d), got {dims}." + ) batch_size = y_pred.shape[0] @@ -131,6 +138,10 @@ def aggregate(self, reduction: MetricReduction | str | None = None) -> torch.Ten available reduction modes: {``"none"``, ``"mean"``, ``"sum"``, ``"mean_batch"``, ``"sum_batch"``, ``"mean_channel"``, ``"sum_channel"``}, default to `self.reduction`. if "none", will not do reduction. + Returns: + If `return_confusion_matrix` is True, returns the raw confusion matrix [tp, fp, fn, iou_sum]. + Otherwise, returns the computed metric(s) based on `metric_name`. + """ data = self.get_buffer() if not isinstance(data, torch.Tensor): @@ -138,6 +149,11 @@ def aggregate(self, reduction: MetricReduction | str | None = None) -> torch.Ten # do metric reduction f, _ = do_metric_reduction(data, reduction or self.reduction) + + if self.return_confusion_matrix: + # Return raw confusion matrix values + return f + tp, fp, fn, iou_sum = f[..., 0], f[..., 1], f[..., 2], f[..., 3] results = [] for metric_name in self.metric_name: @@ -169,7 +185,7 @@ def compute_panoptic_quality( calculate PQ, and returning them directly enables further calculation over all images. Args: - pred: input data to compute, it must be in the form of HW and have integer type. + pred: input data to compute, it must be in the form of HW (2D) or HWD (3D) and have integer type. gt: ground truth. It must have the same shape as `pred` and have integer type. metric_name: output metric. The value can be "pq", "sq" or "rq". remap: whether to remap `pred` and `gt` to ensure contiguous ordering of instance id. @@ -294,3 +310,24 @@ def _check_panoptic_metric_name(metric_name: str) -> str: if metric_name in ["recognition_quality", "rq"]: return "rq" raise ValueError(f"metric name: {metric_name} is wrong, please use 'pq', 'sq' or 'rq'.") + + +def compute_mean_iou(confusion_matrix: torch.Tensor, smooth_numerator: float = 1e-6) -> torch.Tensor: + """Compute mean IoU from confusion matrix values. + + Args: + confusion_matrix: tensor with shape (..., 4) where the last dimension contains + [tp, fp, fn, iou_sum] as returned by `compute_panoptic_quality` with `output_confusion_matrix=True`. + smooth_numerator: a small constant added to the numerator to avoid zero. + + Returns: + Mean IoU computed as iou_sum / (tp + smooth_numerator). + + """ + if confusion_matrix.shape[-1] != 4: + raise ValueError( + f"confusion_matrix should have shape (..., 4) with [tp, fp, fn, iou_sum], " + f"got shape {confusion_matrix.shape}." + ) + tp, iou_sum = confusion_matrix[..., 0], confusion_matrix[..., 3] + return iou_sum / (tp + smooth_numerator) diff --git a/tests/metrics/test_compute_panoptic_quality.py b/tests/metrics/test_compute_panoptic_quality.py index 2c0946a822..4479108b05 100644 --- a/tests/metrics/test_compute_panoptic_quality.py +++ b/tests/metrics/test_compute_panoptic_quality.py @@ -88,6 +88,27 @@ [torch.as_tensor([[0.0, 1.0, 0.0], [0.6667, 0.0, 0.4]]), torch.as_tensor([[0.0, 0.5, 0.0], [0.3333, 0.0, 0.4]])], ] +# 3D test cases +sample_3d_pred = torch.as_tensor( + [[[[[2, 0], [1, 1]], [[0, 1], [2, 1]]], [[[0, 1], [3, 0]], [[1, 0], [1, 1]]]]], # instance channel # class channel + device=_device, +) + +sample_3d_gt = torch.as_tensor( + [[[[[2, 0], [0, 0]], [[2, 2], [2, 3]]], [[[3, 3], [3, 2]], [[2, 2], [3, 3]]]]], # instance channel # class channel + device=_device, +) + +# test 3D sample, num_classes = 3, match_iou_threshold = 0.5 +TEST_3D_CASE_1 = [{"num_classes": 3, "match_iou_threshold": 0.5}, sample_3d_pred, sample_3d_gt] + +# test confusion matrix return +TEST_CM_CASE_1 = [ + {"num_classes": 3, "match_iou_threshold": 0.5, "return_confusion_matrix": True}, + sample_3_pred, + sample_3_gt, +] + @SkipIfNoModule("scipy.optimize") class TestPanopticQualityMetric(unittest.TestCase): @@ -108,6 +129,98 @@ def test_value_class(self, input_params, y_pred, y_gt, expected_value): else: np.testing.assert_allclose(outputs.cpu().numpy(), np.asarray(expected_value), atol=1e-4) + def test_3d_support(self): + """Test that 3D input is properly supported.""" + input_params, y_pred, y_gt = TEST_3D_CASE_1 + metric = PanopticQualityMetric(**input_params) + # Should not raise an error for 3D input + metric(y_pred, y_gt) + outputs = metric.aggregate() + # Check that output is a tensor + self.assertIsInstance(outputs, torch.Tensor) + # Check that output shape is correct (num_classes,) + self.assertEqual(outputs.shape, torch.Size([3])) + + def test_confusion_matrix_return(self): + """Test that confusion matrix can be returned instead of computed metrics.""" + input_params, y_pred, y_gt = TEST_CM_CASE_1 + metric = PanopticQualityMetric(**input_params) + metric(y_pred, y_gt) + outputs = metric.aggregate() + # Check that output is a tensor with shape (batch_size, num_classes, 4) + self.assertIsInstance(outputs, torch.Tensor) + self.assertEqual(outputs.shape[-1], 4) + # Verify that values correspond to [tp, fp, fn, iou_sum] + tp, fp, fn, iou_sum = outputs[..., 0], outputs[..., 1], outputs[..., 2], outputs[..., 3] + # tp, fp, fn should be non-negative integers + self.assertTrue(torch.all(tp >= 0)) + self.assertTrue(torch.all(fp >= 0)) + self.assertTrue(torch.all(fn >= 0)) + # iou_sum should be non-negative float + self.assertTrue(torch.all(iou_sum >= 0)) + + def test_compute_mean_iou(self): + """Test mean IoU computation from confusion matrix.""" + from monai.metrics.panoptic_quality import compute_mean_iou + + input_params, y_pred, y_gt = TEST_CM_CASE_1 + metric = PanopticQualityMetric(**input_params) + metric(y_pred, y_gt) + confusion_matrix = metric.aggregate() + mean_iou = compute_mean_iou(confusion_matrix) + # Check shape is correct + self.assertEqual(mean_iou.shape, confusion_matrix.shape[:-1]) + # Check values are non-negative + self.assertTrue(torch.all(mean_iou >= 0)) + + def test_metric_name_filtering(self): + """Test that metric_name parameter properly filters output.""" + # Test single metric "sq" + metric_sq = PanopticQualityMetric(num_classes=3, metric_name="sq", match_iou_threshold=0.5) + metric_sq(sample_3_pred, sample_3_gt) + result_sq = metric_sq.aggregate() + self.assertIsInstance(result_sq, torch.Tensor) + self.assertEqual(result_sq.shape, torch.Size([3])) + + # Test single metric "rq" + metric_rq = PanopticQualityMetric(num_classes=3, metric_name="rq", match_iou_threshold=0.5) + metric_rq(sample_3_pred, sample_3_gt) + result_rq = metric_rq.aggregate() + self.assertIsInstance(result_rq, torch.Tensor) + self.assertEqual(result_rq.shape, torch.Size([3])) + + # Results should be different for different metrics + self.assertFalse(torch.allclose(result_sq, result_rq, atol=1e-4)) + + def test_invalid_3d_shape(self): + """Test that invalid 3D shapes are rejected.""" + # Shape with 3 dimensions should fail + invalid_pred = torch.randint(0, 5, (2, 2, 10)) + invalid_gt = torch.randint(0, 5, (2, 2, 10)) + metric = PanopticQualityMetric(num_classes=3) + with self.assertRaises(ValueError): + metric(invalid_pred, invalid_gt) + + # Shape with 6 dimensions should fail + invalid_pred = torch.randint(0, 5, (1, 2, 8, 8, 8, 8)) + invalid_gt = torch.randint(0, 5, (1, 2, 8, 8, 8, 8)) + with self.assertRaises(ValueError): + metric(invalid_pred, invalid_gt) + + def test_compute_mean_iou_invalid_shape(self): + """Test that compute_mean_iou raises ValueError for invalid shapes.""" + from monai.metrics.panoptic_quality import compute_mean_iou + + # Shape (..., 3) instead of (..., 4) should fail + invalid_confusion_matrix = torch.zeros(3, 3) + with self.assertRaises(ValueError): + compute_mean_iou(invalid_confusion_matrix) + + # Shape (..., 5) should also fail + invalid_confusion_matrix = torch.zeros(2, 5) + with self.assertRaises(ValueError): + compute_mean_iou(invalid_confusion_matrix) + if __name__ == "__main__": unittest.main()