1- import torch .nn as nn
21import torch
2+ import torch .nn as nn
33
44
55class F1Score (nn .Module ):
66 """
7- F1 Score implementation with direct averaging inside the compute method.
7+ F1 Score implementation with direct averaging inside the compute method.
8+
9+ Parameters
10+ ----------
11+ num_classes : int
12+ Number of classes.
813
9- Parameters
10- ----------
11- num_classes : int
12- Number of classes.
14+ Attributes
15+ ----------
16+ num_classes : int
17+ The number of classes.
1318
14- Attributes
15- ----------
16- num_classes : int
17- The number of classes.
19+ tp : torch.Tensor
20+ Tensor for True Positives (TP) for each class.
1821
19- tp : torch.Tensor
20- Tensor for True Positives (TP ) for each class.
22+ fp : torch.Tensor
23+ Tensor for False Positives (FP ) for each class.
2124
22- fp : torch.Tensor
23- Tensor for False Positives (FP) for each class.
25+ fn : torch.Tensor
26+ Tensor for False Negatives (FN) for each class.
27+ """
2428
25- fn : torch.Tensor
26- Tensor for False Negatives (FN) for each class.
27- """
2829 def __init__ (self , num_classes ):
2930 """
30- Initializes the F1Score object, setting up the necessary state variables.
31+ Initializes the F1Score object, setting up the necessary state variables.
3132
32- Parameters
33- ----------
34- num_classes : int
35- The number of classes in the classification task.
33+ Parameters
34+ ----------
35+ num_classes : int
36+ The number of classes in the classification task.
3637
37- """
38+ """
3839
3940 super ().__init__ ()
4041
@@ -47,16 +48,16 @@ def __init__(self, num_classes):
4748
4849 def update (self , preds , target ):
4950 """
50- Update the variables with predictions and true labels.
51+ Update the variables with predictions and true labels.
5152
52- Parameters
53- ----------
54- preds : torch.Tensor
55- Predicted logits (shape: [batch_size, num_classes]).
53+ Parameters
54+ ----------
55+ preds : torch.Tensor
56+ Predicted logits (shape: [batch_size, num_classes]).
5657
57- target : torch.Tensor
58- True labels (shape: [batch_size]).
59- """
58+ target : torch.Tensor
59+ True labels (shape: [batch_size]).
60+ """
6061 preds = torch .argmax (preds , dim = 1 )
6162
6263 # Calculate True Positives (TP), False Positives (FP), and False Negatives (FN) per class
@@ -76,17 +77,20 @@ def compute(self):
7677 """
7778
7879 # Compute F1 score based on the specified averaging method
79- f1_score = 2 * torch .sum (self .tp ) / (2 * torch .sum (self .tp ) + torch .sum (self .fp ) + torch .sum (self .fn ))
80+ f1_score = (
81+ 2
82+ * torch .sum (self .tp )
83+ / (2 * torch .sum (self .tp ) + torch .sum (self .fp ) + torch .sum (self .fn ))
84+ )
8085
8186 return f1_score
8287
8388
8489def test_f1score ():
8590 f1_metric = F1Score (num_classes = 3 )
86- preds = torch .tensor ([[0.8 , 0.1 , 0.1 ],
87- [0.2 , 0.7 , 0.1 ],
88- [0.2 , 0.3 , 0.5 ],
89- [0.1 , 0.2 , 0.7 ]])
91+ preds = torch .tensor (
92+ [[0.8 , 0.1 , 0.1 ], [0.2 , 0.7 , 0.1 ], [0.2 , 0.3 , 0.5 ], [0.1 , 0.2 , 0.7 ]]
93+ )
9094
9195 target = torch .tensor ([0 , 1 , 0 , 2 ])
9296
0 commit comments