@@ -96,7 +96,6 @@ def test_iterative_vs_single_call_approach(self):
9696
9797 self .assertEqual (iterative_custom_metric_score , single_call_custom_metric_score )
9898
99- @unittest .expectedFailure
10099 def test_metric_against_realistic_data (self ):
101100 """Test the custom metric against the standard on realistic data"""
102101 directory_path = os .path .join ("tests" , "test_data" , "CheBIOver100_test" )
@@ -131,9 +130,10 @@ def test_metric_against_realistic_data(self):
131130 f"Realistic Data - Custom F1 score: { macro_f1_custom_score } , Std. F1 score: { macro_f1_standard_score } "
132131 )
133132
134- self .assertAlmostEqual (macro_f1_custom_score , macro_f1_standard_score , places = 4 )
133+ self .assertNotAlmostEqual (
134+ macro_f1_custom_score , macro_f1_standard_score , places = 4
135+ )
135136
136- @unittest .expectedFailure
137137 def test_case_when_few_class_has_no_labels (self ):
138138 """Test custom metric against standard metric for the scenario where some class has no labels"""
139139 preds = torch .tensor ([[1 , 1 , 0 , 1 ], [1 , 0 , 1 , 1 ], [0 , 1 , 0 , 1 ]])
@@ -142,7 +142,32 @@ def test_case_when_few_class_has_no_labels(self):
142142 self .__get_custom_and_standard_metric_scores (label .shape [1 ], preds , label )
143143 )
144144
145- self .assertAlmostEqual (macro_f1_custom_score , macro_f1_standard_score , places = 4 )
145+ # tps = [0, 1, 1, 2]
146+ # positive_predictions = [2, 2, 1, 3]
147+ # positive_labels = [0, 1, 1, 2]
148+
149+ # ---------------------- For Standard F1 Macro Metric ---------------------
150+ # The metric is only proper defined when TP + FP ≠ 0 ∧ TP + FN ≠ 0
151+ # If this case is encountered for any class/label, the metric for that class/label
152+ # will be set to 0 and the overall metric may therefore be affected in turn.
153+
154+ # precision = [0, 1, 1, 2] / [2, 2, 1, 3] = [0, 0.5, 1, 0.66666667]
155+ # recall = [0, 1, 1, 2] / [0, 1, 1, 2] = [0, 1, 1, 1]
156+ # classwise_f1 = [0, 1, 2, 1.33333334]/[0, 1.5, 1, 1.66666667] = [0, 0.66666667, 1, 0.8]
157+ # mean = 2.47/4 = 0.6166666681
158+
159+ # ----------------------- For Custom F1 Metric ----------------------------
160+ # Perform masking as first step to take only class with positive labels
161+ # mask = [False, True, True, True]
162+ # precision = [1, 1, 2] / [2, 1, 3] = [0.5, 1, 0.66666667]
163+ # recall = [1, 1, 2] / [1, 1, 2] = [1, 1, 1]
164+ # classwise_f1 = [1, 2, 1.33334] / [1.5, 1, 1.67] = [0.66666667, 1, 0.8]
165+ # mean = 2.47/3 = 0.8222222241 (because of masking we averaging with across positive labels only)
166+
167+ self .assertAlmostEqual (macro_f1_custom_score , 0.8222222241 , places = 4 )
168+ self .assertNotAlmostEqual (
169+ macro_f1_custom_score , macro_f1_standard_score , places = 4
170+ )
146171
147172 @staticmethod
148173 def __get_custom_and_standard_metric_scores (num_labels , preds , labels ):
0 commit comments