Skip to content

Commit 458bbad

Browse files
BordawilliamFalcon
andauthored
Avoid zeros in dice and iou (#2567)
* nones * fix * fix * test * test * test * fix * eps * tpu * eps * type * test tpu * Update __init__.py Co-authored-by: William Falcon <[email protected]>
1 parent f35337a commit 458bbad

File tree

5 files changed

+64
-100
lines changed

5 files changed

+64
-100
lines changed

pytorch_lightning/metrics/functional/classification.py

Lines changed: 10 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import sys
12
from collections import Sequence
23
from functools import wraps
34
from typing import Optional, Tuple, Callable
@@ -6,7 +7,7 @@
67
from torch.nn import functional as F
78

89
from pytorch_lightning.metrics.functional.reduction import reduce
9-
from pytorch_lightning.utilities import rank_zero_warn
10+
from pytorch_lightning.utilities import rank_zero_warn, FLOAT16_EPSILON
1011

1112

1213
def to_onehot(
@@ -893,8 +894,8 @@ def dice_score(
893894
... [0.05, 0.05, 0.85, 0.05],
894895
... [0.05, 0.05, 0.05, 0.85]])
895896
>>> target = torch.tensor([0, 1, 3, 2])
896-
>>> average_precision(pred, target)
897-
tensor(0.2500)
897+
>>> dice_score(pred, target)
898+
tensor(0.3333)
898899
899900
"""
900901
num_classes = pred.shape[1]
@@ -907,14 +908,9 @@ def dice_score(
907908
continue
908909

909910
tp, fp, tn, fn, sup = stat_scores(pred=pred, target=target, class_index=i)
910-
911-
denom = (2 * tp + fp + fn + 1e-15).to(torch.float)
912-
913-
if torch.isclose(denom, torch.zeros_like(denom)).any():
914-
# nan result
915-
score_cls = nan_score
916-
else:
917-
score_cls = (2 * tp).to(torch.float) / denom
911+
denom = (2 * tp + fp + fn).to(torch.float)
912+
# nan result
913+
score_cls = (2 * tp).to(torch.float) / denom if torch.is_nonzero(denom) else nan_score
918914

919915
scores[i - bg] += score_cls
920916
return reduce(scores, reduction=reduction)
@@ -963,5 +959,7 @@ def iou(
963959
tps = tps[1:]
964960
fps = fps[1:]
965961
fns = fns[1:]
966-
iou = tps / (fps + fns + tps + 1e-15)
962+
denom = fps + fns + tps
963+
denom[denom == 0] = torch.tensor(FLOAT16_EPSILON).type_as(denom)
964+
iou = tps / denom
967965
return reduce(iou, reduction=reduction)

pytorch_lightning/utilities/__init__.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
"""General utilities"""
22

3+
import numpy
34
import torch
45

56
from pytorch_lightning.utilities.distributed import rank_zero_only, rank_zero_warn, rank_zero_info
@@ -14,3 +15,7 @@
1415
APEX_AVAILABLE = True
1516

1617
NATIVE_AMP_AVALAIBLE = hasattr(torch.cuda, "amp") and hasattr(torch.cuda.amp, "autocast")
18+
19+
FLOAT16_EPSILON = numpy.finfo(numpy.float16).eps
20+
FLOAT32_EPSILON = numpy.finfo(numpy.float32).eps
21+
FLOAT64_EPSILON = numpy.finfo(numpy.float64).eps

tests/metrics/functional/test_classification.py

Lines changed: 19 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -47,57 +47,23 @@ def test_against_sklearn(sklearn_metric, torch_metric):
4747
"""Compare PL metrics to sklearn version."""
4848
device = 'cuda' if torch.cuda.is_available() else 'cpu'
4949

50-
pred = torch.randint(10, (500,), device=device)
51-
target = torch.randint(10, (500,), device=device)
50+
# iterate over different label counts in predictions and target
51+
for n_cls_pred, n_cls_target in [(10, 10), (5, 10), (10, 5)]:
52+
pred = torch.randint(n_cls_pred, (300,), device=device)
53+
target = torch.randint(n_cls_target, (300,), device=device)
5254

53-
assert torch.allclose(
54-
torch.tensor(sklearn_metric(target.cpu().detach().numpy(),
55-
pred.cpu().detach().numpy()), dtype=torch.float, device=device),
56-
torch_metric(pred, target))
57-
58-
pred = torch.randint(10, (200,), device=device)
59-
target = torch.randint(5, (200,), device=device)
60-
61-
assert torch.allclose(
62-
torch.tensor(sklearn_metric(target.cpu().detach().numpy(),
63-
pred.cpu().detach().numpy()), dtype=torch.float, device=device),
64-
torch_metric(pred, target))
65-
66-
pred = torch.randint(5, (200,), device=device)
67-
target = torch.randint(10, (200,), device=device)
68-
69-
assert torch.allclose(
70-
torch.tensor(sklearn_metric(target.cpu().detach().numpy(),
71-
pred.cpu().detach().numpy()), dtype=torch.float, device=device),
72-
torch_metric(pred, target))
55+
sk_score = sklearn_metric(target.cpu().detach().numpy(),
56+
pred.cpu().detach().numpy())
57+
sk_score = torch.tensor(sk_score, dtype=torch.float, device=device)
58+
pl_score = torch_metric(pred, target)
59+
assert torch.allclose(sk_score, pl_score)
7360

7461

7562
def test_onehot():
7663
test_tensor = torch.tensor([[0, 1, 2, 3, 4], [5, 6, 7, 8, 9]])
77-
expected = torch.tensor([
78-
[
79-
[1, 0, 0, 0, 0],
80-
[0, 1, 0, 0, 0],
81-
[0, 0, 1, 0, 0],
82-
[0, 0, 0, 1, 0],
83-
[0, 0, 0, 0, 1],
84-
[0, 0, 0, 0, 0],
85-
[0, 0, 0, 0, 0],
86-
[0, 0, 0, 0, 0],
87-
[0, 0, 0, 0, 0],
88-
[0, 0, 0, 0, 0]
89-
], [
90-
[0, 0, 0, 0, 0],
91-
[0, 0, 0, 0, 0],
92-
[0, 0, 0, 0, 0],
93-
[0, 0, 0, 0, 0],
94-
[0, 0, 0, 0, 0],
95-
[1, 0, 0, 0, 0],
96-
[0, 1, 0, 0, 0],
97-
[0, 0, 1, 0, 0],
98-
[0, 0, 0, 1, 0],
99-
[0, 0, 0, 0, 1]
100-
]
64+
expected = torch.stack([
65+
torch.cat([torch.eye(5, dtype=int), torch.zeros((5, 5), dtype=int)]),
66+
torch.cat([torch.zeros((5, 5), dtype=int), torch.eye(5, dtype=int)])
10167
])
10268

10369
assert test_tensor.shape == (2, 5)
@@ -116,30 +82,9 @@ def test_onehot():
11682

11783

11884
def test_to_categorical():
119-
test_tensor = torch.tensor([
120-
[
121-
[1, 0, 0, 0, 0],
122-
[0, 1, 0, 0, 0],
123-
[0, 0, 1, 0, 0],
124-
[0, 0, 0, 1, 0],
125-
[0, 0, 0, 0, 1],
126-
[0, 0, 0, 0, 0],
127-
[0, 0, 0, 0, 0],
128-
[0, 0, 0, 0, 0],
129-
[0, 0, 0, 0, 0],
130-
[0, 0, 0, 0, 0]
131-
], [
132-
[0, 0, 0, 0, 0],
133-
[0, 0, 0, 0, 0],
134-
[0, 0, 0, 0, 0],
135-
[0, 0, 0, 0, 0],
136-
[0, 0, 0, 0, 0],
137-
[1, 0, 0, 0, 0],
138-
[0, 1, 0, 0, 0],
139-
[0, 0, 1, 0, 0],
140-
[0, 0, 0, 1, 0],
141-
[0, 0, 0, 0, 1]
142-
]
85+
test_tensor = torch.stack([
86+
torch.cat([torch.eye(5, dtype=int), torch.zeros((5, 5), dtype=int)]),
87+
torch.cat([torch.zeros((5, 5), dtype=int), torch.eye(5, dtype=int)])
14388
]).to(torch.float)
14489

14590
expected = torch.tensor([[0, 1, 2, 3, 4], [5, 6, 7, 8, 9]])
@@ -260,7 +205,9 @@ def test_fbeta_score(pred, target, beta, exp_score):
260205

261206

262207
@pytest.mark.parametrize(['pred', 'target', 'exp_score'], [
208+
pytest.param([0., 0., 0., 0.], [1., 1., 1., 1.], [0.0, 0.0]),
263209
pytest.param([1., 0., 1., 0.], [0., 1., 1., 0.], [0.5, 0.5]),
210+
pytest.param([1., 0., 1., 0.], [1., 0., 1., 0.], [1.0, 1.0]),
264211
])
265212
def test_f1_score(pred, target, exp_score):
266213
score = f1_score(torch.tensor(pred), torch.tensor(target), reduction='none')
@@ -324,7 +271,7 @@ def test_roc_curve(pred, target, expected_tpr, expected_fpr):
324271

325272

326273
@pytest.mark.parametrize(['pred', 'target', 'expected'], [
327-
pytest.param([0, 0, 1, 1], [0, 0, 1, 1], 1.),
274+
pytest.param([0, 1, 0, 1], [0, 1, 0, 1], 1.),
328275
pytest.param([1, 1, 0, 0], [0, 0, 1, 1], 0.),
329276
pytest.param([1, 1, 1, 1], [1, 1, 0, 0], 0.5),
330277
pytest.param([1, 1, 0, 0], [1, 1, 0, 0], 1.),
@@ -355,7 +302,7 @@ def test_auc(x, y, expected):
355302
# The precision is then the fraction of positive whatever the recall
356303
# is, as there is only one threshold:
357304
pytest.param(torch.tensor([1, 1, 1, 1]), torch.tensor([0, 0, 0, 1]), .25),
358-
# With treshold .8 : 1 TP and 2 TN and one FN
305+
# With threshold 0.8 : 1 TP and 2 TN and one FN
359306
pytest.param(torch.tensor([.6, .7, .8, 9]), torch.tensor([1, 0, 0, 1]), .75),
360307
])
361308
def test_average_precision(scores, target, expected_score):

tests/metrics/functional/test_regression.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -14,43 +14,47 @@
1414

1515
@pytest.mark.parametrize(['pred', 'target', 'expected'], [
1616
pytest.param([0., 1, 2, 3], [0., 1, 2, 2], 0.25),
17-
pytest.param([4., 3, 2, 1], [1., 4, 3, 2], 3.0)
17+
pytest.param([4., 3, 2, 1], [1., 4, 3, 2], 3.0),
1818
])
1919
def test_mse(pred, target, expected):
2020
score = mse(torch.tensor(pred), torch.tensor(target))
2121
assert score.item() == expected
2222

2323

2424
@pytest.mark.parametrize(['pred', 'target', 'expected'], [
25+
pytest.param([0., 1, 2, 3], [0., 1, 2, 3], 0.0),
2526
pytest.param([0., 1, 2, 3], [0., 1, 2, 2], 0.5),
26-
pytest.param([4., 3, 2, 1], [1., 4, 3, 2], 1.7321)
27+
pytest.param([4., 3, 2, 1], [1., 4, 3, 2], 1.7321),
2728
])
2829
def test_rmse(pred, target, expected):
2930
score = rmse(torch.tensor(pred), torch.tensor(target))
3031
assert torch.allclose(score, torch.tensor(expected), atol=1e-3)
3132

3233

3334
@pytest.mark.parametrize(['pred', 'target', 'expected'], [
35+
pytest.param([0., 1, 2, 3], [0., 1, 2, 3], 0.0),
3436
pytest.param([0., 1, 2, 3], [0., 1, 2, 2], 0.25),
35-
pytest.param([4., 3, 2, 1], [1., 4, 3, 2], 1.5)
37+
pytest.param([4., 3, 2, 1], [1., 4, 3, 2], 1.5),
3638
])
3739
def test_mae(pred, target, expected):
3840
score = mae(torch.tensor(pred), torch.tensor(target))
3941
assert score.item() == expected
4042

4143

4244
@pytest.mark.parametrize(['pred', 'target', 'expected'], [
45+
pytest.param([0., 1, 2, 3], [0., 1, 2, 3], 0.0),
4346
pytest.param([0., 1, 2, 3], [0., 1, 2, 2], 0.0207),
44-
pytest.param([4., 3, 2, 1], [1., 4, 3, 2], 0.2841)
47+
pytest.param([4., 3, 2, 1], [1., 4, 3, 2], 0.2841),
4548
])
4649
def test_rmsle(pred, target, expected):
4750
score = rmsle(torch.tensor(pred), torch.tensor(target))
4851
assert torch.allclose(score, torch.tensor(expected), atol=1e-3)
4952

5053

5154
@pytest.mark.parametrize(['pred', 'target'], [
55+
pytest.param([0., 1., 2., 3.], [0., 1., 2., 3.]),
5256
pytest.param([0., 1., 2., 3.], [0., 1., 2., 2.]),
53-
pytest.param([4., 3., 2., 1.], [1., 4., 3., 2.])
57+
pytest.param([4., 3., 2., 1.], [1., 4., 3., 2.]),
5458
])
5559
def test_psnr_with_skimage(pred, target):
5660
score = psnr(pred=torch.tensor(pred),
@@ -61,7 +65,7 @@ def test_psnr_with_skimage(pred, target):
6165

6266
@pytest.mark.parametrize(['pred', 'target'], [
6367
pytest.param([0., 1., 2., 3.], [0., 1., 2., 2.]),
64-
pytest.param([4., 3., 2., 1.], [1., 4., 3., 2.])
68+
pytest.param([4., 3., 2., 1.], [1., 4., 3., 2.]),
6569
])
6670
def test_psnr_base_e_wider_range(pred, target):
6771
score = psnr(pred=torch.tensor(pred),

tests/metrics/test_sklearn.py

Lines changed: 20 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -43,38 +43,48 @@ def new_func(*args, **kwargs):
4343

4444
@pytest.mark.parametrize(['metric_class', 'sklearn_func', 'inputs'], [
4545
pytest.param(Accuracy(), sk_accuracy,
46-
{'y_pred': torch.randint(10, size=(128,)), 'y_true': torch.randint(10, size=(128,))},
46+
{'y_pred': torch.randint(10, size=(128,)),
47+
'y_true': torch.randint(10, size=(128,))},
4748
id='Accuracy'),
4849
pytest.param(AUC(), sk_auc,
4950
{'x': torch.arange(10, dtype=torch.float) / 10,
5051
'y': torch.tensor([0.2, 0.2, 0.2, 0.2, 0.2, 0.2, 0.3, 0.5, 0.6, 0.7])},
5152
id='AUC'),
5253
pytest.param(AveragePrecision(), sk_average_precision,
53-
{'y_score': torch.randint(2, size=(128,)), 'y_true': torch.randint(2, size=(128,))},
54+
{'y_score': torch.randint(2, size=(128,)),
55+
'y_true': torch.randint(2, size=(128,))},
5456
id='AveragePrecision'),
5557
pytest.param(ConfusionMatrix(), sk_confusion_matrix,
56-
{'y_pred': torch.randint(10, size=(128,)), 'y_true': torch.randint(10, size=(128,))},
58+
{'y_pred': torch.randint(10, size=(128,)),
59+
'y_true': torch.randint(10, size=(128,))},
5760
id='ConfusionMatrix'),
5861
pytest.param(F1(average='macro'), partial(sk_f1_score, average='macro'),
59-
{'y_pred': torch.randint(10, size=(128,)), 'y_true': torch.randint(10, size=(128,))},
62+
{'y_pred': torch.randint(10, size=(128,)),
63+
'y_true': torch.randint(10, size=(128,))},
6064
id='F1'),
6165
pytest.param(FBeta(beta=0.5, average='macro'), partial(sk_fbeta_score, beta=0.5, average='macro'),
62-
{'y_pred': torch.randint(10, size=(128,)), 'y_true': torch.randint(10, size=(128,))},
66+
{'y_pred': torch.randint(10, size=(128,)),
67+
'y_true': torch.randint(10, size=(128,))},
6368
id='FBeta'),
6469
pytest.param(Precision(average='macro'), partial(sk_precision, average='macro'),
65-
{'y_pred': torch.randint(10, size=(128,)), 'y_true': torch.randint(10, size=(128,))},
70+
{'y_pred': torch.randint(10, size=(128,)),
71+
'y_true': torch.randint(10, size=(128,))},
6672
id='Precision'),
6773
pytest.param(Recall(average='macro'), partial(sk_recall, average='macro'),
68-
{'y_pred': torch.randint(10, size=(128,)), 'y_true': torch.randint(10, size=(128,))},
74+
{'y_pred': torch.randint(10, size=(128,)),
75+
'y_true': torch.randint(10, size=(128,))},
6976
id='Recall'),
7077
pytest.param(PrecisionRecallCurve(), _xy_only(sk_precision_recall_curve),
71-
{'probas_pred': torch.rand(size=(128,)), 'y_true': torch.randint(2, size=(128,))},
78+
{'probas_pred': torch.rand(size=(128,)),
79+
'y_true': torch.randint(2, size=(128,))},
7280
id='PrecisionRecallCurve'),
7381
pytest.param(ROC(), _xy_only(sk_roc_curve),
74-
{'y_score': torch.rand(size=(128,)), 'y_true': torch.randint(2, size=(128,))},
82+
{'y_score': torch.rand(size=(128,)),
83+
'y_true': torch.randint(2, size=(128,))},
7584
id='ROC'),
7685
pytest.param(AUROC(), sk_roc_auc_score,
77-
{'y_score': torch.rand(size=(128,)), 'y_true': torch.randint(2, size=(128,))},
86+
{'y_score': torch.rand(size=(128,)),
87+
'y_true': torch.randint(2, size=(128,))},
7888
id='AUROC'),
7989
])
8090
def test_sklearn_metric(metric_class, sklearn_func, inputs):

0 commit comments

Comments
 (0)