Skip to content

Commit c5604f6

Browse files
committed
fixed error in F1.py
1 parent 661b622 commit c5604f6

File tree

1 file changed

+1
-12
lines changed

1 file changed

+1
-12
lines changed

utils/metrics/F1.py

Lines changed: 1 addition & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ def __init__(self, num_classes):
4646
self.fp = torch.zeros(num_classes)
4747
self.fn = torch.zeros(num_classes)
4848

49-
def update(self, preds, target):
49+
def forward(self, preds, target):
5050
"""
5151
Update the variables with predictions and true labels.
5252
@@ -66,17 +66,6 @@ def update(self, preds, target):
6666
self.fp[i] += torch.sum((preds == i) & (target != i)).float()
6767
self.fn[i] += torch.sum((preds != i) & (target == i)).float()
6868

69-
def compute(self):
70-
"""
71-
Compute the F1 score.
72-
73-
Returns
74-
-------
75-
torch.Tensor
76-
The computed F1 score.
77-
"""
78-
79-
# Compute F1 score based on the specified averaging method
8069
f1_score = (
8170
2
8271
* torch.sum(self.tp)

0 commit comments

Comments
 (0)