Skip to content

Commit 8c5842b

Browse files
committed
- added testing framework
1 parent 81df748 commit 8c5842b

File tree

6 files changed

+320
-0
lines changed

6 files changed

+320
-0
lines changed

setup.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,8 @@
4747
"deepsmiles",
4848
"iterative-stratification",
4949
"wandb",
50+
"chardet",
51+
"configparser"
5052
],
5153
extras_require={"dev": ["black", "isort", "pre-commit"]},
5254
)

tests/__init__.py

Whitespace-only changes.

tests/config_chebi_data.ini

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
;Config for Chebi Data Test
2+
3+
[ChebiData]
4+
chebi_class_name=ChEBIOver50
5+
version_number=231

tests/testChebiData.py

Lines changed: 117 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,117 @@
1+
import unittest
2+
import os
3+
from pathlib import Path
4+
import torch
5+
import configparser
6+
7+
8+
class TestChebiData(unittest.TestCase):
9+
10+
@classmethod
11+
def setUpClass(cls) -> None:
12+
cls.getChebiDataConfig()
13+
cls.getDataSplitsOverlaps()
14+
15+
@classmethod
16+
def getChebiDataConfig(cls):
17+
"""Import the respective class and instantiate with given version from the config"""
18+
config = configparser.ConfigParser()
19+
config_file_path = Path(os.path.join(os.getcwd(), "tests/config_chebi_data.ini"))
20+
config.read(config_file_path)
21+
22+
class_name = config.get('ChebiData', 'chebi_class_name')
23+
version_number = config.get('ChebiData', 'version_number')
24+
25+
module = __import__('chebai.preprocessing.datasets.chebi', fromlist=[class_name])
26+
class_ = getattr(module, class_name)
27+
28+
cls.chebi_class = class_(chebi_version=version_number)
29+
30+
@classmethod
31+
def getDataSplitsOverlaps(cls):
32+
"""Get the overlap between data splits"""
33+
processed_path = os.path.join(os.getcwd(), cls.chebi_class.processed_dir)
34+
print(f"Checking Data from - {processed_path}")
35+
36+
train_set = torch.load(os.path.join(processed_path, "train.pt"))
37+
val_set = torch.load(os.path.join(processed_path, "validation.pt"))
38+
test_set = torch.load(os.path.join(processed_path, "test.pt"))
39+
40+
train_smiles, train_smiles_ids = cls.get_features_ids(train_set)
41+
val_smiles, val_smiles_ids = cls.get_features_ids(val_set)
42+
test_smiles, test_smiles_ids = cls.get_features_ids(test_set)
43+
44+
# ----- Get the overlap between data splits based on smiles tokens/features -----
45+
46+
# train_smiles.append(val_smiles[0])
47+
# train_smiles.append(test_smiles[0])
48+
# val_smiles.append(test_smiles[0])
49+
50+
cls.overlaps_train_val = cls.get_overlaps(train_smiles, val_smiles)
51+
cls.overlaps_train_test = cls.get_overlaps(train_smiles, test_smiles)
52+
cls.overlaps_val_test = cls.get_overlaps(val_smiles, test_smiles)
53+
54+
# ----- Get the overlap between data splits based on IDs -----
55+
56+
# train_smiles_ids.append(val_smiles_ids[0])
57+
# train_smiles_ids.append(test_smiles_ids[0])
58+
# val_smiles_ids.append(test_smiles_ids[0])
59+
60+
cls.overlaps_train_val_ids = cls.get_overlaps(train_smiles_ids, val_smiles_ids)
61+
cls.overlaps_train_test_ids = cls.get_overlaps(train_smiles_ids, test_smiles_ids)
62+
cls.overlaps_val_test_ids = cls.get_overlaps(val_smiles_ids, test_smiles_ids)
63+
64+
@staticmethod
65+
def get_features_ids(data_split):
66+
"""Returns SMILES features/tokens and SMILES IDs from the data"""
67+
smiles_features, smiles_ids = [], []
68+
for entry in data_split:
69+
smiles_features.append(entry["features"])
70+
smiles_ids.append(entry["ident"])
71+
72+
return smiles_features, smiles_ids
73+
74+
@staticmethod
75+
def get_overlaps(list_1, list_2):
76+
overlap = []
77+
for element in list_1:
78+
if element in list_2:
79+
overlap.append(element)
80+
return overlap
81+
82+
@unittest.expectedFailure
83+
def test_train_val_overlap_based_on_smiles(self):
84+
"""Check that train-val splits are performed correctly i.e.every entity
85+
only appears in one of the train and validation set based on smiles tokens/features"""
86+
self.assertEqual(len(self.overlaps_train_val), 0, "Duplicate entities present in Train and Validation set based on SMILES")
87+
88+
@unittest.expectedFailure
89+
def test_train_test_overlap_based_on_smiles(self):
90+
"""Check that train-test splits are performed correctly i.e.every entity
91+
only appears in one of the train and test set based on smiles tokens/features"""
92+
self.assertEqual(len(self.overlaps_train_test), 0, "Duplicate entities present in Train and Test set based on SMILES")
93+
94+
@unittest.expectedFailure
95+
def test_val_test_overlap_based_on_smiles(self):
96+
"""Check that val-test splits are performed correctly i.e.every entity
97+
only appears in one of the validation and test set based on smiles tokens/features"""
98+
self.assertEqual(len(self.overlaps_val_test), 0, "Duplicate entities present in Validation and Test set based on SMILES")
99+
100+
def test_train_val_overlap_based_on_ids(self):
101+
"""Check that train-val splits are performed correctly i.e.every entity
102+
only appears in one of the train and validation set based on smiles IDs"""
103+
self.assertEqual(len(self.overlaps_train_val_ids), 0, "Duplicate entities present in Train and Validation set based on IDs")
104+
105+
def test_train_test_overlap_based_on_ids(self):
106+
"""Check that train-test splits are performed correctly i.e.every entity
107+
only appears in one of the train and test set based on smiles IDs"""
108+
self.assertEqual(len(self.overlaps_train_test_ids), 0, "Duplicate entities present in Train and Test set based on IDs")
109+
110+
def test_val_test_overlap_based_on_ids(self):
111+
"""Check that val-test splits are performed correctly i.e.every entity
112+
only appears in one of the validation and test set based on smiles IDs"""
113+
self.assertEqual(len(self.overlaps_val_test_ids), 0, "Duplicate entities present in Validation and Test set based on IDs")
114+
115+
116+
if __name__ == '__main__':
117+
unittest.main()

