Skip to content

Commit 270ca46

Browse files
committed
added minor suggested changes
1 parent 8c5842b commit 270ca46

File tree

5 files changed

+128
-42
lines changed

5 files changed

+128
-42
lines changed

configs/data/chebi50.yml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1,3 @@
1-
class_path: chebai.preprocessing.datasets.chebi.ChEBIOver50
1+
class_path: chebai.preprocessing.datasets.chebi.ChEBIOver50
2+
init_args:
3+
chebi_version: 231

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@
4848
"iterative-stratification",
4949
"wandb",
5050
"chardet",
51-
"configparser"
51+
"yaml",
5252
],
5353
extras_require={"dev": ["black", "isort", "pre-commit"]},
5454
)

tests/testChebiData.py

Lines changed: 48 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,7 @@
11
import unittest
22
import os
3-
from pathlib import Path
43
import torch
5-
import configparser
4+
import yaml
65

76

87
class TestChebiData(unittest.TestCase):
@@ -15,17 +14,18 @@ def setUpClass(cls) -> None:
1514
@classmethod
1615
def getChebiDataConfig(cls):
1716
"""Import the respective class and instantiate with given version from the config"""
18-
config = configparser.ConfigParser()
19-
config_file_path = Path(os.path.join(os.getcwd(), "tests/config_chebi_data.ini"))
20-
config.read(config_file_path)
17+
CONFIG_FILE_NAME = "chebi50.yml"
18+
with open(f"configS/data/{CONFIG_FILE_NAME}", "r") as yaml_file:
19+
config = yaml.safe_load(yaml_file)
2120

22-
class_name = config.get('ChebiData', 'chebi_class_name')
23-
version_number = config.get('ChebiData', 'version_number')
21+
class_path = config["class_path"]
22+
init_args = config.get("init_args", {})
2423

25-
module = __import__('chebai.preprocessing.datasets.chebi', fromlist=[class_name])
24+
module, class_name = class_path.rsplit(".", 1)
25+
module = __import__(module, fromlist=[class_name])
2626
class_ = getattr(module, class_name)
2727

28-
cls.chebi_class = class_(chebi_version=version_number)
28+
cls.chebi_class = class_(**init_args)
2929

3030
@classmethod
3131
def getDataSplitsOverlaps(cls):
@@ -58,7 +58,9 @@ def getDataSplitsOverlaps(cls):
5858
# val_smiles_ids.append(test_smiles_ids[0])
5959

6060
cls.overlaps_train_val_ids = cls.get_overlaps(train_smiles_ids, val_smiles_ids)
61-
cls.overlaps_train_test_ids = cls.get_overlaps(train_smiles_ids, test_smiles_ids)
61+
cls.overlaps_train_test_ids = cls.get_overlaps(
62+
train_smiles_ids, test_smiles_ids
63+
)
6264
cls.overlaps_val_test_ids = cls.get_overlaps(val_smiles_ids, test_smiles_ids)
6365

6466
@staticmethod
@@ -82,36 +84,62 @@ def get_overlaps(list_1, list_2):
8284
@unittest.expectedFailure
8385
def test_train_val_overlap_based_on_smiles(self):
8486
"""Check that train-val splits are performed correctly i.e.every entity
85-
only appears in one of the train and validation set based on smiles tokens/features"""
86-
self.assertEqual(len(self.overlaps_train_val), 0, "Duplicate entities present in Train and Validation set based on SMILES")
87+
only appears in one of the train and validation set based on smiles tokens/features
88+
"""
89+
self.assertEqual(
90+
len(self.overlaps_train_val),
91+
0,
92+
"Duplicate entities present in Train and Validation set based on SMILES",
93+
)
8794

8895
@unittest.expectedFailure
8996
def test_train_test_overlap_based_on_smiles(self):
9097
"""Check that train-test splits are performed correctly i.e.every entity
9198
only appears in one of the train and test set based on smiles tokens/features"""
92-
self.assertEqual(len(self.overlaps_train_test), 0, "Duplicate entities present in Train and Test set based on SMILES")
99+
self.assertEqual(
100+
len(self.overlaps_train_test),
101+
0,
102+
"Duplicate entities present in Train and Test set based on SMILES",
103+
)
93104

94105
@unittest.expectedFailure
95106
def test_val_test_overlap_based_on_smiles(self):
96107
"""Check that val-test splits are performed correctly i.e.every entity
97-
only appears in one of the validation and test set based on smiles tokens/features"""
98-
self.assertEqual(len(self.overlaps_val_test), 0, "Duplicate entities present in Validation and Test set based on SMILES")
108+
only appears in one of the validation and test set based on smiles tokens/features
109+
"""
110+
self.assertEqual(
111+
len(self.overlaps_val_test),
112+
0,
113+
"Duplicate entities present in Validation and Test set based on SMILES",
114+
)
99115

