Skip to content

Commit 16fe335

Browse files
authored
Merge pull request #21 from ChEB-AI/feature/testing_framework
- added testing framework
2 parents 814ba37 + 2e8ef97 commit 16fe335

File tree

6 files changed

+395
-1
lines changed

6 files changed

+395
-1
lines changed

configs/data/chebi50.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
class_path: chebai.preprocessing.datasets.chebi.ChEBIOver50
1+
class_path: chebai.preprocessing.datasets.chebi.ChEBIOver50

setup.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,8 @@
4747
"deepsmiles",
4848
"iterative-stratification",
4949
"wandb",
50+
"chardet",
51+
"yaml",
5052
],
5153
extras_require={"dev": ["black", "isort", "pre-commit"]},
5254
)

tests/__init__.py

Whitespace-only changes.

tests/testChebiData.py

Lines changed: 140 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,140 @@
1+
import unittest
2+
import os
3+
import torch
4+
import yaml
5+
6+
7+
class TestChebiData(unittest.TestCase):
8+
9+
@classmethod
10+
def setUpClass(cls) -> None:
11+
cls.getDataSplitsOverlaps()
12+
13+
@classmethod
14+
def getChebiDataConfig(cls):
15+
"""Import the respective class and instantiate with given version from the config"""
16+
CONFIG_FILE_NAME = "chebi50.yml"
17+
with open(
18+
os.path.join("configs", "data", f"{CONFIG_FILE_NAME}"), "r"
19+
) as yaml_file:
20+
config = yaml.safe_load(yaml_file)
21+
22+
class_path = config["class_path"]
23+
init_args = config.get("init_args", {})
24+
25+
module, class_name = class_path.rsplit(".", 1)
26+
module = __import__(module, fromlist=[class_name])
27+
class_ = getattr(module, class_name)
28+
29+
return class_(**init_args)
30+
31+
@classmethod
32+
def getDataSplitsOverlaps(cls):
33+
"""Get the overlap between data splits"""
34+
processed_path = os.path.join(
35+
os.getcwd(), cls.getChebiDataConfig().processed_dir
36+
)
37+
print(f"Checking Data from - {processed_path}")
38+
39+
train_set = torch.load(os.path.join(processed_path, "train.pt"))
40+
val_set = torch.load(os.path.join(processed_path, "validation.pt"))
41+
test_set = torch.load(os.path.join(processed_path, "test.pt"))
42+
43+
train_smiles, train_smiles_ids = cls.get_features_ids(train_set)
44+
val_smiles, val_smiles_ids = cls.get_features_ids(val_set)
45+
test_smiles, test_smiles_ids = cls.get_features_ids(test_set)
46+
47+
# ----- Get the overlap between data splits based on smiles tokens/features -----
48+
49+
cls.overlaps_train_val = cls.get_overlaps(train_smiles, val_smiles)
50+
cls.overlaps_train_test = cls.get_overlaps(train_smiles, test_smiles)
51+
cls.overlaps_val_test = cls.get_overlaps(val_smiles, test_smiles)
52+
53+
# ----- Get the overlap between data splits based on IDs -----
54+
55+
cls.overlaps_train_val_ids = cls.get_overlaps(train_smiles_ids, val_smiles_ids)
56+
cls.overlaps_train_test_ids = cls.get_overlaps(
57+
train_smiles_ids, test_smiles_ids
58+
)
59+
cls.overlaps_val_test_ids = cls.get_overlaps(val_smiles_ids, test_smiles_ids)
60+
61+
@staticmethod
62+
def get_features_ids(data_split):
63+
"""Returns SMILES features/tokens and SMILES IDs from the data"""
64+
smiles_features, smiles_ids = [], []
65+
for entry in data_split:
66+
smiles_features.append(entry["features"])
67+
smiles_ids.append(entry["ident"])
68+
69+
return smiles_features, smiles_ids
70+
71+
@staticmethod
72+
def get_overlaps(list_1, list_2):
73+
overlap = []
74+
for element in list_1:
75+
if element in list_2:
76+
overlap.append(element)
77+
return overlap
78+
79+
@unittest.expectedFailure
80+
def test_train_val_overlap_based_on_smiles(self):
81+
"""Check that train-val splits are performed correctly i.e.every entity
82+
only appears in one of the train and validation set based on smiles tokens/features
83+
"""
84+
self.assertEqual(
85+
len(self.overlaps_train_val),
86+
0,
87+
"Duplicate entities present in Train and Validation set based on SMILES",
88+
)
89+
90+
@unittest.expectedFailure
91+
def test_train_test_overlap_based_on_smiles(self):
92+
"""Check that train-test splits are performed correctly i.e.every entity
93+
only appears in one of the train and test set based on smiles tokens/features"""
94+
self.assertEqual(
95+
len(self.overlaps_train_test),
96+
0,
97+
"Duplicate entities present in Train and Test set based on SMILES",
98+
)
99+
100+
@unittest.expectedFailure
101+
def test_val_test_overlap_based_on_smiles(self):
102+
"""Check that val-test splits are performed correctly i.e.every entity
103+
only appears in one of the validation and test set based on smiles tokens/features
104+
"""
105+
self.assertEqual(
106+
len(self.overlaps_val_test),
107+
0,
108+
"Duplicate entities present in Validation and Test set based on SMILES",
109+
)
110+
111+
def test_train_val_overlap_based_on_ids(self):
112+
"""Check that train-val splits are performed correctly i.e.every entity
113+
only appears in one of the train and validation set based on smiles IDs"""
114+
self.assertEqual(
115+
len(self.overlaps_train_val_ids),
116+
0,
117+
"Duplicate entities present in Train and Validation set based on IDs",
118+
)
119+
120+
def test_train_test_overlap_based_on_ids(self):
121+
"""Check that train-test splits are performed correctly i.e.every entity
122+
only appears in one of the train and test set based on smiles IDs"""
123+
self.assertEqual(
124+
len(self.overlaps_train_test_ids),
125+
0,
126+
"Duplicate entities present in Train and Test set based on IDs",
127+
)
128+
129+
def test_val_test_overlap_based_on_ids(self):
130+
"""Check that val-test splits are performed correctly i.e.every entity
131+
only appears in one of the validation and test set based on smiles IDs"""
132+
self.assertEqual(
133+
len(self.overlaps_val_test_ids),
134+
0,
135+
"Duplicate entities present in Validation and Test set based on IDs",
136+
)
137+
138+
139+
if __name__ == "__main__":
140+
unittest.main()

