1111# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212# See the License for the specific language governing permissions and
1313# limitations under the License.
14- import math
15- import functools
16- from abc import ABC , abstractmethod
17- from typing import Any , Callable , Optional , Union
18- from collections .abc import Mapping , Sequence
19- from collections import namedtuple
14+ from typing import Any , Optional
2015
2116import torch
22- from torch import nn
17+
2318from pytorch_lightning .metrics .metric import Metric
24- from pytorch_lightning .metrics .functional .reduction import class_reduce
25- from pytorch_lightning .metrics .classification .precision_recall import _input_format
19+ from pytorch_lightning .metrics .functional .f_beta import (
20+ _fbeta_update ,
21+ _fbeta_compute
22+ )
2623
2724
28- class Fbeta (Metric ):
25+ class FBeta (Metric ):
2926 """
3027 Computes f_beta metric.
3128
@@ -51,7 +48,10 @@ class Fbeta(Metric):
5148
5249 average:
5350 * `'micro'` computes metric globally
54- * `'macro'` computes metric for each class and then takes the mean
51+ * `'macro'` computes metric for each class and uniformly averages them
52+ * `'weighted'` computes metric for each class and does a weighted-average,
53+ where each class is weighted by their support (accounts for class imbalance)
54+ * `None` computes and returns the metric per class
5555
5656 multilabel: If predictions are from multilabel classification.
5757 compute_on_step:
@@ -64,29 +64,28 @@ class Fbeta(Metric):
6464
6565 Example:
6666
67- >>> from pytorch_lightning.metrics import Fbeta
67+ >>> from pytorch_lightning.metrics import FBeta
6868 >>> target = torch.tensor([0, 1, 2, 0, 1, 2])
6969 >>> preds = torch.tensor([0, 2, 1, 0, 0, 1])
70- >>> f_beta = Fbeta (num_classes=3, beta=0.5)
70+ >>> f_beta = FBeta (num_classes=3, beta=0.5)
7171 >>> f_beta(preds, target)
7272 tensor(0.3333)
7373
7474 """
75+
7576 def __init__ (
7677 self ,
77- num_classes : int = 1 ,
78- beta : float = 1. ,
78+ num_classes : int ,
79+ beta : float = 1.0 ,
7980 threshold : float = 0.5 ,
80- average : str = ' micro' ,
81+ average : str = " micro" ,
8182 multilabel : bool = False ,
8283 compute_on_step : bool = True ,
8384 dist_sync_on_step : bool = False ,
8485 process_group : Optional [Any ] = None ,
8586 ):
8687 super ().__init__ (
87- compute_on_step = compute_on_step ,
88- dist_sync_on_step = dist_sync_on_step ,
89- process_group = process_group ,
88+ compute_on_step = compute_on_step , dist_sync_on_step = dist_sync_on_step , process_group = process_group ,
9089 )
9190
9291 self .num_classes = num_classes
@@ -95,8 +94,10 @@ def __init__(
9594 self .average = average
9695 self .multilabel = multilabel
9796
98- assert self .average in ('micro' , 'macro' ), \
99- "average passed to the function must be either `micro` or `macro`"
97+ allowed_average = ("micro" , "macro" , "weighted" , None )
98+ if self .average not in allowed_average :
99+ raise ValueError ('Argument `average` expected to be one of the following:'
100+ f' { allowed_average } but got { self .average } ' )
100101
101102 self .add_state ("true_positives" , default = torch .zeros (num_classes ), dist_reduce_fx = "sum" )
102103 self .add_state ("predicted_positives" , default = torch .zeros (num_classes ), dist_reduce_fx = "sum" )
@@ -110,25 +111,88 @@ def update(self, preds: torch.Tensor, target: torch.Tensor):
110111 preds: Predictions from model
111112 target: Ground truth values
112113 """
113- preds , target = _input_format (self .num_classes , preds , target , self .threshold , self .multilabel )
114+ true_positives , predicted_positives , actual_positives = _fbeta_update (
115+ preds , target , self .num_classes , self .threshold , self .multilabel
116+ )
114117
115- self .true_positives += torch . sum ( preds * target , dim = 1 )
116- self .predicted_positives += torch . sum ( preds , dim = 1 )
117- self .actual_positives += torch . sum ( target , dim = 1 )
118+ self .true_positives += true_positives
119+ self .predicted_positives += predicted_positives
120+ self .actual_positives += actual_positives
118121
119- def compute (self ):
122+ def compute (self ) -> torch . Tensor :
120123 """
121- Computes accuracy over state.
124+ Computes fbeta over state.
122125 """
123- if self .average == 'micro' :
124- precision = self .true_positives .sum ().float () / (self .predicted_positives .sum ())
125- recall = self .true_positives .sum ().float () / (self .actual_positives .sum ())
126+ return _fbeta_compute (self .true_positives , self .predicted_positives ,
127+ self .actual_positives , self .beta , self .average )
128+
129+
130+ class F1 (FBeta ):
131+ """
132+ Computes F1 metric. F1 metrics correspond to a equally weighted average of the
133+ precision and recall scores.
134+
135+ Works with binary, multiclass, and multilabel data.
136+ Accepts logits from a model output or integer class values in prediction.
137+ Works with multi-dimensional preds and target.
138+
139+ Forward accepts
140+
141+ - ``preds`` (float or long tensor): ``(N, ...)`` or ``(N, C, ...)`` where C is the number of classes
142+ - ``target`` (long tensor): ``(N, ...)``
143+
144+ If preds and target are the same shape and preds is a float tensor, we use the ``self.threshold`` argument.
145+ This is the case for binary and multi-label logits.
146+
147+ If preds has an extra dimension as in the case of multi-class scores we perform an argmax on ``dim=1``.
148+
149+ Args:
150+ num_classes: Number of classes in the dataset.
151+ threshold:
152+ Threshold value for binary or multi-label logits. default: 0.5
126153
127- elif self .average == 'macro' :
128- precision = self .true_positives .float () / (self .predicted_positives )
129- recall = self .true_positives .float () / (self .actual_positives )
154+ average:
155+ * `'micro'` computes metric globally
156+ * `'macro'` computes metric for each class and uniformly averages them
157+ * `'weighted'` computes metric for each class and does a weighted-average,
158+ where each class is weighted by their support (accounts for class imbalance)
159+ * `None` computes and returns the metric per class
130160
131- num = (1 + self .beta ** 2 ) * precision * recall
132- denom = self .beta ** 2 * precision + recall
161+ multilabel: If predictions are from multilabel classification.
162+ compute_on_step:
163+ Forward only calls ``update()`` and returns None if this is set to False. default: True
164+ dist_sync_on_step:
165+ Synchronize metric state across processes at each ``forward()``
166+ before returning the value at the step. default: False
167+ process_group:
168+ Specify the process group on which synchronization is called. default: None (which selects the entire world)
169+
170+ Example:
171+ >>> from pytorch_lightning.metrics import F1
172+ >>> target = torch.tensor([0, 1, 2, 0, 1, 2])
173+ >>> preds = torch.tensor([0, 2, 1, 0, 0, 1])
174+ >>> f1 = F1(num_classes=3)
175+ >>> f1(preds, target)
176+ tensor(0.3333)
177+ """
133178
134- return class_reduce (num = num , denom = denom , weights = None , class_reduction = 'macro' )
179+ def __init__ (
180+ self ,
181+ num_classes : int = 1 ,
182+ beta : float = 1.0 ,
183+ threshold : float = 0.5 ,
184+ average : str = "micro" ,
185+ multilabel : bool = False ,
186+ compute_on_step : bool = True ,
187+ dist_sync_on_step : bool = False ,
188+ process_group : Optional [Any ] = None ,
189+ ):
190+ super ().__init__ (
191+ num_classes = num_classes ,
192+ beta = 1.0 ,
193+ threshold = threshold ,
194+ average = average ,
195+ compute_on_step = compute_on_step ,
196+ dist_sync_on_step = dist_sync_on_step ,
197+ process_group = process_group ,
198+ )
0 commit comments