Skip to content

Commit 4d37e3b

Browse files
committed
chebi tests : docstring + typehints
1 parent 69d5636 commit 4d37e3b

File tree

2 files changed

+144
-48
lines changed

2 files changed

+144
-48
lines changed

tests/testChebiData.py

Lines changed: 70 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,36 @@
11
import unittest
2+
from typing import List
3+
4+
import pandas as pd
25

36
from chebai.preprocessing.datasets.chebi import ChEBIOver50
47

58

69
class TestChebiData(unittest.TestCase):
10+
"""
11+
Test case for ChEBIOver50 dataset integrity, focusing on data splits and overlaps.
12+
13+
Attributes:
14+
overlaps_train_val (List): List of overlapping entities between train and validation splits based on SMILES.
15+
overlaps_train_test (List): List of overlapping entities between train and test splits based on SMILES.
16+
overlaps_val_test (List): List of overlapping entities between validation and test splits based on SMILES.
17+
overlaps_train_val_ids (List): List of overlapping entity IDs between train and validation splits.
18+
overlaps_train_test_ids (List): List of overlapping entity IDs between train and test splits.
19+
overlaps_val_test_ids (List): List of overlapping entity IDs between validation and test splits.
20+
"""
721

822
@classmethod
923
def setUpClass(cls) -> None:
24+
"""
25+
Set up class method to initialize ChEBIOver50 instance and generate data splits and overlaps.
26+
"""
1027
cls.getDataSplitsOverlaps()
1128

1229
@classmethod
13-
def getDataSplitsOverlaps(cls):
14-
"""Get the overlap between data splits"""
30+
def getDataSplitsOverlaps(cls) -> None:
31+
"""
32+
Get the overlap between data splits based on SMILES and IDs.
33+
"""
1534
chebi_class_obj = ChEBIOver50()
1635
# Get the raw/processed data if missing
1736
chebi_class_obj.prepare_data()
@@ -25,37 +44,56 @@ def getDataSplitsOverlaps(cls):
2544
val_smiles, val_smiles_ids = cls.get_features_ids(val_set)
2645
test_smiles, test_smiles_ids = cls.get_features_ids(test_set)
2746

28-
# ----- Get the overlap between data splits based on smiles tokens/features -----
47+
# Get the overlap between data splits based on smiles tokens/features
2948
cls.overlaps_train_val = cls.get_overlaps(train_smiles, val_smiles)
3049
cls.overlaps_train_test = cls.get_overlaps(train_smiles, test_smiles)
3150
cls.overlaps_val_test = cls.get_overlaps(val_smiles, test_smiles)
3251

33-
# ----- Get the overlap between data splits based on IDs -----
52+
# Get the overlap between data splits based on IDs
3453
cls.overlaps_train_val_ids = cls.get_overlaps(train_smiles_ids, val_smiles_ids)
3554
cls.overlaps_train_test_ids = cls.get_overlaps(
3655
train_smiles_ids, test_smiles_ids
3756
)
3857
cls.overlaps_val_test_ids = cls.get_overlaps(val_smiles_ids, test_smiles_ids)
3958

4059
@staticmethod
41-
def get_features_ids(data_split_df):
42-
"""Returns SMILES features/tokens and SMILES IDs from the data"""
60+
def get_features_ids(data_split_df: pd.DataFrame) -> tuple[List, List]:
61+
"""
62+
Returns SMILES features/tokens and SMILES IDs from the data.
63+
64+
Args:
65+
data_split_df: DataFrame containing data to extract features and IDs from.
66+
67+
Returns:
68+
Tuple of lists: SMILES features/tokens list and SMILES IDs list.
69+
"""
4370
smiles_features = data_split_df["features"].tolist()
4471
smiles_ids = data_split_df["ident"].tolist()
4572

4673
return smiles_features, smiles_ids
4774