tests/testPubChemData.py

Lines changed: 126 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,126 @@
1+
import unittest
2+
import os
3+
import torch
4+
from chebai.preprocessing.datasets.pubchem import PubChem
5+
6+
7+
class TestPubChemData(unittest.TestCase):
8+
9+
@classmethod
10+
def setUpClass(cls) -> None:
11+
cls.pubChem = PubChem()
12+
cls.getDataSplitsOverlaps()
13+
14+
@classmethod
15+
def getDataSplitsOverlaps(cls):
16+
"""Get the overlap between data splits"""
17+
processed_path = os.path.join(os.getcwd(), cls.pubChem.processed_dir)
18+
print(f"Checking Data from - {processed_path}")
19+
20+
train_set = torch.load(os.path.join(processed_path, "train.pt"))
21+
val_set = torch.load(os.path.join(processed_path, "validation.pt"))
22+
test_set = torch.load(os.path.join(processed_path, "test.pt"))
23+
24+
train_smiles, train_smiles_ids = cls.get_features_ids(train_set)
25+
val_smiles, val_smiles_ids = cls.get_features_ids(val_set)
26+
test_smiles, test_smiles_ids = cls.get_features_ids(test_set)
27+
28+
# ----- Get the overlap between data splits based on smiles tokens/features -----
29+
30+
# train_smiles.append(val_smiles[0])
31+
# train_smiles.append(test_smiles[0])
32+
# val_smiles.append(test_smiles[0])
33+
34+
cls.overlaps_train_val = cls.get_overlaps(train_smiles, val_smiles)
35+
cls.overlaps_train_test = cls.get_overlaps(train_smiles, test_smiles)
36+
cls.overlaps_val_test = cls.get_overlaps(val_smiles, test_smiles)
37+
38+
# ----- Get the overlap between data splits based on IDs -----
39+
40+
# train_smiles_ids.append(val_smiles_ids[0])
41+
# train_smiles_ids.append(test_smiles_ids[0])
42+
# val_smiles_ids.append(test_smiles_ids[0])
43+
44+
cls.overlaps_train_val_ids = cls.get_overlaps(train_smiles_ids, val_smiles_ids)
45+
cls.overlaps_train_test_ids = cls.get_overlaps(
46+
train_smiles_ids, test_smiles_ids
47+
)
48+
cls.overlaps_val_test_ids = cls.get_overlaps(val_smiles_ids, test_smiles_ids)
49+
50+
@staticmethod
51+
def get_features_ids(data_split):
52+
"""Returns SMILES features/tokens and SMILES IDs from the data"""
53+
smiles_features, smiles_ids = [], []
54+
for entry in data_split:
55+
smiles_features.append(entry["features"])
56+
smiles_ids.append(entry["ident"])
57+
58+
return smiles_features, smiles_ids
59+
60+
@staticmethod
61+
def get_overlaps(list_1, list_2):
62+
overlap = []
63+
for element in list_1:
64+
if element in list_2:
65+
overlap.append(element)
66+
return overlap
67+
68+
def test_train_val_overlap_based_on_smiles(self):
69+
"""Check that train-val splits are performed correctly i.e.every entity
70+
only appears in one of the train and validation set based on smiles tokens/features
71+
"""
72+
self.assertEqual(
73+
len(self.overlaps_train_val),
74+
0,
75+
"Duplicate entities present in Train and Validation set based on SMILES",
76+
)
77+
78+
def test_train_test_overlap_based_on_smiles(self):
79+
"""Check that train-test splits are performed correctly i.e.every entity
80+
only appears in one of the train and test set based on smiles tokens/features"""
81+
self.assertEqual(
82+
len(self.overlaps_train_test),
83+
0,
84+
"Duplicate entities present in Train and Test set based on SMILES",
85+
)
86+
87+
def test_val_test_overlap_based_on_smiles(self):
88+
"""Check that val-test splits are performed correctly i.e.every entity
89+
only appears in one of the validation and test set based on smiles tokens/features
90+
"""
91+
self.assertEqual(
92+
len(self.overlaps_val_test),
93+
0,
94+
"Duplicate entities present in Validation and Test set based on SMILES",
95+
)
96+
97+
def test_train_val_overlap_based_on_ids(self):
98+
"""Check that train-val splits are performed correctly i.e.every entity
99+
only appears in one of the train and validation set based on smiles IDs"""
100+
self.assertEqual(
101+
len(self.overlaps_train_val_ids),
102+
0,
103+
"Duplicate entities present in Train and Validation set based on IDs",
104+
)
105+
106+
def test_train_test_overlap_based_on_ids(self):
107+
"""Check that train-test splits are performed correctly i.e.every entity
108+
only appears in one of the train and test set based on smiles IDs"""
109+
self.assertEqual(
110+
len(self.overlaps_train_test_ids),
111+
0,
112+
"Duplicate entities present in Train and Test set based on IDs",
113+
)
114+
115+
def test_val_test_overlap_based_on_ids(self):
116+
"""Check that val-test splits are performed correctly i.e.every entity
117+
only appears in one of the validation and test set based on smiles IDs"""
118+
self.assertEqual(
119+
len(self.overlaps_val_test_ids),
120+
0,
121+
"Duplicate entities present in Validation and Test set based on IDs",
122+
)
123+
124+
125+
if __name__ == "__main__":
126+
unittest.main()

0 commit comments

Comments
 (0)