100116
def test_train_val_overlap_based_on_ids(self):
101117
"""Check that train-val splits are performed correctly i.e.every entity
102118
only appears in one of the train and validation set based on smiles IDs"""
103-
self.assertEqual(len(self.overlaps_train_val_ids), 0, "Duplicate entities present in Train and Validation set based on IDs")
119+
self.assertEqual(
120+
len(self.overlaps_train_val_ids),
121+
0,
122+
"Duplicate entities present in Train and Validation set based on IDs",
123+
)
104124

105125
def test_train_test_overlap_based_on_ids(self):
106126
"""Check that train-test splits are performed correctly i.e.every entity
107127
only appears in one of the train and test set based on smiles IDs"""
108-
self.assertEqual(len(self.overlaps_train_test_ids), 0, "Duplicate entities present in Train and Test set based on IDs")
128+
self.assertEqual(
129+
len(self.overlaps_train_test_ids),
130+
0,
131+
"Duplicate entities present in Train and Test set based on IDs",
132+
)
109133

110134
def test_val_test_overlap_based_on_ids(self):
111135
"""Check that val-test splits are performed correctly i.e.every entity
112136
only appears in one of the validation and test set based on smiles IDs"""
113-
self.assertEqual(len(self.overlaps_val_test_ids), 0, "Duplicate entities present in Validation and Test set based on IDs")
137+
self.assertEqual(
138+
len(self.overlaps_val_test_ids),
139+
0,
140+
"Duplicate entities present in Validation and Test set based on IDs",
141+
)
114142

115143

116-
if __name__ == '__main__':
117-
unittest.main()
144+
if __name__ == "__main__":
145+
unittest.main()

tests/testPubChemData.py

Lines changed: 38 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,9 @@ def getDataSplitsOverlaps(cls):
4242
# val_smiles_ids.append(test_smiles_ids[0])
4343

4444
cls.overlaps_train_val_ids = cls.get_overlaps(train_smiles_ids, val_smiles_ids)
45-
cls.overlaps_train_test_ids = cls.get_overlaps(train_smiles_ids, test_smiles_ids)
45+
cls.overlaps_train_test_ids = cls.get_overlaps(
46+
train_smiles_ids, test_smiles_ids
47+
)
4648
cls.overlaps_val_test_ids = cls.get_overlaps(val_smiles_ids, test_smiles_ids)
4749

4850
@staticmethod
@@ -65,34 +67,60 @@ def get_overlaps(list_1, list_2):
6567

6668
def test_train_val_overlap_based_on_smiles(self):
6769
"""Check that train-val splits are performed correctly i.e.every entity
68-
only appears in one of the train and validation set based on smiles tokens/features"""
69-
self.assertEqual(len(self.overlaps_train_val), 0, "Duplicate entities present in Train and Validation set based on SMILES")
70+
only appears in one of the train and validation set based on smiles tokens/features
71+
"""
72+
self.assertEqual(
73+
len(self.overlaps_train_val),
74+
0,
75+
"Duplicate entities present in Train and Validation set based on SMILES",
76+
)
7077

7178
def test_train_test_overlap_based_on_smiles(self):
7279
"""Check that train-test splits are performed correctly i.e.every entity
7380
only appears in one of the train and test set based on smiles tokens/features"""
74-
self.assertEqual(len(self.overlaps_train_test), 0, "Duplicate entities present in Train and Test set based on SMILES")
81+
self.assertEqual(
82+
len(self.overlaps_train_test),
83+
0,
84+
"Duplicate entities present in Train and Test set based on SMILES",
85+
)
7586

7687
def test_val_test_overlap_based_on_smiles(self):
7788
"""Check that val-test splits are performed correctly i.e.every entity
78-
only appears in one of the validation and test set based on smiles tokens/features"""
79-
self.assertEqual(len(self.overlaps_val_test), 0, "Duplicate entities present in Validation and Test set based on SMILES")
89+
only appears in one of the validation and test set based on smiles tokens/features
90+
"""
91+
self.assertEqual(
92+
len(self.overlaps_val_test),
93+
0,
94+
"Duplicate entities present in Validation and Test set based on SMILES",
95+
)
8096

8197
def test_train_val_overlap_based_on_ids(self):
8298
"""Check that train-val splits are performed correctly i.e.every entity
8399
only appears in one of the train and validation set based on smiles IDs"""
84-
self.assertEqual(len(self.overlaps_train_val_ids), 0, "Duplicate entities present in Train and Validation set based on IDs")
100+
self.assertEqual(
101+
len(self.overlaps_train_val_ids),
102+
0,
103+
"Duplicate entities present in Train and Validation set based on IDs",
104+
)
85105

86106
def test_train_test_overlap_based_on_ids(self):
87107
"""Check that train-test splits are performed correctly i.e.every entity
88108
only appears in one of the train and test set based on smiles IDs"""
89-
self.assertEqual(len(self.overlaps_train_test_ids), 0, "Duplicate entities present in Train and Test set based on IDs")
109+
self.assertEqual(
110+
len(self.overlaps_train_test_ids),
111+
0,
112+
"Duplicate entities present in Train and Test set based on IDs",
113+
)
90114