4875
@staticmethod
49-
def get_overlaps(list_1, list_2):
76+
def get_overlaps(list_1: List, list_2: List) -> List:
77+
"""
78+
Get overlaps between two lists.
79+
80+
Args:
81+
list_1: First list.
82+
list_2: Second list.
83+
84+
Returns:
85+
List: List of elements present in both lists.
86+
"""
5087
overlap = []
5188
for element in list_1:
5289
if element in list_2:
5390
overlap.append(element)
5491
return overlap
5592

5693
@unittest.expectedFailure
57-
def test_train_val_overlap_based_on_smiles(self):
58-
"""Check that train-val splits are performed correctly i.e.every entity
94+
def test_train_val_overlap_based_on_smiles(self) -> None:
95+
"""
96+
Check that train-val splits are performed correctly i.e.every entity
5997
only appears in one of the train and validation set based on smiles tokens/features
6098
"""
6199
self.assertEqual(
@@ -65,18 +103,21 @@ def test_train_val_overlap_based_on_smiles(self):
65103
)
66104

67105
@unittest.expectedFailure
68-
def test_train_test_overlap_based_on_smiles(self):
69-
"""Check that train-test splits are performed correctly i.e.every entity
70-
only appears in one of the train and test set based on smiles tokens/features"""
106+
def test_train_test_overlap_based_on_smiles(self) -> None:
107+
"""
108+
Check that train-test splits are performed correctly i.e.every entity
109+
only appears in one of the train and test set based on smiles tokens/features
110+
"""
71111
self.assertEqual(
72112
len(self.overlaps_train_test),
73113
0,
74114
"Duplicate entities present in Train and Test set based on SMILES",
75115
)
76116

77117
@unittest.expectedFailure
78-
def test_val_test_overlap_based_on_smiles(self):
79-
"""Check that val-test splits are performed correctly i.e.every entity
118+
def test_val_test_overlap_based_on_smiles(self) -> None:
119+
"""
120+
Check that val-test splits are performed correctly i.e.every entity
80121
only appears in one of the validation and test set based on smiles tokens/features
81122
"""
82123
self.assertEqual(
@@ -85,27 +126,33 @@ def test_val_test_overlap_based_on_smiles(self):
85126
"Duplicate entities present in Validation and Test set based on SMILES",
86127
)
87128

88-
def test_train_val_overlap_based_on_ids(self):
89-
"""Check that train-val splits are performed correctly i.e.every entity
90-
only appears in one of the train and validation set based on smiles IDs"""
129+
def test_train_val_overlap_based_on_ids(self) -> None:
130+
"""
131+
Check that train-val splits are performed correctly i.e.every entity
132+
only appears in one of the train and validation set based on smiles IDs
133+
"""
91134
self.assertEqual(
92135
len(self.overlaps_train_val_ids),
93136
0,
94137
"Duplicate entities present in Train and Validation set based on IDs",
95138
)
96139

97-
def test_train_test_overlap_based_on_ids(self):
98-
"""Check that train-test splits are performed correctly i.e.every entity
99-
only appears in one of the train and test set based on smiles IDs"""
140+
def test_train_test_overlap_based_on_ids(self) -> None:
141+
"""
142+
Check that train-test splits are performed correctly i.e.every entity
143+
only appears in one of the train and test set based on smiles IDs
144+
"""
100145
self.assertEqual(
101146
len(self.overlaps_train_test_ids),
102147
0,
103148
"Duplicate entities present in Train and Test set based on IDs",
104149
)
105150

