Skip to content

Commit 2ec99a3

Browse files
committed
Tests (except chebi data) : docstrings + type hints
- Updated all tests with docstring and type hints except tests related to chebi data due to dependency with PR#29 Issue#10 - Chebi tests will be updated after the above PR gets merged
1 parent 0de5e41 commit 2ec99a3

File tree

4 files changed

+224
-104
lines changed

4 files changed

+224
-104
lines changed

tests/testCustomBalancedAccuracyMetric.py

Lines changed: 43 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -6,15 +6,22 @@
66

77

88
class TestCustomBalancedAccuracyMetric(unittest.TestCase):
9+
"""
10+
Unit tests for the Custom Balanced Accuracy metric.
11+
"""
912

1013
@classmethod
1114
def setUpClass(cls) -> None:
15+
"""
16+
Set up class-level variables.
17+
"""
1218
cls.device = "cuda" if torch.cuda.is_available() else "cpu"
1319

14-
def test_iterative_vs_single_call_approach(self):
15-
"""Test the custom metric implementation in update fashion approach against
16-
the single call approach"""
17-
20+
def test_iterative_vs_single_call_approach(self) -> None:
21+
"""
22+
Test the custom metric implementation in update fashion approach against
23+
the single call approach.
24+
"""
1825
preds = torch.tensor([[1, 1, 0, 1], [1, 0, 1, 1], [0, 1, 0, 1]])
1926
label = torch.tensor([[0, 0, 0, 0], [0, 0, 1, 1], [0, 1, 0, 1]])
2027

@@ -29,8 +36,10 @@ def test_iterative_vs_single_call_approach(self):
2936

3037
self.assertEqual(iterative_custom_metric_score, single_call_custom_metric_score)
3138

32-
def test_metric_against_realistic_data(self):
33-
"""Test the custom metric against the standard on realistic data"""
39+
def test_metric_against_realistic_data(self) -> None:
40+
"""
41+
Test the custom metric against the standard on realistic data.
42+
"""
3443
directory_path = os.path.join("tests", "test_data", "CheBIOver100_test")
3544
abs_path = os.path.join(os.getcwd(), directory_path)
3645
print(f"Checking data from - {abs_path}")
@@ -57,8 +66,10 @@ def test_metric_against_realistic_data(self):
5766
balanced_acc_custom_score = balanced_acc_custom.compute().item()
5867
print(f"Balanced Accuracy for realistic data: {balanced_acc_custom_score}")
5968

60-
def test_case_when_few_class_has_no_labels(self):
61-
"""Test custom metric against standard metric for the scenario where some class has no labels"""
69+
def test_case_when_few_class_has_no_labels(self) -> None:
70+
"""
71+
Test custom metric against standard metric for the scenario where some class has no labels.
72+
"""
6273
preds = torch.tensor([[1, 1, 0, 1], [1, 0, 1, 1], [0, 1, 0, 1]])
6374
label = torch.tensor([[0, 0, 0, 0], [0, 0, 1, 1], [0, 1, 0, 1]]) # no labels
6475

@@ -74,9 +85,11 @@ def test_case_when_few_class_has_no_labels(self):
7485

7586
self.assertAlmostEqual(balanced_acc_score, 0.6041666666, places=4)
7687

77-
def test_all_predictions_are_1_half_labels_are_1(self):
78-
"""Test custom metric against standard metric for the scenario where all prediction are 1 but only half of
79-
the labels are 1"""
88+
def test_all_predictions_are_1_half_labels_are_1(self) -> None:
89+
"""
90+
Test custom metric against standard metric for the scenario where all predictions are 1 but only half of
91+
the labels are 1.
92+
"""
8093
preds = torch.ones((1, 900), dtype=torch.int)
8194
label = torch.ones((1, 900), dtype=torch.int)
8295

@@ -98,9 +111,11 @@ def test_all_predictions_are_1_half_labels_are_1(self):
98111
)
99112
self.assertAlmostEqual(balanced_acc_custom_score, 0.25, places=4)
100113

101-
def test_all_labels_are_1_half_predictions_are_1(self):
102-
"""Test custom metric against standard metric for the scenario where all labels are 1 but only half of
103-
the predictions are 1"""
114+
def test_all_labels_are_1_half_predictions_are_1(self) -> None:
115+
"""
116+
Test custom metric against standard metric for the scenario where all labels are 1 but only half of
117+
the predictions are 1.
118+
"""
104119
preds = torch.ones((1, 900), dtype=torch.int)
105120
label = torch.ones((1, 900), dtype=torch.int)
106121

