Skip to content

Commit 3ca5707

Browse files
authored
Merge pull request #27 from ChEB-AI/feature/testing_framework
Metric Testing + Implementation
2 parents 648c675 + 3ce690c commit 3ca5707

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

46 files changed

+400
-2
lines changed

chebai/callbacks/epoch_metrics.py

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,3 +47,77 @@ def compute(self):
4747
# if (precision and recall are 0) or (precision is nan), set f1 to 0
4848
classwise_f1 = classwise_f1.nan_to_num()
4949
return torch.mean(classwise_f1)
50+
51+
52+
class BalancedAccuracy(torchmetrics.Metric):
53+
"""Balanced Accuracy = (TPR + TNR) / 2 = ( TP/(TP + FN) + (TN)/(TN + FP) ) / 2
54+
55+
This metric computes the balanced accuracy, which is the average of true positive rate (TPR)
56+
and true negative rate (TNR). It is useful for imbalanced datasets where the classes are not
57+
represented equally.
58+
"""
59+
60+
def __init__(self, num_labels, dist_sync_on_step=False, threshold=0.5):
61+
super().__init__(dist_sync_on_step=dist_sync_on_step)
62+
63+
self.add_state(
64+
"true_positives",
65+
default=torch.zeros(num_labels, dtype=torch.int),
66+
dist_reduce_fx="sum",
67+
)
68+
69+
self.add_state(
70+
"false_positives",
71+
default=torch.zeros(num_labels, dtype=torch.int),
72+
dist_reduce_fx="sum",
73+
)
74+
75+
self.add_state(
76+
"true_negatives",
77+
default=torch.zeros(num_labels, dtype=torch.int),
78+
dist_reduce_fx="sum",
79+
)
80+
81+
self.add_state(
82+
"false_negatives",
83+
default=torch.zeros(num_labels, dtype=torch.int),
84+
dist_reduce_fx="sum",
85+
)
86+
87+
self.threshold = threshold
88+
89+
def update(self, preds: torch.Tensor, labels: torch.Tensor):
90+
"""Update the TPs, TNs ,FPs and FNs"""
91+
92+
# Size: Batch_size x Num_of_Classes;
93+
# summing over 1st dimension (dim=0), gives us the True positives per class
94+
tps = torch.sum(
95+
torch.logical_and(preds > self.threshold, labels.to(torch.bool)), dim=0
96+
)
97+
fps = torch.sum(
98+
torch.logical_and(preds > self.threshold, ~labels.to(torch.bool)), dim=0
99+
)
100+
tns = torch.sum(
101+
torch.logical_and(preds <= self.threshold, ~labels.to(torch.bool)), dim=0
102+
)
103+
fns = torch.sum(
104+
torch.logical_and(preds <= self.threshold, labels.to(torch.bool)), dim=0
105+
)
106+
107+
# Size: Num_of_Classes;
108+
self.true_positives += tps
109+
self.false_positives += fps
110+
self.true_negatives += tns
111+
self.false_negatives += fns
112+
113+
def compute(self):
114+
"""Compute the average value of Balanced accuracy from each batch"""
115+
116+
tpr = self.true_positives / (self.true_positives + self.false_negatives)
117+
tnr = self.true_negatives / (self.true_negatives + self.false_positives)
118+
# Convert the nan values to 0
119+
tpr = tpr.nan_to_num()
120+
tnr = tnr.nan_to_num()
121+
122+
balanced_acc = (tpr + tnr) / 2
123+
return torch.mean(balanced_acc)

chebai/cli.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,10 @@ def __init__(self, *args, **kwargs):
1111