tests/testPubChemData.py

Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,98 @@
1+
import unittest
2+
import os
3+
import torch
4+
from chebai.preprocessing.datasets.pubchem import PubChem
5+
6+
7+
class TestPubChemData(unittest.TestCase):
8+
9+
@classmethod
10+
def setUpClass(cls) -> None:
11+
cls.pubChem = PubChem()
12+
cls.getDataSplitsOverlaps()
13+
14+
@classmethod
15+
def getDataSplitsOverlaps(cls):
16+
"""Get the overlap between data splits"""
17+
processed_path = os.path.join(os.getcwd(), cls.pubChem.processed_dir)
18+
print(f"Checking Data from - {processed_path}")
19+
20+
train_set = torch.load(os.path.join(processed_path, "train.pt"))
21+
val_set = torch.load(os.path.join(processed_path, "validation.pt"))
22+
test_set = torch.load(os.path.join(processed_path, "test.pt"))
23+
24+
train_smiles, train_smiles_ids = cls.get_features_ids(train_set)
25+
val_smiles, val_smiles_ids = cls.get_features_ids(val_set)
26+
test_smiles, test_smiles_ids = cls.get_features_ids(test_set)
27+
28+
# ----- Get the overlap between data splits based on smiles tokens/features -----
29+
30+
# train_smiles.append(val_smiles[0])
31+
# train_smiles.append(test_smiles[0])
32+
# val_smiles.append(test_smiles[0])
33+
34+
cls.overlaps_train_val = cls.get_overlaps(train_smiles, val_smiles)
35+
cls.overlaps_train_test = cls.get_overlaps(train_smiles, test_smiles)
36+
cls.overlaps_val_test = cls.get_overlaps(val_smiles, test_smiles)
37+
38+
# ----- Get the overlap between data splits based on IDs -----
39+
40+
# train_smiles_ids.append(val_smiles_ids[0])
41+
# train_smiles_ids.append(test_smiles_ids[0])
42+
# val_smiles_ids.append(test_smiles_ids[0])
43+
44+
cls.overlaps_train_val_ids = cls.get_overlaps(train_smiles_ids, val_smiles_ids)
45+
cls.overlaps_train_test_ids = cls.get_overlaps(train_smiles_ids, test_smiles_ids)
46+
cls.overlaps_val_test_ids = cls.get_overlaps(val_smiles_ids, test_smiles_ids)
47+
48+
@staticmethod
49+
def get_features_ids(data_split):
50+
"""Returns SMILES features/tokens and SMILES IDs from the data"""
51+
smiles_features, smiles_ids = [], []
52+
for entry in data_split:
53+
smiles_features.append(entry["features"])
54+
smiles_ids.append(entry["ident"])
55+
56+
return smiles_features, smiles_ids
57+
58+
@staticmethod
59+
def get_overlaps(list_1, list_2):
60+
overlap = []
61+
for element in list_1:
62+
if element in list_2:
63+
overlap.append(element)
64+
return overlap
65+
66+
def test_train_val_overlap_based_on_smiles(self):
67+
"""Check that train-val splits are performed correctly i.e.every entity
68+
only appears in one of the train and validation set based on smiles tokens/features"""
69+
self.assertEqual(len(self.overlaps_train_val), 0, "Duplicate entities present in Train and Validation set based on SMILES")
70+
71+
def test_train_test_overlap_based_on_smiles(self):
72+
"""Check that train-test splits are performed correctly i.e.every entity
73+
only appears in one of the train and test set based on smiles tokens/features"""
74+
self.assertEqual(len(self.overlaps_train_test), 0, "Duplicate entities present in Train and Test set based on SMILES")
75+
76+
def test_val_test_overlap_based_on_smiles(self):
77+
"""Check that val-test splits are performed correctly i.e.every entity
78+
only appears in one of the validation and test set based on smiles tokens/features"""
79+
self.assertEqual(len(self.overlaps_val_test), 0, "Duplicate entities present in Validation and Test set based on SMILES")
80+
81+
def test_train_val_overlap_based_on_ids(self):
82+
"""Check that train-val splits are performed correctly i.e.every entity
83+
only appears in one of the train and validation set based on smiles IDs"""
84+
self.assertEqual(len(self.overlaps_train_val_ids), 0, "Duplicate entities present in Train and Validation set based on IDs")
85+
86+
def test_train_test_overlap_based_on_ids(self):
87+
"""Check that train-test splits are performed correctly i.e.every entity
88+
only appears in one of the train and test set based on smiles IDs"""
89+
self.assertEqual(len(self.overlaps_train_test_ids), 0, "Duplicate entities present in Train and Test set based on IDs")
90+
91+
def test_val_test_overlap_based_on_ids(self):
92+
"""Check that val-test splits are performed correctly i.e.every entity
93+
only appears in one of the validation and test set based on smiles IDs"""
94+
self.assertEqual(len(self.overlaps_val_test_ids), 0, "Duplicate entities present in Validation and Test set based on IDs")
95+
96+
97+
if __name__ == '__main__':
98+
unittest.main()

