Skip to content

Commit e8999ab

Browse files
committed
updated F1 test function
1 parent 4df3cd2 commit e8999ab

File tree

1 file changed

+24
-9
lines changed

1 file changed

+24
-9
lines changed

tests/test_metrics.py

Lines changed: 24 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -71,17 +71,32 @@ def test_recall():
7171
def test_f1score():
7272
import torch
7373

74-
f1_metric = F1Score(num_classes=3)
75-
preds = torch.tensor(
76-
[[0.8, 0.1, 0.1], [0.2, 0.7, 0.1], [0.2, 0.3, 0.5], [0.1, 0.2, 0.7]]
77-
)
74+
# Example case with known output
75+
y_true = torch.tensor([0, 1, 2, 2, 1, 0]) # True labels
76+
y_pred = torch.tensor([0, 1, 1, 2, 0, 0]) # Predicted labels
77+
78+
# Create F1Score object for micro and macro averaging
79+
f1_micro = F1Score(num_classes=3, macro_averaging=False)
80+
f1_macro = F1Score(num_classes=3, macro_averaging=True)
81+
82+
# Update F1 score with predictions
83+
f1_micro(y_true, y_pred)
84+
f1_macro(y_true, y_pred)
85+
86+
# Get F1 scores
87+
micro_f1_score = f1_micro.__returnmetric__()
88+
macro_f1_score = f1_macro.__returnmetric__()
89+
90+
# Check if outputs are tensors
91+
assert isinstance(micro_f1_score, torch.Tensor), "Micro F1 score should be a tensor."
92+
assert isinstance(macro_f1_score, torch.Tensor), "Macro F1 score should be a tensor."
7893

79-
target = torch.tensor([0, 1, 0, 2])
94+
# Check that F1 scores are between 0 and 1
95+
assert 0 <= micro_f1_score.item() <= 1, "Micro F1 score should be between 0 and 1."
96+
assert 0 <= macro_f1_score.item() <= 1, "Macro F1 score should be between 0 and 1."
8097

81-
f1_metric(preds, target)
82-
assert f1_metric.tp.sum().item() > 0, "Expected some true positives."
83-
assert f1_metric.fp.sum().item() > 0, "Expected some false positives."
84-
assert f1_metric.fn.sum().item() > 0, "Expected some false negatives."
98+
print(f"Micro F1 Score: {micro_f1_score.item()}")
99+
print(f"Macro F1 Score: {macro_f1_score.item()}")
85100

86101

87102
def test_precision():

0 commit comments

Comments
 (0)