77
88
99class TestCustomMacroF1Metric (unittest .TestCase ):
10-
1110 @classmethod
1211 def setUpClass (cls ) -> None :
1312 cls .device = "cuda" if torch .cuda .is_available () else "cpu"
1413
15- def test_all_predictions_are_1_half_labels_are_1 (self ):
16- """Test custom metric against standard metric for the scenario where all prediction are 1 but only half of
17- the labels are 1"""
14+ def test_all_predictions_are_1_half_labels_are_1 (self ) -> None :
15+ """
16+ Test custom metric against standard metric for the scenario where all predictions are 1
17+ but only half of the labels are 1.
18+ """
1819 preds = torch .ones ((1 , 900 ), dtype = torch .int )
1920 label = torch .ones ((1 , 900 ), dtype = torch .int )
2021
22+ # Randomly set half of the labels to 0
2123 mask = [
2224 [True ] * (label .size (1 ) // 2 )
2325 + [False ] * (label .size (1 ) - (label .size (1 ) // 2 ))
2426 ]
2527 random .shuffle (mask [0 ])
2628 label [torch .tensor (mask )] = 0
2729
30+ # Get custom and standard metric scores
2831 macro_f1_custom_score , macro_f1_standard_score = (
2932 self .__get_custom_and_standard_metric_scores (label .shape [1 ], preds , label )
3033 )
@@ -52,25 +55,29 @@ def test_all_predictions_are_1_half_labels_are_1(self):
5255 # precision = [1, 1, 1, 1, 1] / [1, 1, 1, 1, 1] = [1, 1, 1, 1, 1]
5356 # recall = [1, 1, 1, 1, 1] / [1, 1, 1, 1, 1] = [1, 1, 1, 1, 1]
5457 # classwise_f1 = [2, 2, 2, 2, 2] / [2, 2, 2, 2, 2] = [1, 1, 1, 1, 1]
55- # mean = 5/5 = 1 (because of masking we averaging with across positive labels only)
58+ # mean = 5/5 = 1 (because of masking we're averaging with across positive labels only)
5659 self .assertAlmostEqual (macro_f1_custom_score , 1 , places = 4 )
5760 self .assertNotAlmostEqual (
5861 macro_f1_custom_score , macro_f1_standard_score , places = 4
5962 )
6063
61- def test_all_labels_are_1_half_predictions_are_1 (self ):
62- """Test custom metric against standard metric for the scenario where all labels are 1 but only half of
63- the predictions are 1"""
64+ def test_all_labels_are_1_half_predictions_are_1 (self ) -> None :
65+ """
66+ Test custom metric against standard metric for the scenario where all labels are 1
67+ but only half of the predictions are 1.
68+ """
6469 preds = torch .ones ((1 , 900 ), dtype = torch .int )
6570 label = torch .ones ((1 , 900 ), dtype = torch .int )
6671
72+ # Randomly set half of the predictions to 0
6773 mask = [
6874 [True ] * (label .size (1 ) // 2 )
6975 + [False ] * (label .size (1 ) - (label .size (1 ) // 2 ))
7076 ]
7177 random .shuffle (mask [0 ])
7278 preds [torch .tensor (mask )] = 0
7379
80+ # Get custom and standard metric scores
7481 macro_f1_custom_score , macro_f1_standard_score = (
7582 self .__get_custom_and_standard_metric_scores (label .shape [1 ], preds , label )
7683 )
@@ -79,9 +86,11 @@ def test_all_labels_are_1_half_predictions_are_1(self):
7986 # and since all labels are positive in this scenario, custom and std metric are same
8087 self .assertAlmostEqual (macro_f1_custom_score , macro_f1_standard_score , places = 4 )
8188
82- def test_iterative_vs_single_call_approach (self ):
83- """Test the custom metric implementation in update fashion approach against
84- the single call approach"""
89+ def test_iterative_vs_single_call_approach (self ) -> None :
90+ """
91+ Test the custom metric implementation in update fashion approach against
92+ the single call approach.
93+ """
8594 preds = torch .tensor ([[1 , 1 , 0 , 1 ], [1 , 0 , 1 , 1 ], [0 , 1 , 0 , 1 ]])
8695 label = torch .tensor ([[0 , 0 , 0 , 0 ], [0 , 0 , 1 , 1 ], [0 , 1 , 0 , 1 ]])
8796
@@ -94,24 +103,27 @@ def test_iterative_vs_single_call_approach(self):
94103 single_call_custom_metric = MacroF1 (num_labels = num_labels )
95104 single_call_custom_metric_score = single_call_custom_metric (preds , label ).item ()
96105
106+ # Assert iterative and single call approaches give the same metric score
97107 self .assertEqual (iterative_custom_metric_score , single_call_custom_metric_score )
98108
99- def test_metric_against_realistic_data (self ):
100- """Test the custom metric against the standard on realistic data"""
109+ def test_metric_against_realistic_data (self ) -> None :
110+ """
111+ Test the custom metric against the standard on realistic data.
112+ """
101113 directory_path = os .path .join ("tests" , "test_data" , "CheBIOver100_test" )
102114 abs_path = os .path .join (os .getcwd (), directory_path )
103115 print (f"Checking data from - { abs_path } " )
104116 num_of_files = len (os .listdir (abs_path )) // 2
105117
106- # load single file to get the num of labels for metric class instantiation
118+ # Load single file to get the number of labels for metric class instantiation
107119 labels = torch .load (
108120 f"{ directory_path } /labels{ 0 :03d} .pt" , map_location = torch .device (self .device )
109121 )
110122 num_labels = labels .shape [1 ]
111123 macro_f1_custom = MacroF1 (num_labels = num_labels )
112124 macro_f1_standard = MultilabelF1Score (num_labels = num_labels , average = "macro" )
113125
114- # load each file in the directory and update the stats
126+ # Load each file in the directory and update the metrics
115127 for i in range (num_of_files ):
116128 labels = torch .load (
117129 f"{ directory_path } /labels{ i :03d} .pt" ,
@@ -130,14 +142,19 @@ def test_metric_against_realistic_data(self):
130142 f"Realistic Data - Custom F1 score: { macro_f1_custom_score } , Std. F1 score: { macro_f1_standard_score } "
131143 )
132144
145+ # Assert custom metric score is not equal to standard metric score
133146 self .assertNotAlmostEqual (
134147 macro_f1_custom_score , macro_f1_standard_score , places = 4
135148 )
136149
137- def test_case_when_few_class_has_no_labels (self ):
138- """Test custom metric against standard metric for the scenario where some class has no labels"""
150+ def test_case_when_few_class_has_no_labels (self ) -> None :
151+ """
152+ Test custom metric against standard metric for the scenario where some class has no labels.
153+ """
139154 preds = torch .tensor ([[1 , 1 , 0 , 1 ], [1 , 0 , 1 , 1 ], [0 , 1 , 0 , 1 ]])
140155 label = torch .tensor ([[0 , 0 , 0 , 0 ], [0 , 0 , 1 , 1 ], [0 , 1 , 0 , 1 ]])
156+
157+ # Get custom and standard metric scores
141158 macro_f1_custom_score , macro_f1_standard_score = (
142159 self .__get_custom_and_standard_metric_scores (label .shape [1 ], preds , label )
143160 )
@@ -170,7 +187,22 @@ def test_case_when_few_class_has_no_labels(self):
170187 )
171188
172189 @staticmethod
173- def __get_custom_and_standard_metric_scores (num_labels , preds , labels ):
190+ def __get_custom_and_standard_metric_scores (
191+ num_labels : int , preds : torch .Tensor , labels : torch .Tensor
192+ ) -> tuple :
193+ """
194+ Helper method to calculate custom and standard macro F1 scores.
195+
196+ Args:
197+ num_labels (int): Number of labels/classes.
198+ preds (torch.Tensor): Predicted tensor of shape (batch_size, num_labels).
199+ labels (torch.Tensor): True labels tensor of shape (batch_size, num_labels).
200+
201+ Returns:
202+ tuple: A tuple containing two floats:
203+ - macro_f1_custom_score: Custom macro F1 score.
204+ - macro_f1_standard_score: Standard macro F1 score.
205+ """
174206 # Custom metric score
175207 macro_f1_custom = MacroF1 (num_labels = num_labels )
176208 macro_f1_custom_score = macro_f1_custom (preds , labels ).item ()
0 commit comments