Skip to content

Commit 3ce690c

Browse files
committed
Update testCustomMacroF1Metric.py
1 parent ce56bdb commit 3ce690c

File tree

1 file changed

+29
-4
lines changed

1 file changed

+29
-4
lines changed

tests/testCustomMacroF1Metric.py

Lines changed: 29 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,6 @@ def test_iterative_vs_single_call_approach(self):
9696

9797
self.assertEqual(iterative_custom_metric_score, single_call_custom_metric_score)
9898

99-
@unittest.expectedFailure
10099
def test_metric_against_realistic_data(self):
101100
"""Test the custom metric against the standard on realistic data"""
102101
directory_path = os.path.join("tests", "test_data", "CheBIOver100_test")
@@ -131,9 +130,10 @@ def test_metric_against_realistic_data(self):
131130
f"Realistic Data - Custom F1 score: {macro_f1_custom_score}, Std. F1 score: {macro_f1_standard_score}"
132131
)
133132

134-
self.assertAlmostEqual(macro_f1_custom_score, macro_f1_standard_score, places=4)
133+
self.assertNotAlmostEqual(
134+
macro_f1_custom_score, macro_f1_standard_score, places=4
135+
)
135136

136-
@unittest.expectedFailure
137137
def test_case_when_few_class_has_no_labels(self):
138138
"""Test custom metric against standard metric for the scenario where some class has no labels"""
139139
preds = torch.tensor([[1, 1, 0, 1], [1, 0, 1, 1], [0, 1, 0, 1]])
@@ -142,7 +142,32 @@ def test_case_when_few_class_has_no_labels(self):
142142
self.__get_custom_and_standard_metric_scores(label.shape[1], preds, label)
143143
)
144144

145-
self.assertAlmostEqual(macro_f1_custom_score, macro_f1_standard_score, places=4)
145+
# tps = [0, 1, 1, 2]
146+
# positive_predictions = [2, 2, 1, 3]
147+
# positive_labels = [0, 1, 1, 2]
148+
149+
# ---------------------- For Standard F1 Macro Metric ---------------------
150+
# The metric is only proper defined when TP + FP ≠ 0 ∧ TP + FN ≠ 0
151+
# If this case is encountered for any class/label, the metric for that class/label
152+
# will be set to 0 and the overall metric may therefore be affected in turn.
153+
154+
# precision = [0, 1, 1, 2] / [2, 2, 1, 3] = [0, 0.5, 1, 0.66666667]
155+
# recall = [0, 1, 1, 2] / [0, 1, 1, 2] = [0, 1, 1, 1]
156+
# classwise_f1 = [0, 1, 2, 1.33333334]/[0, 1.5, 1, 1.66666667] = [0, 0.66666667, 1, 0.8]
157+
# mean = 2.47/4 = 0.6166666681
158+
159+
# ----------------------- For Custom F1 Metric ----------------------------
160+
# Perform masking as first step to take only class with positive labels
161+
# mask = [False, True, True, True]
162+
# precision = [1, 1, 2] / [2, 1, 3] = [0.5, 1, 0.66666667]
163+
# recall = [1, 1, 2] / [1, 1, 2] = [1, 1, 1]
164+
# classwise_f1 = [1, 2, 1.33334] / [1.5, 1, 1.67] = [0.66666667, 1, 0.8]
165+
# mean = 2.47/3 = 0.8222222241 (because of masking we averaging with across positive labels only)
166+
167+
self.assertAlmostEqual(macro_f1_custom_score, 0.8222222241, places=4)
168+
self.assertNotAlmostEqual(
169+
macro_f1_custom_score, macro_f1_standard_score, places=4
170+
)
146171

147172
@staticmethod
148173
def __get_custom_and_standard_metric_scores(num_labels, preds, labels):

0 commit comments

Comments
 (0)