Skip to content

Commit a4c970d

Browse files
committed
added testing for custom metric + custom balanced acc implementation
1 parent 2e8ef97 commit a4c970d

File tree

6 files changed

+346
-2
lines changed

6 files changed

+346
-2
lines changed

chebai/callbacks/epoch_metrics.py

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,3 +47,61 @@ def compute(self):
4747
# if (precision and recall are 0) or (precision is nan), set f1 to 0
4848
classwise_f1 = classwise_f1.nan_to_num()
4949
return torch.mean(classwise_f1)
50+
51+
class BalancedAccuracy(torchmetrics.Metric):
52+
def __init__(self, num_labels, dist_sync_on_step=False, threshold=0.5):
53+
super().__init__(dist_sync_on_step=dist_sync_on_step)
54+
55+
self.add_state(
56+
"true_positives",
57+
default=torch.zeros(num_labels, dtype=torch.int),
58+
dist_reduce_fx="sum",
59+
)
60+
61+
self.add_state(
62+
"false_positives",
63+
default=torch.zeros(num_labels, dtype=torch.int),
64+
dist_reduce_fx="sum",
65+
)
66+
67+
self.add_state(
68+
"true_negatives",
69+
default=torch.zeros(num_labels, dtype=torch.int),
70+
dist_reduce_fx="sum",
71+
)
72+
73+
self.add_state(
74+
"false_negatives",
75+
default=torch.zeros(num_labels, dtype=torch.int),
76+
dist_reduce_fx="sum",
77+
)
78+
79+
self.threshold = threshold
80+
81+
def update(self, preds: torch.Tensor, labels: torch.Tensor):
82+
""""Update the TPs, TNs ,FPs and FNs """
83+
84+
# Size: Batch_size x Num_of_Classes;
85+
# summing over 1st dimension (dim=0), gives us the True positives per class
86+
tps = torch.sum(torch.logical_and(preds > self.threshold, labels.to(torch.bool)), dim=0)
87+
fps = torch.sum(torch.logical_and(preds > self.threshold, ~labels.to(torch.bool)), dim=0)
88+
tns = torch.sum(torch.logical_and(preds <= self.threshold, ~labels.to(torch.bool)), dim=0)
89+
fns = torch.sum(torch.logical_and(preds <= self.threshold, labels.to(torch.bool)), dim=0)
90+
91+
# Size: Num_of_Classes;
92+
self.true_positives += tps
93+
self.false_positives += fps
94+
self.true_negatives += tns
95+
self.false_negatives += fns
96+
97+
def compute(self):
98+
"""Compute the average value of Balanced accuracy from each batch"""
99+
100+
tpr = self.true_positives / (self.true_positives + self.false_negatives)
101+
tnr = self.true_negatives / (self.true_negatives + self.false_positives)
102+
# Convert the nan values to 0
103+
tpr = tpr.nan_to_num()
104+
tnr = tnr.nan_to_num()
105+
106+
balanced_acc = (tpr + tnr) / 2
107+
return torch.mean(balanced_acc)

chebai/cli.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,10 @@ def __init__(self, *args, **kwargs):
1111