91115
def test_val_test_overlap_based_on_ids(self):
92116
"""Check that val-test splits are performed correctly i.e.every entity
93117
only appears in one of the validation and test set based on smiles IDs"""
94-
self.assertEqual(len(self.overlaps_val_test_ids), 0, "Duplicate entities present in Validation and Test set based on IDs")
118+
self.assertEqual(
119+
len(self.overlaps_val_test_ids),
120+
0,
121+
"Duplicate entities present in Validation and Test set based on IDs",
122+
)
95123

96124

97-
if __name__ == '__main__':
125+
if __name__ == "__main__":
98126
unittest.main()

tests/testTox21MolNetData.py

Lines changed: 38 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,9 @@ def getDataSplitsOverlaps(cls):
4242
# val_smiles_ids.append(test_smiles_ids[0])
4343

4444
cls.overlaps_train_val_ids = cls.get_overlaps(train_smiles_ids, val_smiles_ids)
45-
cls.overlaps_train_test_ids = cls.get_overlaps(train_smiles_ids, test_smiles_ids)
45+
cls.overlaps_train_test_ids = cls.get_overlaps(
46+
train_smiles_ids, test_smiles_ids
47+
)
4648
cls.overlaps_val_test_ids = cls.get_overlaps(val_smiles_ids, test_smiles_ids)
4749

4850
@staticmethod
@@ -65,34 +67,60 @@ def get_overlaps(list_1, list_2):
6567

6668
def test_train_val_overlap_based_on_smiles(self):
6769
"""Check that train-val splits are performed correctly i.e.every entity
68-
only appears in one of the train and validation set based on smiles tokens/features"""
69-
self.assertEqual(len(self.overlaps_train_val), 0, "Duplicate entities present in Train and Validation set based on SMILES")
70+
only appears in one of the train and validation set based on smiles tokens/features
71+
"""
72+
self.assertEqual(
73+
len(self.overlaps_train_val),
74+
0,
75+
"Duplicate entities present in Train and Validation set based on SMILES",
76+
)
7077

7178
def test_train_test_overlap_based_on_smiles(self):
7279
"""Check that train-test splits are performed correctly i.e.every entity
7380
only appears in one of the train and test set based on smiles tokens/features"""
74-
self.assertEqual(len(self.overlaps_train_test), 0, "Duplicate entities present in Train and Test set based on SMILES")
81+
self.assertEqual(
82+
len(self.overlaps_train_test),
83+
0,
84+
"Duplicate entities present in Train and Test set based on SMILES",
85+
)
7586

7687
def test_val_test_overlap_based_on_smiles(self):
7788
"""Check that val-test splits are performed correctly i.e.every entity
78-
only appears in one of the validation and test set based on smiles tokens/features"""
79-
self.assertEqual(len(self.overlaps_val_test), 0, "Duplicate entities present in Validation and Test set based on SMILES")
89+
only appears in one of the validation and test set based on smiles tokens/features
90+
"""
91+
self.assertEqual(
92+
len(self.overlaps_val_test),
93+
0,
94+
"Duplicate entities present in Validation and Test set based on SMILES",
95+
)
8096

8197
def test_train_val_overlap_based_on_ids(self):
8298
"""Check that train-val splits are performed correctly i.e.every entity
8399
only appears in one of the train and validation set based on smiles IDs"""
84-
self.assertEqual(len(self.overlaps_train_val_ids), 0, "Duplicate entities present in Train and Validation set based on IDs")
100+
self.assertEqual(
101+
len(self.overlaps_train_val_ids),
102+
0,
103+
"Duplicate entities present in Train and Validation set based on IDs",
104+
)
85105

86106
def test_train_test_overlap_based_on_ids(self):
87107
"""Check that train-test splits are performed correctly i.e.every entity
88108
only appears in one of the train and test set based on smiles IDs"""
89-
self.assertEqual(len(self.overlaps_train_test_ids), 0, "Duplicate entities present in Train and Test set based on IDs")
109+
self.assertEqual(
110+
len(self.overlaps_train_test_ids),
111+
0,
112+
"Duplicate entities present in Train and Test set based on IDs",
113+
)
90114

91115
def test_val_test_overlap_based_on_ids(self):
92116
"""Check that val-test splits are performed correctly i.e.every entity
93117
only appears in one of the validation and test set based on smiles IDs"""
94-
self.assertEqual(len(self.overlaps_val_test_ids), 0, "Duplicate entities present in Validation and Test set based on IDs")
118+
self.assertEqual(
119+
len(self.overlaps_val_test_ids),
120+
0,
121+
"Duplicate entities present in Validation and Test set based on IDs",
122+
)
95123

96124

97-
if __name__ == '__main__':
125+
if __name__ == "__main__":
98126
unittest.main()

0 commit comments

Comments
 (0)