1+ import unittest
2+ import os
3+ from pathlib import Path
4+ import torch
5+ import configparser
6+
7+
8+ class TestChebiData (unittest .TestCase ):
9+
10+ @classmethod
11+ def setUpClass (cls ) -> None :
12+ cls .getChebiDataConfig ()
13+ cls .getDataSplitsOverlaps ()
14+
15+ @classmethod
16+ def getChebiDataConfig (cls ):
17+ """Import the respective class and instantiate with given version from the config"""
18+ config = configparser .ConfigParser ()
19+ config_file_path = Path (os .path .join (os .getcwd (), "tests/config_chebi_data.ini" ))
20+ config .read (config_file_path )
21+
22+ class_name = config .get ('ChebiData' , 'chebi_class_name' )
23+ version_number = config .get ('ChebiData' , 'version_number' )
24+
25+ module = __import__ ('chebai.preprocessing.datasets.chebi' , fromlist = [class_name ])
26+ class_ = getattr (module , class_name )
27+
28+ cls .chebi_class = class_ (chebi_version = version_number )
29+
30+ @classmethod
31+ def getDataSplitsOverlaps (cls ):
32+ """Get the overlap between data splits"""
33+ processed_path = os .path .join (os .getcwd (), cls .chebi_class .processed_dir )
34+ print (f"Checking Data from - { processed_path } " )
35+
36+ train_set = torch .load (os .path .join (processed_path , "train.pt" ))
37+ val_set = torch .load (os .path .join (processed_path , "validation.pt" ))
38+ test_set = torch .load (os .path .join (processed_path , "test.pt" ))
39+
40+ train_smiles , train_smiles_ids = cls .get_features_ids (train_set )
41+ val_smiles , val_smiles_ids = cls .get_features_ids (val_set )
42+ test_smiles , test_smiles_ids = cls .get_features_ids (test_set )
43+
44+ # ----- Get the overlap between data splits based on smiles tokens/features -----
45+
46+ # train_smiles.append(val_smiles[0])
47+ # train_smiles.append(test_smiles[0])
48+ # val_smiles.append(test_smiles[0])
49+
50+ cls .overlaps_train_val = cls .get_overlaps (train_smiles , val_smiles )
51+ cls .overlaps_train_test = cls .get_overlaps (train_smiles , test_smiles )
52+ cls .overlaps_val_test = cls .get_overlaps (val_smiles , test_smiles )
53+
54+ # ----- Get the overlap between data splits based on IDs -----
55+
56+ # train_smiles_ids.append(val_smiles_ids[0])
57+ # train_smiles_ids.append(test_smiles_ids[0])
58+ # val_smiles_ids.append(test_smiles_ids[0])
59+
60+ cls .overlaps_train_val_ids = cls .get_overlaps (train_smiles_ids , val_smiles_ids )
61+ cls .overlaps_train_test_ids = cls .get_overlaps (train_smiles_ids , test_smiles_ids )
62+ cls .overlaps_val_test_ids = cls .get_overlaps (val_smiles_ids , test_smiles_ids )
63+
64+ @staticmethod
65+ def get_features_ids (data_split ):
66+ """Returns SMILES features/tokens and SMILES IDs from the data"""
67+ smiles_features , smiles_ids = [], []
68+ for entry in data_split :
69+ smiles_features .append (entry ["features" ])
70+ smiles_ids .append (entry ["ident" ])
71+
72+ return smiles_features , smiles_ids
73+
74+ @staticmethod
75+ def get_overlaps (list_1 , list_2 ):
76+ overlap = []
77+ for element in list_1 :
78+ if element in list_2 :
79+ overlap .append (element )
80+ return overlap
81+
82+ @unittest .expectedFailure
83+ def test_train_val_overlap_based_on_smiles (self ):
84+ """Check that train-val splits are performed correctly i.e.every entity
85+ only appears in one of the train and validation set based on smiles tokens/features"""
86+ self .assertEqual (len (self .overlaps_train_val ), 0 , "Duplicate entities present in Train and Validation set based on SMILES" )
87+
88+ @unittest .expectedFailure
89+ def test_train_test_overlap_based_on_smiles (self ):
90+ """Check that train-test splits are performed correctly i.e.every entity
91+ only appears in one of the train and test set based on smiles tokens/features"""
92+ self .assertEqual (len (self .overlaps_train_test ), 0 , "Duplicate entities present in Train and Test set based on SMILES" )
93+
94+ @unittest .expectedFailure
95+ def test_val_test_overlap_based_on_smiles (self ):
96+ """Check that val-test splits are performed correctly i.e.every entity
97+ only appears in one of the validation and test set based on smiles tokens/features"""
98+ self .assertEqual (len (self .overlaps_val_test ), 0 , "Duplicate entities present in Validation and Test set based on SMILES" )
99+
100+ def test_train_val_overlap_based_on_ids (self ):
101+ """Check that train-val splits are performed correctly i.e.every entity
102+ only appears in one of the train and validation set based on smiles IDs"""
103+ self .assertEqual (len (self .overlaps_train_val_ids ), 0 , "Duplicate entities present in Train and Validation set based on IDs" )
104+
105+ def test_train_test_overlap_based_on_ids (self ):
106+ """Check that train-test splits are performed correctly i.e.every entity
107+ only appears in one of the train and test set based on smiles IDs"""
108+ self .assertEqual (len (self .overlaps_train_test_ids ), 0 , "Duplicate entities present in Train and Test set based on IDs" )
109+
110+ def test_val_test_overlap_based_on_ids (self ):
111+ """Check that val-test splits are performed correctly i.e.every entity
112+ only appears in one of the validation and test set based on smiles IDs"""
113+ self .assertEqual (len (self .overlaps_val_test_ids ), 0 , "Duplicate entities present in Validation and Test set based on IDs" )
114+
115+
116+ if __name__ == '__main__' :
117+ unittest .main ()
0 commit comments