tests/testTox21MolNetData.py

Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,98 @@
1+
import unittest
2+
import os
3+
import torch
4+
from chebai.preprocessing.datasets.tox21 import Tox21MolNetChem
5+
6+
7+
class TestPubChemData(unittest.TestCase):
8+
9+
@classmethod
10+
def setUpClass(cls) -> None:
11+
cls.tox21 = Tox21MolNetChem()
12+
cls.getDataSplitsOverlaps()
13+
14+
@classmethod
15+
def getDataSplitsOverlaps(cls):
16+
"""Get the overlap between data splits"""
17+
processed_path = os.path.join(os.getcwd(), cls.tox21.processed_dir)
18+
print(f"Checking Data from - {processed_path}")
19+
20+
train_set = torch.load(os.path.join(processed_path, "train.pt"))
21+
val_set = torch.load(os.path.join(processed_path, "validation.pt"))
22+
test_set = torch.load(os.path.join(processed_path, "test.pt"))
23+
24+
train_smiles, train_smiles_ids = cls.get_features_ids(train_set)
25+
val_smiles, val_smiles_ids = cls.get_features_ids(val_set)
26+
test_smiles, test_smiles_ids = cls.get_features_ids(test_set)
27+
28+
# ----- Get the overlap between data splits based on smiles tokens/features -----
29+
30+
# train_smiles.append(val_smiles[0])
31+
# train_smiles.append(test_smiles[0])
32+
# val_smiles.append(test_smiles[0])
33+
34+
cls.overlaps_train_val = cls.get_overlaps(train_smiles, val_smiles)
35+
cls.overlaps_train_test = cls.get_overlaps(train_smiles, test_smiles)
36+
cls.overlaps_val_test = cls.get_overlaps(val_smiles, test_smiles)
37+
38+
# ----- Get the overlap between data splits based on IDs -----
39+
40+
# train_smiles_ids.append(val_smiles_ids[0])
41+
# train_smiles_ids.append(test_smiles_ids[0])
42+
# val_smiles_ids.append(test_smiles_ids[0])
43+
44+
cls.overlaps_train_val_ids = cls.get_overlaps(train_smiles_ids, val_smiles_ids)
45+
cls.overlaps_train_test_ids = cls.get_overlaps(train_smiles_ids, test_smiles_ids)
46+
cls.overlaps_val_test_ids = cls.get_overlaps(val_smiles_ids, test_smiles_ids)
47+
48+
@staticmethod
49+
def get_features_ids(data_split):
50+
"""Returns SMILES features/tokens and SMILES IDs from the data"""
51+
smiles_features, smiles_ids = [], []
52+
for entry in data_split:
53+
smiles_features.append(entry["features"])
54+
smiles_ids.append(entry["ident"])
55+
56+
return smiles_features, smiles_ids
57+
58+
@staticmethod
59+
def get_overlaps(list_1, list_2):
60+
overlap = []
61+
for element in list_1:
62+
if element in list_2:
63+
overlap.append(element)
64+
return overlap
65+
66+
def test_train_val_overlap_based_on_smiles(self):
67+
"""Check that train-val splits are performed correctly i.e.every entity
68+
only appears in one of the train and validation set based on smiles tokens/features"""
69+
self.assertEqual(len(self.overlaps_train_val), 0, "Duplicate entities present in Train and Validation set based on SMILES")
70+
71+
def test_train_test_overlap_based_on_smiles(self):
72+
"""Check that train-test splits are performed correctly i.e.every entity
73+
only appears in one of the train and test set based on smiles tokens/features"""
74+
self.assertEqual(len(self.overlaps_train_test), 0, "Duplicate entities present in Train and Test set based on SMILES")
75+
76+
def test_val_test_overlap_based_on_smiles(self):
77+
"""Check that val-test splits are performed correctly i.e.every entity
78+
only appears in one of the validation and test set based on smiles tokens/features"""
79+
self.assertEqual(len(self.overlaps_val_test), 0, "Duplicate entities present in Validation and Test set based on SMILES")
80+
81+
def test_train_val_overlap_based_on_ids(self):
82+
"""Check that train-val splits are performed correctly i.e.every entity
83+
only appears in one of the train and validation set based on smiles IDs"""
84+
self.assertEqual(len(self.overlaps_train_val_ids), 0, "Duplicate entities present in Train and Validation set based on IDs")
85+
86+
def test_train_test_overlap_based_on_ids(self):
87+
"""Check that train-test splits are performed correctly i.e.every entity
88+
only appears in one of the train and test set based on smiles IDs"""
89+
self.assertEqual(len(self.overlaps_train_test_ids), 0, "Duplicate entities present in Train and Test set based on IDs")
90+
91+
def test_val_test_overlap_based_on_ids(self):
92+
"""Check that val-test splits are performed correctly i.e.every entity
93+
only appears in one of the validation and test set based on smiles IDs"""
94+
self.assertEqual(len(self.overlaps_val_test_ids), 0, "Duplicate entities present in Validation and Test set based on IDs")
95+
96+
97+
if __name__ == '__main__':
98+
unittest.main()

0 commit comments

Comments
 (0)