@@ -71,17 +71,32 @@ def test_recall():
7171def test_f1score ():
7272 import torch
7373
74- f1_metric = F1Score (num_classes = 3 )
75- preds = torch .tensor (
76- [[0.8 , 0.1 , 0.1 ], [0.2 , 0.7 , 0.1 ], [0.2 , 0.3 , 0.5 ], [0.1 , 0.2 , 0.7 ]]
77- )
74+ # Example case with known output
75+ y_true = torch .tensor ([0 , 1 , 2 , 2 , 1 , 0 ]) # True labels
76+ y_pred = torch .tensor ([0 , 1 , 1 , 2 , 0 , 0 ]) # Predicted labels
77+
78+ # Create F1Score object for micro and macro averaging
79+ f1_micro = F1Score (num_classes = 3 , macro_averaging = False )
80+ f1_macro = F1Score (num_classes = 3 , macro_averaging = True )
81+
82+ # Update F1 score with predictions
83+ f1_micro (y_true , y_pred )
84+ f1_macro (y_true , y_pred )
85+
86+ # Get F1 scores
87+ micro_f1_score = f1_micro .__returnmetric__ ()
88+ macro_f1_score = f1_macro .__returnmetric__ ()
89+
90+ # Check if outputs are tensors
91+ assert isinstance (micro_f1_score , torch .Tensor ), "Micro F1 score should be a tensor."
92+ assert isinstance (macro_f1_score , torch .Tensor ), "Macro F1 score should be a tensor."
7893
79- target = torch .tensor ([0 , 1 , 0 , 2 ])
94+ # Check that F1 scores are between 0 and 1
95+ assert 0 <= micro_f1_score .item () <= 1 , "Micro F1 score should be between 0 and 1."
96+ assert 0 <= macro_f1_score .item () <= 1 , "Macro F1 score should be between 0 and 1."
8097
81- f1_metric (preds , target )
82- assert f1_metric .tp .sum ().item () > 0 , "Expected some true positives."
83- assert f1_metric .fp .sum ().item () > 0 , "Expected some false positives."
84- assert f1_metric .fn .sum ().item () > 0 , "Expected some false negatives."
98+ print (f"Micro F1 Score: { micro_f1_score .item ()} " )
99+ print (f"Macro F1 Score: { macro_f1_score .item ()} " )
85100
86101
87102def test_precision ():
0 commit comments