Skip to content

Commit e0b7359

Browse files
j-dsouzaBorda
andauthored
[metrics] IoU Metric (#2062)
* add iou function * update stat scores * add iou class * add iou tests * chlog * Apply suggestions from code review * tests * docs * Apply suggestions from code review * docs Co-authored-by: Jirka <[email protected]> Co-authored-by: Jirka Borovec <[email protected]>
1 parent 79e1426 commit e0b7359

File tree

8 files changed

+146
-8
lines changed

8 files changed

+146
-8
lines changed

CHANGELOG.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
1313
- Added metrics
1414
* Base classes ([#1326](https://github.com/PyTorchLightning/pytorch-lightning/pull/1326), [#1877](https://github.com/PyTorchLightning/pytorch-lightning/pull/1877))
1515
* Sklearn metrics classes ([#1327](https://github.com/PyTorchLightning/pytorch-lightning/pull/1327))
16-
* Native torch metrics ([#1488](https://github.com/PyTorchLightning/pytorch-lightning/pull/1488))
16+
* Native torch metrics ([#1488](https://github.com/PyTorchLightning/pytorch-lightning/pull/1488), [#2062](https://github.com/PyTorchLightning/pytorch-lightning/pull/2062))
1717
* docs for all Metrics ([#2184](https://github.com/PyTorchLightning/pytorch-lightning/pull/2184), [#2209](https://github.com/PyTorchLightning/pytorch-lightning/pull/2209))
1818
* Regression metrics ([#2221](https://github.com/PyTorchLightning/pytorch-lightning/pull/2221))
1919
- Added type hints in `Trainer.fit()` and `Trainer.test()` to reflect that also a list of dataloaders can be passed in ([#1723](https://github.com/PyTorchLightning/pytorch-lightning/pull/1723))

docs/source/metrics.rst

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ Example::
3131
to a few metrics. Please feel free to create an issue/PR if you have a proposed
3232
metric or have found a bug.
3333

34-
--------------
34+
---
3535

3636
Implement a metric
3737
------------------
@@ -75,7 +75,7 @@ Here's an example showing how to implement a NumpyMetric
7575
.. autoclass:: pytorch_lightning.metrics.metric.NumpyMetric
7676
:noindex:
7777

78-
--------------
78+
---
7979

8080
Class Metrics
8181
-------------
@@ -207,6 +207,12 @@ MulticlassPrecisionRecall
207207
.. autoclass:: pytorch_lightning.metrics.classification.MulticlassPrecisionRecall
208208
:noindex:
209209

210+
IoU
211+
^^^
212+
213+
.. autoclass:: pytorch_lightning.metrics.classification.IoU
214+
:noindex:
215+
210216
RMSE
211217
^^^^
212218

@@ -219,7 +225,7 @@ RMSLE
219225
.. autoclass:: pytorch_lightning.metrics.regression.RMSE
220226
:noindex:
221227

222-
--------------
228+
---
223229

224230
Functional Metrics
225231
------------------
@@ -346,16 +352,23 @@ stat_scores (F)
346352
.. autofunction:: pytorch_lightning.metrics.functional.stat_scores
347353
:noindex:
348354

355+
iou (F)
356+
^^^^^^^
357+
358+
.. autofunction:: pytorch_lightning.metrics.functional.iou
359+
:noindex:
360+
349361
stat_scores_multiple_classes (F)
350362
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
351363

352364
.. autofunction:: pytorch_lightning.metrics.functional.stat_scores_multiple_classes
353365
:noindex:
354366

355-
----------------
367+
---
356368

357369
Metric pre-processing
358370
---------------------
371+
359372
Metric
360373

361374
to_categorical (F)
@@ -370,7 +383,7 @@ to_onehot (F)
370383
.. autofunction:: pytorch_lightning.metrics.functional.to_onehot
371384
:noindex:
372385

373-
----------------
386+
---
374387

375388
Sklearn interface
376389
-----------------

pytorch_lightning/metrics/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
MulticlassROC,
2121
Precision,
2222
PrecisionRecall,
23+
IoU,
2324
)
2425
from pytorch_lightning.metrics.sklearns import (
2526
AUC,
@@ -43,6 +44,7 @@
4344
'PrecisionRecallCurve',
4445
'ROC',
4546
'Recall',
47+
'IoU',
4648
]
4749
__regression_metrics = [
4850
'MSE',

pytorch_lightning/metrics/classification.py

Lines changed: 47 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,8 @@
1515
roc,
1616
multiclass_roc,
1717
multiclass_precision_recall_curve,
18-
dice_score
18+
dice_score,
19+
iou,
1920
)
2021
from pytorch_lightning.metrics.metric import TensorMetric, TensorCollectionMetric
2122

@@ -770,3 +771,48 @@ def forward(self, pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
770771
nan_score=self.nan_score,
771772
no_fg_score=self.no_fg_score,
772773
reduction=self.reduction)
774+
775+
776+
class IoU(TensorMetric):
777+
"""
778+
Computes the intersection over union.
779+
780+
Example:
781+
782+
>>> pred = torch.tensor([[0, 0, 0, 0, 0, 0, 0, 0],
783+
... [0, 0, 1, 1, 1, 0, 0, 0],
784+
... [0, 0, 0, 0, 0, 0, 0, 0]])
785+
>>> target = torch.tensor([[0, 0, 0, 0, 0, 0, 0, 0],
786+
... [0, 0, 0, 1, 1, 1, 0, 0],
787+
... [0, 0, 0, 0, 0, 0, 0, 0]])
788+
>>> metric = IoU()
789+
>>> metric(pred, target)
790+
tensor(0.7045)
791+
792+
"""
793+
def __init__(self,
794+
remove_bg: bool = False,
795+
reduction: str = 'elementwise_mean'):
796+
"""
797+
Args:
798+
remove_bg: Flag to state whether a background class has been included
799+
within input parameters. If true, will remove background class. If
800+
false, return IoU over all classes.
801+
Assumes that background is '0' class in input tensor
802+
reduction: a method for reducing IoU over labels (default: takes the mean)
803+
Available reduction methods:
804+
805+
- elementwise_mean: takes the mean
806+
- none: pass array
807+
- sum: add elements
808+
"""
809+
super().__init__(name='iou')
810+
self.remove_bg = remove_bg
811+
self.reduction = reduction
812+
813+
def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor,
814+
sample_weight: Optional[torch.Tensor] = None):
815+
"""
816+
Actual metric calculation.
817+
"""
818+
return iou(y_pred, y_true, remove_bg=self.remove_bg, reduction=self.reduction)

pytorch_lightning/metrics/functional/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,5 +17,6 @@
1717
stat_scores,
1818
stat_scores_multiple_classes,
1919
to_categorical,
20-
to_onehot
20+
to_onehot,
21+
iou,
2122
)

pytorch_lightning/metrics/functional/classification.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -901,3 +901,49 @@ def dice_score(
901901

902902
scores[i - bg] += score_cls
903903
return reduce(scores, reduction=reduction)
904+
905+
906+
def iou(pred: torch.Tensor, target: torch.Tensor,
907+
num_classes: Optional[int] = None, remove_bg: bool = False,
908+
reduction: str = 'elementwise_mean'):
909+
"""
910+
Intersection over union, or Jaccard index calculation.
911+
912+
Args:
913+
pred: Tensor containing predictions
914+
915+
target: Tensor containing targets
916+
917+
num_classes: Optionally specify the number of classes
918+
919+
remove_bg: Flag to state whether a background class has been included
920+
within input parameters. If true, will remove background class. If
921+
false, return IoU over all classes.
922+
Assumes that background is '0' class in input tensor
923+
924+
reduction: a method for reducing IoU over labels (default: takes the mean)
925+
Available reduction methods:
926+
- elementwise_mean: takes the mean
927+
- none: pass array
928+
- sum: add elements
929+
930+
Returns:
931+
IoU score : Tensor containing single value if reduction is
932+
'elementwise_mean', or number of classes if reduction is 'none'
933+
934+
Example:
935+
936+
>>> target = torch.randint(0, 1, (10, 25, 25))
937+
>>> pred = torch.tensor(target)
938+
>>> pred[2:5, 7:13, 9:15] = 1 - pred[2:5, 7:13, 9:15]
939+
>>> iou(pred, target)
940+
tensor(0.4914)
941+
942+
"""
943+
tps, fps, tns, fns, sups = stat_scores_multiple_classes(pred, target, num_classes)
944+
if remove_bg:
945+
tps = tps[1:]
946+
fps = fps[1:]
947+
fns = fns[1:]
948+
iou = tps / (fps + fns + tps)
949+
return reduce(iou, reduction=reduction)

tests/metrics/functional/test_classification.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
precision_recall_curve,
3232
roc,
3333
auc,
34+
iou,
3435
)
3536

3637

@@ -366,5 +367,22 @@ def test_dice_score(pred, target, expected):
366367
assert score == expected
367368

368369

370+
@pytest.mark.parametrize(['half_ones', 'reduction', 'remove_bg', 'expected'], [
371+
pytest.param(False, 'none', False, torch.Tensor([1, 1, 1])),
372+
pytest.param(False, 'elementwise_mean', False, torch.Tensor([1])),
373+
pytest.param(False, 'none', True, torch.Tensor([1, 1])),
374+
pytest.param(True, 'none', False, torch.Tensor([0.5, 0.5, 0.5])),
375+
pytest.param(True, 'elementwise_mean', False, torch.Tensor([0.5])),
376+
pytest.param(True, 'none', True, torch.Tensor([0.5, 0.5])),
377+
])
378+
def test_iou(half_ones, reduction, remove_bg, expected):
379+
pred = (torch.arange(120) % 3).view(-1, 1)
380+
target = (torch.arange(120) % 3).view(-1, 1)
381+
if half_ones:
382+
pred[:60] = 1
383+
iou_val = iou(pred, target, remove_bg=remove_bg, reduction=reduction)
384+
assert torch.allclose(iou_val, expected, atol=1e-9)
385+
386+
369387
# example data taken from
370388
# https://github.com/scikit-learn/scikit-learn/blob/master/sklearn/metrics/tests/test_ranking.py

tests/metrics/test_classification.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
MulticlassROC,
2020
MulticlassPrecisionRecall,
2121
DiceCoefficient,
22+
IoU,
2223
)
2324

2425

@@ -205,3 +206,14 @@ def test_dice_coefficient(include_background):
205206
dice = dice_coeff(torch.randint(0, 1, (10, 25, 25)),
206207
torch.randint(0, 1, (10, 25, 25)))
207208
assert isinstance(dice, torch.Tensor)
209+
210+
211+
@pytest.mark.parametrize('remove_bg', [True, False])
212+
def test_iou(remove_bg):
213+
iou = IoU(remove_bg=remove_bg)
214+
assert iou.name == 'iou'
215+
216+
score = iou(torch.randint(0, 1, (10, 25, 25)),
217+
torch.randint(0, 1, (10, 25, 25)))
218+
219+
assert isinstance(score, torch.Tensor)

0 commit comments

Comments
 (0)