Skip to content

Commit 278b9a9

Browse files
SkafteNickiBorda
authored andcommitted
[Metrics] Unification of FBeta (#4656)
* implementation * init files * more stable reduction * add tests * docs * remove old implementation * pep8 * changelog * Apply suggestions from code review Co-authored-by: Adrian Wälchli <[email protected]> Co-authored-by: Nicki Skafte <[email protected]> Co-authored-by: Teddy Koker <[email protected]> Co-authored-by: Adrian Wälchli <[email protected]> (cherry picked from commit 6831ba9)
1 parent c8afcab commit 278b9a9

File tree

13 files changed

+421
-216
lines changed

13 files changed

+421
-216
lines changed

CHANGELOG.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
1212
- Added casting to python types for numpy scalars when logging hparams ([#4647](https://github.com/PyTorchLightning/pytorch-lightning/pull/4647))
1313

1414

15+
- Added `F1` class metric ([#4656](https://github.com/PyTorchLightning/pytorch-lightning/pull/4656))
16+
17+
1518
### Changed
1619

1720
- Consistently use `step=trainer.global_step` in `LearningRateMonitor` independently of `logging_interval` ([#4376](https://github.com/PyTorchLightning/pytorch-lightning/pull/4376))
@@ -20,6 +23,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
2023
- Metric states are no longer as default added to `state_dict` ([#4685](https://github.com/PyTorchLightning/pytorch-lightning/pull/))
2124

2225

26+
- Renamed class metric `Fbeta` >> `FBeta` ([#4656](https://github.com/PyTorchLightning/pytorch-lightning/pull/4656))
27+
28+
2329
### Deprecated
2430

2531

docs/source/metrics.rst

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -221,10 +221,16 @@ Recall
221221
.. autoclass:: pytorch_lightning.metrics.classification.Recall
222222
:noindex:
223223

224-
Fbeta
224+
FBeta
225225
~~~~~
226226

227-
.. autoclass:: pytorch_lightning.metrics.classification.Fbeta
227+
.. autoclass:: pytorch_lightning.metrics.classification.FBeta
228+
:noindex:
229+
230+
F1
231+
~~
232+
233+
.. autoclass:: pytorch_lightning.metrics.classification.F1
228234
:noindex:
229235

230236
Regression Metrics
@@ -325,17 +331,17 @@ dice_score [func]
325331
:noindex:
326332

327333

328-
f1_score [func]
334+
f1 [func]
329335
~~~~~~~~~~~~~~~
330336

331-
.. autofunction:: pytorch_lightning.metrics.functional.classification.f1_score
337+
.. autofunction:: pytorch_lightning.metrics.functional.f1
332338
:noindex:
333339

334340

335-
fbeta_score [func]
341+
fbeta [func]
336342
~~~~~~~~~~~~~~~~~~
337343

338-
.. autofunction:: pytorch_lightning.metrics.functional.classification.fbeta_score
344+
.. autofunction:: pytorch_lightning.metrics.functional.fbeta
339345
:noindex:
340346

341347

pytorch_lightning/metrics/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,8 @@
1717
Accuracy,
1818
Precision,
1919
Recall,
20-
Fbeta
20+
F1,
21+
FBeta,
2122
)
2223

2324
from pytorch_lightning.metrics.regression import (

pytorch_lightning/metrics/classification/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,4 +13,4 @@
1313
# limitations under the License.
1414
from pytorch_lightning.metrics.classification.accuracy import Accuracy
1515
from pytorch_lightning.metrics.classification.precision_recall import Precision, Recall
16-
from pytorch_lightning.metrics.classification.f_beta import Fbeta
16+
from pytorch_lightning.metrics.classification.f_beta import FBeta, F1

pytorch_lightning/metrics/classification/f_beta.py

100644100755
Lines changed: 100 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -11,21 +11,18 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14-
import math
15-
import functools
16-
from abc import ABC, abstractmethod
17-
from typing import Any, Callable, Optional, Union
18-
from collections.abc import Mapping, Sequence
19-
from collections import namedtuple
14+
from typing import Any, Optional
2015

2116
import torch
22-
from torch import nn
17+
2318
from pytorch_lightning.metrics.metric import Metric
24-
from pytorch_lightning.metrics.functional.reduction import class_reduce
25-
from pytorch_lightning.metrics.classification.precision_recall import _input_format
19+
from pytorch_lightning.metrics.functional.f_beta import (
20+
_fbeta_update,
21+
_fbeta_compute
22+
)
2623

2724

28-
class Fbeta(Metric):
25+
class FBeta(Metric):
2926
"""
3027
Computes f_beta metric.
3128
@@ -51,7 +48,10 @@ class Fbeta(Metric):
5148
5249
average:
5350
* `'micro'` computes metric globally
54-
* `'macro'` computes metric for each class and then takes the mean
51+
* `'macro'` computes metric for each class and uniformly averages them
52+
* `'weighted'` computes metric for each class and does a weighted-average,
53+
where each class is weighted by their support (accounts for class imbalance)
54+
* `None` computes and returns the metric per class
5555
5656
multilabel: If predictions are from multilabel classification.
5757
compute_on_step:
@@ -64,29 +64,28 @@ class Fbeta(Metric):
6464
6565
Example:
6666
67-
>>> from pytorch_lightning.metrics import Fbeta
67+
>>> from pytorch_lightning.metrics import FBeta
6868
>>> target = torch.tensor([0, 1, 2, 0, 1, 2])
6969
>>> preds = torch.tensor([0, 2, 1, 0, 0, 1])
70-
>>> f_beta = Fbeta(num_classes=3, beta=0.5)
70+
>>> f_beta = FBeta(num_classes=3, beta=0.5)
7171
>>> f_beta(preds, target)
7272
tensor(0.3333)
7373
7474
"""
75+
7576
def __init__(
7677
self,
77-
num_classes: int = 1,
78-
beta: float = 1.,
78+
num_classes: int,
79+
beta: float = 1.0,
7980
threshold: float = 0.5,
80-
average: str = 'micro',
81+
average: str = "micro",
8182
multilabel: bool = False,
8283
compute_on_step: bool = True,
8384
dist_sync_on_step: bool = False,
8485
process_group: Optional[Any] = None,
8586
):
8687
super().__init__(
87-
compute_on_step=compute_on_step,
88-
dist_sync_on_step=dist_sync_on_step,
89-
process_group=process_group,
88+
compute_on_step=compute_on_step, dist_sync_on_step=dist_sync_on_step, process_group=process_group,
9089
)
9190

9291
self.num_classes = num_classes
@@ -95,8 +94,10 @@ def __init__(
9594
self.average = average
9695
self.multilabel = multilabel
9796

98-
assert self.average in ('micro', 'macro'), \
99-
"average passed to the function must be either `micro` or `macro`"
97+
allowed_average = ("micro", "macro", "weighted", None)
98+
if self.average not in allowed_average:
99+
raise ValueError('Argument `average` expected to be one of the following:'
100+
f' {allowed_average} but got {self.average}')
100101

101102
self.add_state("true_positives", default=torch.zeros(num_classes), dist_reduce_fx="sum")
102103
self.add_state("predicted_positives", default=torch.zeros(num_classes), dist_reduce_fx="sum")
@@ -110,25 +111,88 @@ def update(self, preds: torch.Tensor, target: torch.Tensor):
110111
preds: Predictions from model
111112
target: Ground truth values
112113
"""
113-
preds, target = _input_format(self.num_classes, preds, target, self.threshold, self.multilabel)
114+
true_positives, predicted_positives, actual_positives = _fbeta_update(
115+
preds, target, self.num_classes, self.threshold, self.multilabel
116+
)
114117

115-
self.true_positives += torch.sum(preds * target, dim=1)
116-
self.predicted_positives += torch.sum(preds, dim=1)
117-
self.actual_positives += torch.sum(target, dim=1)
118+
self.true_positives += true_positives
119+
self.predicted_positives += predicted_positives
120+
self.actual_positives += actual_positives
118121

119-
def compute(self):
122+
def compute(self) -> torch.Tensor:
120123
"""
121-
Computes accuracy over state.
124+
Computes fbeta over state.
122125
"""
123-
if self.average == 'micro':
124-
precision = self.true_positives.sum().float() / (self.predicted_positives.sum())
125-
recall = self.true_positives.sum().float() / (self.actual_positives.sum())
126+
return _fbeta_compute(self.true_positives, self.predicted_positives,
127+
self.actual_positives, self.beta, self.average)
128+
129+
130+
class F1(FBeta):
131+
"""
132+
Computes F1 metric. F1 metrics correspond to a equally weighted average of the
133+
precision and recall scores.
134+
135+
Works with binary, multiclass, and multilabel data.
136+
Accepts logits from a model output or integer class values in prediction.
137+
Works with multi-dimensional preds and target.
138+
139+
Forward accepts
140+
141+
- ``preds`` (float or long tensor): ``(N, ...)`` or ``(N, C, ...)`` where C is the number of classes
142+
- ``target`` (long tensor): ``(N, ...)``
143+
144+
If preds and target are the same shape and preds is a float tensor, we use the ``self.threshold`` argument.
145+
This is the case for binary and multi-label logits.
146+
147+
If preds has an extra dimension as in the case of multi-class scores we perform an argmax on ``dim=1``.
148+
149+
Args:
150+
num_classes: Number of classes in the dataset.
151+
threshold:
152+
Threshold value for binary or multi-label logits. default: 0.5
126153
127-
elif self.average == 'macro':
128-
precision = self.true_positives.float() / (self.predicted_positives)
129-
recall = self.true_positives.float() / (self.actual_positives)
154+
average:
155+
* `'micro'` computes metric globally
156+
* `'macro'` computes metric for each class and uniformly averages them
157+
* `'weighted'` computes metric for each class and does a weighted-average,
158+
where each class is weighted by their support (accounts for class imbalance)
159+
* `None` computes and returns the metric per class
130160
131-
num = (1 + self.beta ** 2) * precision * recall
132-
denom = self.beta ** 2 * precision + recall
161+
multilabel: If predictions are from multilabel classification.
162+
compute_on_step:
163+
Forward only calls ``update()`` and returns None if this is set to False. default: True
164+
dist_sync_on_step:
165+
Synchronize metric state across processes at each ``forward()``
166+
before returning the value at the step. default: False
167+
process_group:
168+
Specify the process group on which synchronization is called. default: None (which selects the entire world)
169+
170+
Example:
171+
>>> from pytorch_lightning.metrics import F1
172+
>>> target = torch.tensor([0, 1, 2, 0, 1, 2])
173+
>>> preds = torch.tensor([0, 2, 1, 0, 0, 1])
174+
>>> f1 = F1(num_classes=3)
175+
>>> f1(preds, target)
176+
tensor(0.3333)
177+
"""
133178

134-
return class_reduce(num=num, denom=denom, weights=None, class_reduction='macro')
179+
def __init__(
180+
self,
181+
num_classes: int = 1,
182+
beta: float = 1.0,
183+
threshold: float = 0.5,
184+
average: str = "micro",
185+
multilabel: bool = False,
186+
compute_on_step: bool = True,
187+
dist_sync_on_step: bool = False,
188+
process_group: Optional[Any] = None,
189+
):
190+
super().__init__(
191+
num_classes=num_classes,
192+
beta=1.0,
193+
threshold=threshold,
194+
average=average,
195+
compute_on_step=compute_on_step,
196+
dist_sync_on_step=dist_sync_on_step,
197+
process_group=process_group,
198+
)

pytorch_lightning/metrics/classification/precision_recall.py

Lines changed: 7 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -21,34 +21,7 @@
2121
import torch
2222
from torch import nn
2323
from pytorch_lightning.metrics.metric import Metric
24-
from pytorch_lightning.metrics.utils import to_onehot, METRIC_EPS
25-
26-
27-
def _input_format(num_classes: int, preds: torch.Tensor, target: torch.Tensor, threshold=0.5, multilabel=False):
28-
if not (len(preds.shape) == len(target.shape) or len(preds.shape) == len(target.shape) + 1):
29-
raise ValueError(
30-
"preds and target must have same number of dimensions, or one additional dimension for preds"
31-
)
32-
33-
if len(preds.shape) == len(target.shape) + 1:
34-
# multi class probabilites
35-
preds = torch.argmax(preds, dim=1)
36-
37-
if len(preds.shape) == len(target.shape) and preds.dtype == torch.long and num_classes > 1 and not multilabel:
38-
# multi-class
39-
preds = to_onehot(preds, num_classes=num_classes)
40-
target = to_onehot(target, num_classes=num_classes)
41-
42-
elif len(preds.shape) == len(target.shape) and preds.dtype == torch.float:
43-
# binary or multilabel probablities
44-
preds = (preds >= threshold).long()
45-
46-
# transpose class as first dim and reshape
47-
if len(preds.shape) > 1:
48-
preds = preds.transpose(1, 0)
49-
target = target.transpose(1, 0)
50-
51-
return preds.reshape(num_classes, -1), target.reshape(num_classes, -1)
24+
from pytorch_lightning.metrics.utils import to_onehot, METRIC_EPS, _input_format_classification_one_hot
5225

5326

5427
class Precision(Metric):
@@ -126,7 +99,9 @@ def __init__(
12699
self.add_state("predicted_positives", default=torch.zeros(num_classes), dist_reduce_fx="sum")
127100

128101
def update(self, preds: torch.Tensor, target: torch.Tensor):
129-
preds, target = _input_format(self.num_classes, preds, target, self.threshold, self.multilabel)
102+
preds, target = _input_format_classification_one_hot(
103+
self.num_classes, preds, target, self.threshold, self.multilabel
104+
)
130105

131106
# multiply because we are counting (1, 1) pair for true positives
132107
self.true_positives += torch.sum(preds * target, dim=1)
@@ -221,7 +196,9 @@ def update(self, preds: torch.Tensor, target: torch.Tensor):
221196
preds: Predictions from model
222197
target: Ground truth values
223198
"""
224-
preds, target = _input_format(self.num_classes, preds, target, self.threshold, self.multilabel)
199+
preds, target = _input_format_classification_one_hot(
200+
self.num_classes, preds, target, self.threshold, self.multilabel
201+
)
225202

226203
# multiply because we are counting (1, 1) pair for true positives
227204
self.true_positives += torch.sum(preds * target, dim=1)

pytorch_lightning/metrics/functional/__init__.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,6 @@
1818
average_precision,
1919
confusion_matrix,
2020
dice_score,
21-
f1_score,
22-
fbeta_score,
2321
multiclass_precision_recall_curve,
2422
multiclass_roc,
2523
precision,
@@ -44,3 +42,4 @@
4442
from pytorch_lightning.metrics.functional.mean_squared_log_error import mean_squared_log_error
4543
from pytorch_lightning.metrics.functional.psnr import psnr
4644
from pytorch_lightning.metrics.functional.ssim import ssim
45+
from pytorch_lightning.metrics.functional.f_beta import fbeta, f1

0 commit comments

Comments
 (0)