@@ -123,7 +138,20 @@ def test_all_labels_are_1_half_predictions_are_1(self):
123138
self.assertAlmostEqual(balanced_acc_custom_score, 0.25, places=4)
124139

125140
@staticmethod
126-
def __get_custom_metric_score(preds, labels, num_labels):
141+
def __get_custom_metric_score(
142+
preds: torch.Tensor, labels: torch.Tensor, num_labels: int
143+
) -> float:
144+
"""
145+
Helper function to compute the custom metric score.
146+
147+
Args:
148+
- preds (torch.Tensor): Predictions tensor.
149+
- labels (torch.Tensor): Labels tensor.
150+
- num_labels (int): Number of labels/classes.
151+
152+
Returns:
153+
- float: Computed custom metric score.
154+
"""
127155
balanced_acc_custom = BalancedAccuracy(num_labels=num_labels)
128156
return balanced_acc_custom(preds, labels).item()
129157

tests/testCustomMacroF1Metric.py

Lines changed: 50 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -7,24 +7,27 @@
77

88

99
class TestCustomMacroF1Metric(unittest.TestCase):
10-
1110
@classmethod
1211
def setUpClass(cls) -> None:
1312
cls.device = "cuda" if torch.cuda.is_available() else "cpu"
1413

15-
def test_all_predictions_are_1_half_labels_are_1(self):
16-
"""Test custom metric against standard metric for the scenario where all prediction are 1 but only half of
17-
the labels are 1"""
14+
def test_all_predictions_are_1_half_labels_are_1(self) -> None:
15+
"""
16+
Test custom metric against standard metric for the scenario where all predictions are 1
17+
but only half of the labels are 1.
18+
"""
1819
preds = torch.ones((1, 900), dtype=torch.int)
1920
label = torch.ones((1, 900), dtype=torch.int)
2021