106-
def test_val_test_overlap_based_on_ids(self):
107-
"""Check that val-test splits are performed correctly i.e.every entity
108-
only appears in one of the validation and test set based on smiles IDs"""
151+
def test_val_test_overlap_based_on_ids(self) -> None:
152+
"""
153+
Check that val-test splits are performed correctly i.e.every entity
154+
only appears in one of the validation and test set based on smiles IDs
155+
"""
109156
self.assertEqual(
110157
len(self.overlaps_val_test_ids),
111158
0,

tests/testChebiDynamicDataSplits.py

Lines changed: 74 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import hashlib
22
import unittest
3+
from typing import Any, List, Tuple
34

45
import numpy as np
56
import pandas as pd
@@ -8,20 +9,30 @@
89

910

1011
class TestChebiDynamicDataSplits(unittest.TestCase):
11-
"""Test dynamic splits implementation's consistency"""
12+
"""
13+
Test dynamic splits implementation's consistency for ChEBIOver50 dataset.
14+
15+
Attributes:
16+
chebi_50_v231 (ChEBIOver50): Instance of ChEBIOver50 with ChEBI version 231.
17+
chebi_50_v231_vt200 (ChEBIOver50): Instance of ChEBIOver50 with ChEBI version 231 and train version 200.
18+
"""
1219

1320
@classmethod
14-
def setUpClass(cls):
21+
def setUpClass(cls) -> None:
22+
"""
23+
Set up class method to initialize instances of ChEBIOver50 and generate data.
24+
"""
1525
cls.chebi_50_v231 = ChEBIOver50(chebi_version=231)
1626
cls.chebi_50_v231_vt200 = ChEBIOver50(
1727
chebi_version=231, chebi_version_train=200
1828
)
1929
cls._generate_chebi_class_data(cls.chebi_50_v231)
2030
cls._generate_chebi_class_data(cls.chebi_50_v231_vt200)
2131

22-
def testDynamicDataSplitsConsistency(self):
23-
"""Test Dynamic Data Splits consistency across every run"""
24-
32+
def testDynamicDataSplitsConsistency(self) -> None:
33+
"""
34+
Test Dynamic Data Splits consistency across multiple runs.
35+
"""
2536
# Dynamic Data Splits in Run 1
2637
train_hash_1, val_hash_1, test_hash_1 = self._get_hashed_splits()
2738

@@ -34,9 +45,10 @@ def testDynamicDataSplitsConsistency(self):
3445
self.assertEqual(val_hash_1, val_hash_2, "Validation data hashes do not match.")
3546
self.assertEqual(test_hash_1, test_hash_2, "Test data hashes do not match.")
3647

37-
def test_same_ids_and_in_test_sets(self):
38-
"""Check if test sets of both classes have same IDs"""
39-
48+
def test_same_ids_and_in_test_sets(self) -> None:
49+
"""
50+
Check if test sets of both classes have the same IDs.
51+
"""
4052
v231_ids = set(self.chebi_50_v231.dynamic_split_dfs["test"]["ident"])
4153
v231_vt200_ids = set(
4254
self.chebi_50_v231_vt200.dynamic_split_dfs["test"]["ident"]
@@ -46,9 +58,10 @@ def test_same_ids_and_in_test_sets(self):
4658
v231_ids, v231_vt200_ids, "Test sets do not have the same IDs."
4759
)
4860

49-
def test_labels_vector_size_in_test_sets(self):
50-
"""Check if test sets of both classes have different size/shape of labels"""
51-
61+
def test_labels_vector_size_in_test_sets(self) -> None:
62+
"""
63+
Check if test sets of both classes have different sizes/shapes of labels.
64+
"""
5265
v231_labels_shape = len(
5366
self.chebi_50_v231.dynamic_split_dfs["test"]["labels"].iloc[0]
5467
)
@@ -59,11 +72,13 @@ def test_labels_vector_size_in_test_sets(self):
5972
self.assertEqual(
6073
v231_labels_shape,
6174
v231_vt200_label_shape,
62-
"Test sets have the different size of labels",
75+
"Test sets have different sizes of labels",
6376
)
6477

65-
def test_no_overlaps_in_chebi_v231_vt200(self):
66-
"""Test the overlaps for the ChEBIOver50(chebi_version=231, chebi_version_train=200)"""
78+
def test_no_overlaps_in_chebi_v231_vt200(self) -> None:
79+
"""
80+
Test the overlaps for the ChEBIOver50(chebi_version=231, chebi_version_train=200) dataset.
81+
"""
6782
train_set = self.chebi_50_v231_vt200.dynamic_split_dfs["train"]
6883
val_set = self.chebi_50_v231_vt200.dynamic_split_dfs["validation"]
6984
test_set = self.chebi_50_v231_vt200.dynamic_split_dfs["test"]
@@ -72,7 +87,7 @@ def test_no_overlaps_in_chebi_v231_vt200(self):
7287
val_set_ids = val_set["ident"].tolist()
7388
test_set_ids = test_set["ident"].tolist()
7489

75-
# ----- Get the overlap between data splits based on IDs -----
90+
# Get the overlap between data splits based on IDs
7691
self.overlaps_train_val_ids = self.get_overlaps(train_set_ids, val_set_ids)
7792
self.overlaps_train_test_ids = self.get_overlaps(train_set_ids, test_set_ids)
7893
self.overlaps_val_test_ids = self.get_overlaps(val_set_ids, test_set_ids)
@@ -93,10 +108,13 @@ def test_no_overlaps_in_chebi_v231_vt200(self):
93108
"Duplicate entities present in Validation and Test set based on IDs",
94109
)
95110