1212
def add_arguments_to_parser(self, parser: LightningArgumentParser):
1313
for kind in ("train", "val", "test"):
14-
for average in ("micro", "macro"):
14+
for average in ("micro-f1", "macro-f1", "balanced-accuracy"):
1515
parser.link_arguments(
1616
"model.init_args.out_dim",
17-
f"model.init_args.{kind}_metrics.init_args.metrics.{average}-f1.init_args.num_labels",
17+
f"model.init_args.{kind}_metrics.init_args.metrics.{average}.init_args.num_labels",
1818
)
1919
parser.link_arguments(
2020
"model.init_args.out_dim", "trainer.callbacks.init_args.num_labels"
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
class_path: torchmetrics.MetricCollection
2+
init_args:
3+
metrics:
4+
balanced-accuracy:
5+
class_path: chebai.callbacks.epoch_metrics.BalancedAccuracy

setup.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@
4949
"wandb",
5050
"chardet",
5151
"yaml",
52+
"torchmetrics",
5253
],
5354
extras_require={"dev": ["black", "isort", "pre-commit"]},
5455
)
Lines changed: 120 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,120 @@
1+
import unittest
2+
import torch
3+
import os
4+
from chebai.callbacks.epoch_metrics import BalancedAccuracy
5+
import random
6+
7+
8+
class TestCustomMacroF1Metric(unittest.TestCase):
9+
10+
@classmethod
11+
def setUpClass(cls) -> None:
12+
cls.device = "cuda" if torch.cuda.is_available() else "cpu"
13+
14+
def test_iterative_vs_single_call_approach(self):
15+
"""Test the custom metric implementation in update fashion approach against
16+
the single call approach"""
17+
18+
preds = torch.tensor([[1, 1, 0, 1],
19+
[1, 0, 1, 1],
20+
[0, 1, 0, 1]])
21+
label = torch.tensor([[0, 0, 0, 0],
22+
[0, 0, 1, 1],
23+
[0, 1, 0, 1]])
24+
25+
num_labels = label.shape[1]
26+
iterative_custom_metric = BalancedAccuracy(num_labels=num_labels)
27+
for i in range(label.shape[0]):
28+
iterative_custom_metric.update(preds[i].unsqueeze(0), label[i].unsqueeze(0))
29+
iterative_custom_metric_score = iterative_custom_metric.compute().item()
30+
31+
single_call_custom_metric = BalancedAccuracy(num_labels=num_labels)
32+
single_call_custom_metric_score = single_call_custom_metric(preds, label).item()
33+
34+
self.assertEqual(iterative_custom_metric_score, single_call_custom_metric_score)
35+
36+
def test_metric_against_realistic_data(self):
37+
"""Test the custom metric against the standard on realistic data"""
38+
directory_path = "CheBIOver100_test"
39+
abs_path = os.path.join(os.getcwd(), directory_path)
40+
print(f"Checking data from - {abs_path}")
41+
num_of_files = len(os.listdir(abs_path)) // 2
42+
43+
# load single file to get the num of labels for metric class instantiation
44+
labels = torch.load(f'{directory_path}/labels{0:03d}.pt', map_location=torch.device(self.device))
45+
num_labels = labels.shape[1]
46+
balanced_acc_custom = BalancedAccuracy(num_labels=num_labels)
47+
48+
for i in range(num_of_files):
49+
labels = torch.load(f'{directory_path}/labels{i:03d}.pt', map_location=torch.device(self.device))
50+
preds = torch.load(f'{directory_path}/preds{i:03d}.pt', map_location=torch.device(self.device))
51+
balanced_acc_custom.update(preds, labels)
52+
53+
balanced_acc_custom_score = balanced_acc_custom.compute().item()
54+
print(f"Balanced Accuracy for realistic data: {balanced_acc_custom_score}")
55+
56+
def test_case_when_few_class_has_no_labels(self):
57+
"""Test custom metric against standard metric for the scenario where some class has no labels"""
58+
preds = torch.tensor([[1, 1, 0, 1],
59+
[1, 0, 1, 1],
60+
[0, 1, 0, 1]])
61+
label = torch.tensor([[0, 0, 0, 0], # no labels
62+
[0, 0, 1, 1],
63+
[0, 1, 0, 1]])
64+
65+
# tp = [0, 1, 1, 2], fp = [2, 1, 0, 1], tn = [1, 1, 2, 0], fn = [0, 0, 0, 0]
66+
# tpr = [0, 1, 1, 2] / ([0, 1, 1, 2] + [0, 0, 0, 0]) = [0, 1, 1, 1]
67+
# tnr = [1, 1, 2, 0] / ([1, 1, 2, 0] + [2, 1, 0, 1]) = [0.33333, 0.5, 1, 0]
68+
# balanced_accuracy = ([0, 1, 1, 1] + [0.33333, 0.5, 1, 0]) / 2 = ([0.16666667, 0.75, 1, 0.5]
69+
# mean bal accuracy = 0.6041666666666666
70+
71+
balanced_acc_score = self.__get_custom_metric_score(preds, label, label.shape[1])
72+
73+
self.assertAlmostEqual(balanced_acc_score, 0.6041666666, places=4)
74+
75+
def test_all_predictions_are_1_half_labels_are_1(self):
76+
"""Test custom metric against standard metric for the scenario where all prediction are 1 but only half of
77+
the labels are 1"""
78+
preds = torch.ones((1, 900), dtype=torch.int)
79+
label = torch.ones((1, 900), dtype=torch.int)
80+
81+
mask = [[True] * (label.size(1) // 2) + [False] * (label.size(1) - (label.size(1) // 2))]
82+
random.shuffle(mask[0])
83+
label[torch.tensor(mask)] = 0
84+
85+
# preds = [1, 1, 1, 1], label = [0, 1, 0, 1]
86+
# tp = [0, 1, 0, 1], fp = [1, 0, 1, 0], tn = [0, 0, 0, 0], fn = [0, 0, 0, 0]
87+
# tpr = tp / (tp + fn) = [0, 1, 0, 1] / [0, 1, 0, 1] = [0, 1, 0, 1]
88+
# tnr = tn / (tn + fp) = [0, 0, 0, 0]
89+
# balanced accuracy = 1 / 4 = 0.25
90+
91+
balanced_acc_custom_score = self.__get_custom_metric_score(preds, label, label.shape[1])
92+
self.assertAlmostEqual(balanced_acc_custom_score, 0.25, places=4)
93+
94+
def test_all_labels_are_1_half_predictions_are_1(self):
95+
"""Test custom metric against standard metric for the scenario where all labels are 1 but only half of
96+
the predictions are 1"""
97+
preds = torch.ones((1, 900), dtype=torch.int)
98+
label = torch.ones((1, 900), dtype=torch.int)
99+
100+
mask = [[True] * (label.size(1) // 2) + [False] * (label.size(1) - (label.size(1) // 2))]
101+
random.shuffle(mask[0])
102+
preds[torch.tensor(mask)] = 0
103+
104+
# label = [1, 1, 1, 1], pred = [0, 1, 0, 1]
105+
# tp = [0, 1, 0, 1], fp = [0, 1, 0, 1], tn = [0, 0, 0, 0], fn = [0, 0, 0, 0]
106+
# tpr = tp / (tp + fn) = [0, 1, 0, 1] / [0, 1, 0, 1] = [0, 1, 0, 1]
107+
# tnr = tn / (tn + fp) = [0, 0, 0, 0]
108+
# balanced accuracy = 1 / 4 = 0.25
109+
110+
balanced_acc_custom_score = self.__get_custom_metric_score(preds, label, label.shape[1])
111+
self.assertAlmostEqual(balanced_acc_custom_score, 0.25, places=4)
112+
113+
@staticmethod
114+
def __get_custom_metric_score(preds, labels, num_labels):
115+
balanced_acc_custom = BalancedAccuracy(num_labels=num_labels)
116+
return balanced_acc_custom(preds, labels).item()
117+
118+
119+
if __name__ == "__main__":
120+
unittest.main()

tests/testCustomMacroF1Metric.py

Lines changed: 160 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,160 @@
1+
import unittest
2+
import torch
3+
import os
4+
from chebai.callbacks.epoch_metrics import MacroF1
5+
from torchmetrics.classification import MultilabelF1Score
6+
import random
7+
8+
9+
class TestCustomMacroF1Metric(unittest.TestCase):
10+
11+
@classmethod
12+
def setUpClass(cls) -> None:
13+
cls.device = "cuda" if torch.cuda.is_available() else "cpu"
14+
15+
@unittest.expectedFailure
16+
def test_all_predictions_are_1_half_labels_are_1(self):
17+
"""Test custom metric against standard metric for the scenario where all prediction are 1 but only half of
18+
the labels are 1"""
19+
preds = torch.ones((1, 900), dtype=torch.int)
20+
label = torch.ones((1, 900), dtype=torch.int)
21+
22+
mask = [
23+
[True] * (label.size(1) // 2)
24+
+ [False] * (label.size(1) - (label.size(1) // 2))
25+
]
26+
random.shuffle(mask[0])
27+
label[torch.tensor(mask)] = 0
28+
29+
macro_f1_custom_score, macro_f1_standard_score = (
30+
self.__get_custom_and_standard_metric_scores(label.shape[1], preds, label)
31+
)
32+
33+
# preds = torch.tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1]])
34+
# label = torch.tensor([[1, 1, 0, 0, 1, 1, 0, 0, 1, 0]])
35+
# tps = [1, 1, 0, 0, 1, 1, 0, 0, 1, 0]
36+
# positive_predictions = [1, 1, 1, 1, 1, 1, 1, 1, 1, 1]
37+
# positive_labels = [1, 1, 0, 0, 1, 1, 0, 0, 1, 0]
38+
39+
# ---------------------- For Standard F1 Macro Metric ---------------------
40+
# The metric is only proper defined when TP + FP ≠ 0 ∧ TP + FN ≠ 0
41+
# If this case is encountered for any class/label, the metric for that class/label
42+
# will be set to 0 and the overall metric may therefore be affected in turn.
43+
44+
# precision = [1, 1, 0, 0, 1, 1, 0, 0, 1, 0]
45+
# recall = [1, 1, 0, 0, 1, 1, 0, 0, 1, 0]
46+
# classwise_f1 = [2, 2, 0, 0, 2, 2, 0, 0, 2, 0] / [2, 2, 0, 0, 2, 2, 0, 0, 2, 0]
47+
# = [1, 1, 0, 0, 1, 1, 0, 0, 1, 0]
48+
# mean = 5/10 = 0.5
49+
50+
# ----------------------- For Custom F1 Metric ----------------------------
51+
# Perform masking as first step to take only class with positive labels
52+
# mask = [True, True, False, False, True, True, False, False, True, False]
53+
# precision = [1, 1, 1, 1, 1] / [1, 1, 1, 1, 1] = [1, 1, 1, 1, 1]
54+
# recall = [1, 1, 1, 1, 1] / [1, 1, 1, 1, 1] = [1, 1, 1, 1, 1]
55+
# classwise_f1 = [2, 2, 2, 2, 2] / [2, 2, 2, 2, 2] = [1, 1, 1, 1, 1]
56+
# mean = 5/5 = 1 (because of masking we averaging with across positive labels only)
57+
58+
self.assertAlmostEqual(macro_f1_custom_score, macro_f1_standard_score, places=4)
59+
60+
def test_all_labels_are_1_half_predictions_are_1(self):
61+
"""Test custom metric against standard metric for the scenario where all labels are 1 but only half of
62+
the predictions are 1"""
63+
preds = torch.ones((1, 900), dtype=torch.int)
64+
label = torch.ones((1, 900), dtype=torch.int)
65+
66+
mask = [
67+
[True] * (label.size(1) // 2)
68+
+ [False] * (label.size(1) - (label.size(1) // 2))
69+
]
70+
random.shuffle(mask[0])
71+
preds[torch.tensor(mask)] = 0
72+
73+
macro_f1_custom_score, macro_f1_standard_score = (
74+
self.__get_custom_and_standard_metric_scores(label.shape[1], preds, label)
75+
)
76+
77+
# As we are only taking positive labels for custom metric calculation via masking,
78+
# and since all labels are positive in this scenario, custom and std metric are same
79+
self.assertAlmostEqual(macro_f1_custom_score, macro_f1_standard_score, places=4)
80+
81+
def test_iterative_vs_single_call_approach(self):
82+
"""Test the custom metric implementation in update fashion approach against
83+
the single call approach"""
84+
preds = torch.tensor([[1, 1, 0, 1], [1, 0, 1, 1], [0, 1, 0, 1]])
85+
label = torch.tensor([[0, 0, 0, 0], [0, 0, 1, 1], [0, 1, 0, 1]])
86+
87+
num_labels = label.shape[1]
88+
iterative_custom_metric = MacroF1(num_labels=num_labels)
89+
for i in range(label.shape[0]):
90+
iterative_custom_metric.update(preds[i].unsqueeze(0), label[i].unsqueeze(0))
91+
iterative_custom_metric_score = iterative_custom_metric.compute().item()
92+
93+
single_call_custom_metric = MacroF1(num_labels=num_labels)
94+
single_call_custom_metric_score = single_call_custom_metric(preds, label).item()
95+
96+
self.assertEqual(iterative_custom_metric_score, single_call_custom_metric_score)
97+
98+
@unittest.expectedFailure
99+
def test_metric_against_realistic_data(self):
100+
"""Test the custom metric against the standard on realistic data"""
101+
directory_path = "CheBIOver100_test"
102+
abs_path = os.path.join(os.getcwd(), directory_path)
103+
print(f"Checking data from - {abs_path}")
104+
num_of_files = len(os.listdir(abs_path)) // 2
105+
106+
# load single file to get the num of labels for metric class instantiation
107+
labels = torch.load(
108+
f"{directory_path}/labels{0:03d}.pt", map_location=torch.device(self.device)
109+
)
110+
num_labels = labels.shape[1]
111+
macro_f1_custom = MacroF1(num_labels=num_labels)
112+
macro_f1_standard = MultilabelF1Score(num_labels=num_labels, average="macro")
113+
114+
# load each file in the directory and update the stats
115+
for i in range(num_of_files):
116+
labels = torch.load(
117+
f"{directory_path}/labels{i:03d}.pt",
118+
map_location=torch.device(self.device),
119+
)
120+
preds = torch.load(
121+
f"{directory_path}/preds{i:03d}.pt",
122+
map_location=torch.device(self.device),
123+
)
124+
macro_f1_standard.update(preds, labels)
125+
macro_f1_custom.update(preds, labels)
126+
127+
macro_f1_custom_score = macro_f1_custom.compute().item()
128+
macro_f1_standard_score = macro_f1_standard.compute().item()
129+
print(
130+
f"Realistic Data - Custom F1 score: {macro_f1_custom_score}, Std. F1 score: {macro_f1_standard_score}"
131+
)
132+
133+
self.assertAlmostEqual(macro_f1_custom_score, macro_f1_standard_score, places=4)
134+
135+
@unittest.expectedFailure
136+
def test_case_when_few_class_has_no_labels(self):
137+
"""Test custom metric against standard metric for the scenario where some class has no labels"""
138+
preds = torch.tensor([[1, 1, 0, 1], [1, 0, 1, 1], [0, 1, 0, 1]])
139+
label = torch.tensor([[0, 0, 0, 0], [0, 0, 1, 1], [0, 1, 0, 1]])
140+
macro_f1_custom_score, macro_f1_standard_score = (
141+
self.__get_custom_and_standard_metric_scores(label.shape[1], preds, label)
142+
)
143+
144+
self.assertAlmostEqual(macro_f1_custom_score, macro_f1_standard_score, places=4)
145+
146+
@staticmethod
147+
def __get_custom_and_standard_metric_scores(num_labels, preds, labels):
148+
# Custom metric score
149+
macro_f1_custom = MacroF1(num_labels=num_labels)
150+
macro_f1_custom_score = macro_f1_custom(preds, labels).item()
151+
152+
# Standard metric score
153+
macro_f1_standard = MultilabelF1Score(num_labels=num_labels, average="macro")
154+
macro_f1_standard_score = macro_f1_standard(preds, labels).item()
155+
156+
return macro_f1_custom_score, macro_f1_standard_score
157+
158+
159+
if __name__ == "__main__":
160+
unittest.main()

0 commit comments

Comments
 (0)