22+
# Randomly set half of the labels to 0
2123
mask = [
2224
[True] * (label.size(1) // 2)
2325
+ [False] * (label.size(1) - (label.size(1) // 2))
2426
]
2527
random.shuffle(mask[0])
2628
label[torch.tensor(mask)] = 0
2729

30+
# Get custom and standard metric scores
2831
macro_f1_custom_score, macro_f1_standard_score = (
2932
self.__get_custom_and_standard_metric_scores(label.shape[1], preds, label)
3033
)
@@ -52,25 +55,29 @@ def test_all_predictions_are_1_half_labels_are_1(self):
5255
# precision = [1, 1, 1, 1, 1] / [1, 1, 1, 1, 1] = [1, 1, 1, 1, 1]
5356
# recall = [1, 1, 1, 1, 1] / [1, 1, 1, 1, 1] = [1, 1, 1, 1, 1]
5457
# classwise_f1 = [2, 2, 2, 2, 2] / [2, 2, 2, 2, 2] = [1, 1, 1, 1, 1]
55-
# mean = 5/5 = 1 (because of masking we averaging with across positive labels only)
58+
# mean = 5/5 = 1 (because of masking we're averaging with across positive labels only)
5659
self.assertAlmostEqual(macro_f1_custom_score, 1, places=4)
5760
self.assertNotAlmostEqual(
5861
macro_f1_custom_score, macro_f1_standard_score, places=4
5962
)
6063

61-
def test_all_labels_are_1_half_predictions_are_1(self):
62-
"""Test custom metric against standard metric for the scenario where all labels are 1 but only half of
63-
the predictions are 1"""
64+
def test_all_labels_are_1_half_predictions_are_1(self) -> None:
65+
"""
66+
Test custom metric against standard metric for the scenario where all labels are 1
67+
but only half of the predictions are 1.
68+
"""
6469
preds = torch.ones((1, 900), dtype=torch.int)
6570
label = torch.ones((1, 900), dtype=torch.int)
6671

72+
# Randomly set half of the predictions to 0
6773
mask = [
6874
[True] * (label.size(1) // 2)
6975
+ [False] * (label.size(1) - (label.size(1) // 2))
7076
]
7177
random.shuffle(mask[0])
7278
preds[torch.tensor(mask)] = 0
7379

80+
# Get custom and standard metric scores
7481
macro_f1_custom_score, macro_f1_standard_score = (
7582
self.__get_custom_and_standard_metric_scores(label.shape[1], preds, label)
7683
)
@@ -79,9 +86,11 @@ def test_all_labels_are_1_half_predictions_are_1(self):
7986
# and since all labels are positive in this scenario, custom and std metric are same
8087
self.assertAlmostEqual(macro_f1_custom_score, macro_f1_standard_score, places=4)
8188

82-
def test_iterative_vs_single_call_approach(self):
83-
"""Test the custom metric implementation in update fashion approach against
84-
the single call approach"""
89+
def test_iterative_vs_single_call_approach(self) -> None:
90+
"""
91+
Test the custom metric implementation in update fashion approach against
92+
the single call approach.
93+
"""
8594
preds = torch.tensor([[1, 1, 0, 1], [1, 0, 1, 1], [0, 1, 0, 1]])
8695
label = torch.tensor([[0, 0, 0, 0], [0, 0, 1, 1], [0, 1, 0, 1]])
8796

@@ -94,24 +103,27 @@ def test_iterative_vs_single_call_approach(self):
94103
single_call_custom_metric = MacroF1(num_labels=num_labels)
95104
single_call_custom_metric_score = single_call_custom_metric(preds, label).item()
96105

106+
# Assert iterative and single call approaches give the same metric score
97107
self.assertEqual(iterative_custom_metric_score, single_call_custom_metric_score)
98108

99-
def test_metric_against_realistic_data(self):
100-
"""Test the custom metric against the standard on realistic data"""
109+
def test_metric_against_realistic_data(self) -> None:
110+
"""
111+
Test the custom metric against the standard on realistic data.
112+
"""
101113
directory_path = os.path.join("tests", "test_data", "CheBIOver100_test")
102114
abs_path = os.path.join(os.getcwd(), directory_path)
103115
print(f"Checking data from - {abs_path}")
104116
num_of_files = len(os.listdir(abs_path)) // 2
105117

106-
# load single file to get the num of labels for metric class instantiation
118+
# Load single file to get the number of labels for metric class instantiation
107119
labels = torch.load(
108120
f"{directory_path}/labels{0:03d}.pt", map_location=torch.device(self.device)
109121
)
110122
num_labels = labels.shape[1]
111123
macro_f1_custom = MacroF1(num_labels=num_labels)
112124
macro_f1_standard = MultilabelF1Score(num_labels=num_labels, average="macro")
113125

114-
# load each file in the directory and update the stats
126+
# Load each file in the directory and update the metrics
115127
for i in range(num_of_files):
116128
labels = torch.load(
117129
f"{directory_path}/labels{i:03d}.pt",
@@ -130,14 +142,19 @@ def test_metric_against_realistic_data(self):
130142
f"Realistic Data - Custom F1 score: {macro_f1_custom_score}, Std. F1 score: {macro_f1_standard_score}"
131143
)
132144

145+
# Assert custom metric score is not equal to standard metric score
133146
self.assertNotAlmostEqual(
134147
macro_f1_custom_score, macro_f1_standard_score, places=4
135148
)
136149

137-
def test_case_when_few_class_has_no_labels(self):
138-
"""Test custom metric against standard metric for the scenario where some class has no labels"""
150+
def test_case_when_few_class_has_no_labels(self) -> None:
151+
"""
152+
Test custom metric against standard metric for the scenario where some class has no labels.
153+
"""
139154
preds = torch.tensor([[1, 1, 0, 1], [1, 0, 1, 1], [0, 1, 0, 1]])
140155
label = torch.tensor([[0, 0, 0, 0], [0, 0, 1, 1], [0, 1, 0, 1]])
156+
157+
# Get custom and standard metric scores
141158
macro_f1_custom_score, macro_f1_standard_score = (
142159
self.__get_custom_and_standard_metric_scores(label.shape[1], preds, label)
143160
)
@@ -170,7 +187,22 @@ def test_case_when_few_class_has_no_labels(self):
170187
)
171188

172189
@staticmethod
173-
def __get_custom_and_standard_metric_scores(num_labels, preds, labels):
190+
def __get_custom_and_standard_metric_scores(
191+
num_labels: int, preds: torch.Tensor, labels: torch.Tensor
192+
) -> tuple:
193+
"""
194+
Helper method to calculate custom and standard macro F1 scores.
195+
196+
Args:
197+
num_labels (int): Number of labels/classes.
198+
preds (torch.Tensor): Predicted tensor of shape (batch_size, num_labels).
199+
labels (torch.Tensor): True labels tensor of shape (batch_size, num_labels).
200+
201+
Returns:
202+
tuple: A tuple containing two floats:
203+
- macro_f1_custom_score: Custom macro F1 score.
204+
- macro_f1_standard_score: Standard macro F1 score.
205+
"""
174206
# Custom metric score
175207
macro_f1_custom = MacroF1(num_labels=num_labels)
176208
macro_f1_custom_score = macro_f1_custom(preds, labels).item()

0 commit comments

Comments
 (0)