diff --git a/CHANGELOG.md b/CHANGELOG.md index 7ccf8cae705..22c67b9ea69 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -21,6 +21,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Defaulting Dice score `average="macro"` ([#3042](https://github.com/Lightning-AI/torchmetrics/pull/3042)) +- Added `ignore_index` to Segmentation IoU metric ([#2747](https://github.com/Lightning-AI/torchmetrics/issues/2747)) + + ### Deprecated - diff --git a/src/torchmetrics/functional/segmentation/mean_iou.py b/src/torchmetrics/functional/segmentation/mean_iou.py index 747644c3a52..7f7cb3037fb 100644 --- a/src/torchmetrics/functional/segmentation/mean_iou.py +++ b/src/torchmetrics/functional/segmentation/mean_iou.py @@ -72,8 +72,13 @@ def _mean_iou_update( num_classes: Optional[int] = None, include_background: bool = False, input_format: Literal["one-hot", "index", "mixed"] = "one-hot", + ignore_index: Optional[int] = None, ) -> tuple[Tensor, Tensor]: """Update the intersection and union counts for the mean IoU computation.""" + if ignore_index is not None and input_format == "index": + idx = target == ignore_index + target, preds = target[~idx], preds[~idx] + preds, target = _mean_iou_reshape_args(preds, target, input_format) preds, target = _segmentation_inputs_format(preds, target, include_background, num_classes, input_format) @@ -102,6 +107,7 @@ def mean_iou( include_background: bool = True, per_class: bool = False, input_format: Literal["one-hot", "index", "mixed"] = "one-hot", + ignore_index: Optional[int] = None, ) -> Tensor: """Calculates the mean Intersection over Union (mIoU) for semantic segmentation. @@ -117,6 +123,8 @@ def mean_iou( input_format: What kind of input the function receives. Choose between ``"one-hot"`` for one-hot encoded tensors, ``"index"`` for index tensors or ``"mixed"`` for one one-hot encoded and one index tensor + ignore_index: Class index to ignore in the target. This class will be ignored + in both the intersection and union computation. Only used when ``input_format="index"``. Returns: The mean IoU score @@ -151,7 +159,7 @@ def mean_iou( """ _mean_iou_validate_args(num_classes, include_background, per_class, input_format) - intersection, union = _mean_iou_update(preds, target, num_classes, include_background, input_format) + intersection, union = _mean_iou_update(preds, target, num_classes, include_background, input_format, ignore_index) scores = _mean_iou_compute(intersection, union, zero_division="nan") valid_classes = union > 0 return scores.nan_to_num(-1.0) if per_class else scores.nansum(dim=-1) / valid_classes.sum(dim=-1) diff --git a/src/torchmetrics/segmentation/mean_iou.py b/src/torchmetrics/segmentation/mean_iou.py index 2cb06964ed7..1af14271a8f 100644 --- a/src/torchmetrics/segmentation/mean_iou.py +++ b/src/torchmetrics/segmentation/mean_iou.py @@ -61,6 +61,8 @@ class MeanIoU(Metric): input_format: What kind of input the function receives. Choose between ``"one-hot"`` for one-hot encoded tensors, ``"index"`` for index tensors or ``"mixed"`` for one one-hot encoded and one index tensor + ignore_index: Class index to ignore in the target. This class will be ignored + in both the intersection and union computation. Only used when ``input_format="index"`` kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info. Raises: @@ -110,6 +112,7 @@ def __init__( include_background: bool = True, per_class: bool = False, input_format: Literal["one-hot", "index", "mixed"] = "one-hot", + ignore_index: Optional[int] = None, **kwargs: Any, ) -> None: super().__init__(**kwargs) @@ -119,6 +122,7 @@ def __init__( self.per_class = per_class self.input_format = input_format self._is_initialized = False + self.ignore_index = ignore_index if num_classes is not None: num_classes = num_classes - 1 if not include_background else num_classes self.add_state("score", default=torch.zeros(num_classes if per_class else 1), dist_reduce_fx="sum") @@ -168,7 +172,7 @@ def update(self, preds: Tensor, target: Tensor) -> None: self._is_initialized = True intersection, union = _mean_iou_update( - preds, target, self.num_classes, self.include_background, self.input_format + preds, target, self.num_classes, self.include_background, self.input_format, self.ignore_index ) score = _mean_iou_compute(intersection, union, zero_division=0.0) # only update for classes that are present (i.e. union > 0) diff --git a/tests/unittests/segmentation/test_mean_iou.py b/tests/unittests/segmentation/test_mean_iou.py index 6270d8a5a0c..848b94a0864 100644 --- a/tests/unittests/segmentation/test_mean_iou.py +++ b/tests/unittests/segmentation/test_mean_iou.py @@ -22,7 +22,7 @@ from torchmetrics.functional.segmentation.mean_iou import mean_iou from torchmetrics.segmentation.mean_iou import MeanIoU from unittests import NUM_CLASSES -from unittests._helpers.testers import MetricTester +from unittests._helpers.testers import MetricTester, inject_ignore_index, remove_ignore_index from unittests.segmentation.inputs import ( _index_input_1, _mixed_input_1, @@ -41,27 +41,30 @@ def _reference_mean_iou( include_background: bool = True, per_class: bool = True, reduce: bool = True, + ignore_index: Optional[int] = None, ): """Calculate reference metric for `MeanIoU`.""" if input_format == "index": + target, preds = remove_ignore_index(target=target, preds=preds, ignore_index=ignore_index) preds = torch.nn.functional.one_hot(preds, num_classes=num_classes).movedim(-1, 1) target = torch.nn.functional.one_hot(target, num_classes=num_classes).movedim(-1, 1) elif input_format == "mixed": if preds.dim() == (target.dim() + 1): if torch.is_floating_point(preds): preds = preds.argmax(dim=1) - preds = torch.nn.functional.one_hot(preds, num_classes=NUM_CLASSES).movedim(-1, 1) - target = torch.nn.functional.one_hot(target, num_classes=NUM_CLASSES).movedim(-1, 1) + preds = torch.nn.functional.one_hot(preds, num_classes=num_classes).movedim(-1, 1) + target = torch.nn.functional.one_hot(target, num_classes=num_classes).movedim(-1, 1) elif (preds.dim() + 1) == target.dim(): if torch.is_floating_point(target): target = target.argmax(dim=1) - target = torch.nn.functional.one_hot(target, num_classes=NUM_CLASSES).movedim(-1, 1) - preds = torch.nn.functional.one_hot(preds, num_classes=NUM_CLASSES).movedim(-1, 1) + target = torch.nn.functional.one_hot(target, num_classes=num_classes).movedim(-1, 1) + preds = torch.nn.functional.one_hot(preds, num_classes=num_classes).movedim(-1, 1) val = compute_iou(preds, target, include_background=include_background) val[torch.isnan(val)] = 0.0 if reduce: return torch.mean(val, 0) if per_class else torch.mean(val) + return val @@ -83,11 +86,14 @@ def _reference_mean_iou( class TestMeanIoU(MetricTester): """Test class for `MeanIoU` metric.""" - atol = 1e-4 + atol = 1e-2 @pytest.mark.parametrize("ddp", [pytest.param(True, marks=pytest.mark.DDP), False]) @pytest.mark.parametrize("per_class", [True, False]) - def test_mean_iou_class(self, preds, target, input_format, num_classes, include_background, per_class, ddp): + @pytest.mark.parametrize("ignore_index", [None, 255]) + def test_mean_iou_class( + self, preds, target, input_format, num_classes, include_background, per_class, ddp, ignore_index + ): """Test class implementation of metric.""" if input_format in ["index", "mixed"] and num_classes is None: with pytest.raises( @@ -96,6 +102,9 @@ def test_mean_iou_class(self, preds, target, input_format, num_classes, include_ MeanIoU(num_classes=None, input_format="index") return + if input_format == "index" and ignore_index is not None: + target = inject_ignore_index(target, ignore_index) + self.run_class_metric_test( ddp=ddp, preds=preds, @@ -108,16 +117,19 @@ def test_mean_iou_class(self, preds, target, input_format, num_classes, include_ include_background=include_background, per_class=per_class, reduce=True, + ignore_index=ignore_index, ), metric_args={ "num_classes": num_classes, "include_background": include_background, "per_class": per_class, "input_format": input_format, + "ignore_index": ignore_index, }, ) - def test_mean_iou_functional(self, preds, target, input_format, num_classes, include_background): + @pytest.mark.parametrize("ignore_index", [None, 255]) + def test_mean_iou_functional(self, preds, target, input_format, num_classes, include_background, ignore_index): """Test functional implementation of metric.""" if input_format == "index" and num_classes is None: with pytest.raises( @@ -125,7 +137,8 @@ def test_mean_iou_functional(self, preds, target, input_format, num_classes, inc ): mean_iou(preds, target, num_classes=None, input_format="index") return - + if input_format == "index" and ignore_index is not None: + target = inject_ignore_index(target, ignore_index) self.run_functional_metric_test( preds=preds, target=target, @@ -136,12 +149,14 @@ def test_mean_iou_functional(self, preds, target, input_format, num_classes, inc num_classes=num_classes, include_background=include_background, reduce=False, + ignore_index=ignore_index, ), metric_args={ "num_classes": num_classes, "include_background": include_background, "per_class": True, "input_format": input_format, + "ignore_index": ignore_index, }, ) @@ -197,3 +212,27 @@ def test_mean_iou_perfect_prediction(): expected_ious = [1.0, 1.0, 1.0] for idx, (iou, iou_func) in enumerate(zip(miou_per_class, miou_func)): assert iou == iou_func == expected_ious[idx] + + +def test_mean_iou_ignore_index(): + """Test mean IoU with ignore_index.""" + metric = MeanIoU(num_classes=3, per_class=True, input_format="index", ignore_index=255) + target = torch.tensor([ + [0, 2, 255], + [1, 0, 255], + [2, 2, 255], + ]) + preds = torch.tensor([ + [0, 1, 1], + [1, 0, 0], + [0, 1, 2], + ]) + metric.update(preds, target) + miou_per_class = metric.compute() + miou_func = mean_iou(preds, target, num_classes=3, per_class=True, input_format="index", ignore_index=255).mean( + dim=0 + ) # reduce over batch dim + expected_ious = [0.6667, 0.3333, 0.0] + for idx, (iou, iou_func) in enumerate(zip(miou_per_class, miou_func)): + assert torch.allclose(iou, iou_func, atol=1e-4) + assert iou == pytest.approx(expected_ious[idx], rel=1e-3)