96-
def _get_hashed_splits(self):
97-
"""Returns hashed dynamic data splits"""
111+
def _get_hashed_splits(self) -> Tuple[str, str, str]:
112+
"""
113+
Returns hashed dynamic data splits.
98114
99-
# Get the raw/processed data if missing
115+
Returns:
116+
Tuple[str, str, str]: Hashes for train, validation, and test data splits.
117+
"""
100118
chebi_class_obj = self.chebi_50_v231
101119

102120
# Get dynamic splits from class variables
@@ -112,16 +130,32 @@ def _get_hashed_splits(self):
112130
return train_hash, val_hash, test_hash
113131

114132
@staticmethod
115-
def compute_hash(data):
116-
"""Returns hash for the given data partition"""
133+
def compute_hash(data: pd.DataFrame) -> str:
134+
"""
135+
Returns hash for the given data partition.
136+
137+
Args:
138+
data (pd.DataFrame): DataFrame containing data to be hashed.
139+
140+
Returns:
141+
str: Hash computed for the DataFrame.
142+
"""
117143
data_for_hashing = data.map(TestChebiDynamicDataSplits.convert_to_hashable)
118144
return hashlib.md5(
119145
pd.util.hash_pandas_object(data_for_hashing, index=True).values
120146
).hexdigest()
121147

122148
@staticmethod
123-
def convert_to_hashable(item):
124-
"""To Convert lists and numpy arrays within the DataFrame to tuples for hashing"""
149+
def convert_to_hashable(item: Any) -> Any:
150+
"""
151+
Convert lists and numpy arrays within the DataFrame to tuples for hashing.
152+
153+
Args:
154+
item (Any): Item to convert to a hashable form.
155+
156+
Returns:
157+
Any: Hashable representation of the input item.
158+
"""
125159
if isinstance(item, list):
126160
return tuple(item)
127161
elif isinstance(item, np.ndarray):
@@ -130,13 +164,28 @@ def convert_to_hashable(item):
130164
return item
131165

132166
@staticmethod
133-
def _generate_chebi_class_data(chebi_class_obj):
134-
# Get the raw/processed data if missing
167+
def _generate_chebi_class_data(chebi_class_obj: ChEBIOver50) -> None:
168+
"""
169+
Generate ChEBI class data if not already generated.
170+
171+
Args:
172+
chebi_class_obj (ChEBIOver50): Instance of ChEBIOver50 class.
173+
"""
135174
chebi_class_obj.prepare_data()
136175
chebi_class_obj.setup()
137176

138177
@staticmethod
139-
def get_overlaps(list_1, list_2):
178+
def get_overlaps(list_1: List[Any], list_2: List[Any]) -> List[Any]:
179+
"""
180+
Get overlaps between two lists.
181+
182+
Args:
183+
list_1 (List[Any]): First list.
184+
list_2 (List[Any]): Second list.
185+
186+
Returns:
187+
List[Any]: List of elements present in both lists.
188+
"""
140189
overlap = []
141190
for element in list_1:
142191
if element in list_2:

0 commit comments

Comments
 (0)