1212
def add_arguments_to_parser(self, parser: LightningArgumentParser):
1313
for kind in ("train", "val", "test"):
14-
for average in ("micro", "macro"):
14+
for average in ("micro-f1", "macro-f1", "balanced-accuracy"):
1515
parser.link_arguments(
1616
"model.init_args.out_dim",
17-
f"model.init_args.{kind}_metrics.init_args.metrics.{average}-f1.init_args.num_labels",
17+
f"model.init_args.{kind}_metrics.init_args.metrics.{average}.init_args.num_labels",
1818
)
1919
parser.link_arguments(
2020
"model.init_args.out_dim", "trainer.callbacks.init_args.num_labels"
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
class_path: torchmetrics.MetricCollection
2+
init_args:
3+
metrics:
4+
balanced-accuracy:
5+
class_path: chebai.callbacks.epoch_metrics.BalancedAccuracy

setup.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@
4949
"wandb",
5050
"chardet",
5151
"yaml",
52+
"torchmetrics",
5253
],
5354
extras_require={"dev": ["black", "isort", "pre-commit"]},
5455
)
Lines changed: 132 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,132 @@
1+
import unittest
2+
import torch
3+
import os
4+
from chebai.callbacks.epoch_metrics import BalancedAccuracy
5+
import random
6+
7+
8+
class TestCustomBalancedAccuracyMetric(unittest.TestCase):
9+
10+
@classmethod
11+
def setUpClass(cls) -> None:
12+
cls.device = "cuda" if torch.cuda.is_available() else "cpu"
13+
14+
def test_iterative_vs_single_call_approach(self):
15+
"""Test the custom metric implementation in update fashion approach against
16+
the single call approach"""
17+
18+
preds = torch.tensor([[1, 1, 0, 1], [1, 0, 1, 1], [0, 1, 0, 1]])
19+
label = torch.tensor([[0, 0, 0, 0], [0, 0, 1, 1], [0, 1, 0, 1]])
20+
21+
num_labels = label.shape[1]
22+
iterative_custom_metric = BalancedAccuracy(num_labels=num_labels)
23+
for i in range(label.shape[0]):
24+
iterative_custom_metric.update(preds[i].unsqueeze(0), label[i].unsqueeze(0))
25+
iterative_custom_metric_score = iterative_custom_metric.compute().item()
26+
27+
single_call_custom_metric = BalancedAccuracy(num_labels=num_labels)
28+
single_call_custom_metric_score = single_call_custom_metric(preds, label).item()
29+
30+
self.assertEqual(iterative_custom_metric_score, single_call_custom_metric_score)
31+
32+
def test_metric_against_realistic_data(self):
33+
"""Test the custom metric against the standard on realistic data"""
34+
directory_path = os.path.join("tests", "test_data", "CheBIOver100_test")
35+
abs_path = os.path.join(os.getcwd(), directory_path)
36+
print(f"Checking data from - {abs_path}")
37+
num_of_files = len(os.listdir(abs_path)) // 2
38+
39+
# load single file to get the num of labels for metric class instantiation
40+
labels = torch.load(
41+
f"{directory_path}/labels{0:03d}.pt", map_location=torch.device(self.device)
42+
)
43+
num_labels = labels.shape[1]
44+
balanced_acc_custom = BalancedAccuracy(num_labels=num_labels)
45+
46+
for i in range(num_of_files):
47+
labels = torch.load(
48+
f"{directory_path}/labels{i:03d}.pt",
49+
map_location=torch.device(self.device),
50+
)
51+
preds = torch.load(
52+
f"{directory_path}/preds{i:03d}.pt",
53+
map_location=torch.device(self.device),
54+
)
55+
balanced_acc_custom.update(preds, labels)
56+
57+
balanced_acc_custom_score = balanced_acc_custom.compute().item()
58+
print(f"Balanced Accuracy for realistic data: {balanced_acc_custom_score}")
59+
60+
def test_case_when_few_class_has_no_labels(self):
61+
"""Test custom metric against standard metric for the scenario where some class has no labels"""
62+
preds = torch.tensor([[1, 1, 0, 1], [1, 0, 1, 1], [0, 1, 0, 1]])
63+
label = torch.tensor([[0, 0, 0, 0], [0, 0, 1, 1], [0, 1, 0, 1]]) # no labels
64+
65+
# tp = [0, 1, 1, 2], fp = [2, 1, 0, 1], tn = [1, 1, 2, 0], fn = [0, 0, 0, 0]
66+
# tpr = [0, 1, 1, 2] / ([0, 1, 1, 2] + [0, 0, 0, 0]) = [0, 1, 1, 1]
67+
# tnr = [1, 1, 2, 0] / ([1, 1, 2, 0] + [2, 1, 0, 1]) = [0.33333, 0.5, 1, 0]
68+
# balanced_accuracy = ([0, 1, 1, 1] + [0.33333, 0.5, 1, 0]) / 2 = ([0.16666667, 0.75, 1, 0.5]
69+
# mean bal accuracy = 0.6041666666666666
70+
71+
balanced_acc_score = self.__get_custom_metric_score(
72+
preds, label, label.shape[1]
73+
)
74+
75+
self.assertAlmostEqual(balanced_acc_score, 0.6041666666, places=4)
76+
77+
def test_all_predictions_are_1_half_labels_are_1(self):
78+
"""Test custom metric against standard metric for the scenario where all prediction are 1 but only half of
79+
the labels are 1"""
80+
preds = torch.ones((1, 900), dtype=torch.int)
81+
label = torch.ones((1, 900), dtype=torch.int)
82+
83+
mask = [
84+
[True] * (label.size(1) // 2)
85+
+ [False] * (label.size(1) - (label.size(1) // 2))
86+
]
87+
random.shuffle(mask[0])
88+
label[torch.tensor(mask)] = 0
89+
90+
# preds = [1, 1, 1, 1], label = [0, 1, 0, 1]
91+
# tp = [0, 1, 0, 1], fp = [1, 0, 1, 0], tn = [0, 0, 0, 0], fn = [0, 0, 0, 0]
92+
# tpr = tp / (tp + fn) = [0, 1, 0, 1] / [0, 1, 0, 1] = [0, 1, 0, 1]
93+
# tnr = tn / (tn + fp) = [0, 0, 0, 0]
94+
# balanced accuracy = 1 / 4 = 0.25
95+
96+
balanced_acc_custom_score = self.__get_custom_metric_score(
97+
preds, label, label.shape[1]
98+
)
99+
self.assertAlmostEqual(balanced_acc_custom_score, 0.25, places=4)
100+
101+
def test_all_labels_are_1_half_predictions_are_1(self):
102+
"""Test custom metric against standard metric for the scenario where all labels are 1 but only half of
103+
the predictions are 1"""
104+
preds = torch.ones((1, 900), dtype=torch.int)
105+
label = torch.ones((1, 900), dtype=torch.int)
106+
107+
mask = [
108+
[True] * (label.size(1) // 2)
109+
+ [False] * (label.size(1) - (label.size(1) // 2))
110+
]
111+
random.shuffle(mask[0])
112+
preds[torch.tensor(mask)] = 0
113+
114+
# label = [1, 1, 1, 1], pred = [0, 1, 0, 1]
115+
# tp = [0, 1, 0, 1], fp = [0, 1, 0, 1], tn = [0, 0, 0, 0], fn = [0, 0, 0, 0]
116+
# tpr = tp / (tp + fn) = [0, 1, 0, 1] / [0, 1, 0, 1] = [0, 1, 0, 1]
117+
# tnr = tn / (tn + fp) = [0, 0, 0, 0]
118+
# balanced accuracy = 1 / 4 = 0.25
119+
120+
balanced_acc_custom_score = self.__get_custom_metric_score(
121+
preds, label, label.shape[1]
122+
)
123+
self.assertAlmostEqual(balanced_acc_custom_score, 0.25, places=4)
124+
125+
@staticmethod
126+
def __get_custom_metric_score(preds, labels, num_labels):
127+
balanced_acc_custom = BalancedAccuracy(num_labels=num_labels)
128+
return balanced_acc_custom(preds, labels).item()
129+
130+
131+
if __name__ == "__main__":
132+
unittest.main()

0 commit comments

Comments
 (0)