11import unittest
22import os
3- from pathlib import Path
43import torch
5- import configparser
4+ import yaml
65
76
87class TestChebiData (unittest .TestCase ):
@@ -15,17 +14,18 @@ def setUpClass(cls) -> None:
1514 @classmethod
1615 def getChebiDataConfig (cls ):
1716 """Import the respective class and instantiate with given version from the config"""
18- config = configparser . ConfigParser ()
19- config_file_path = Path ( os . path . join ( os . getcwd () , "tests/config_chebi_data.ini" ))
20- config . read ( config_file_path )
17+ CONFIG_FILE_NAME = "chebi50.yml"
18+ with open ( f"configS/data/ { CONFIG_FILE_NAME } " , "r" ) as yaml_file :
19+ config = yaml . safe_load ( yaml_file )
2120
22- class_name = config . get ( 'ChebiData' , 'chebi_class_name' )
23- version_number = config .get ('ChebiData' , 'version_number' )
21+ class_path = config [ "class_path" ]
22+ init_args = config .get ("init_args" , {} )
2423
25- module = __import__ ('chebai.preprocessing.datasets.chebi' , fromlist = [class_name ])
24+ module , class_name = class_path .rsplit ("." , 1 )
25+ module = __import__ (module , fromlist = [class_name ])
2626 class_ = getattr (module , class_name )
2727
28- cls .chebi_class = class_ (chebi_version = version_number )
28+ cls .chebi_class = class_ (** init_args )
2929
3030 @classmethod
3131 def getDataSplitsOverlaps (cls ):
@@ -58,7 +58,9 @@ def getDataSplitsOverlaps(cls):
5858 # val_smiles_ids.append(test_smiles_ids[0])
5959
6060 cls .overlaps_train_val_ids = cls .get_overlaps (train_smiles_ids , val_smiles_ids )
61- cls .overlaps_train_test_ids = cls .get_overlaps (train_smiles_ids , test_smiles_ids )
61+ cls .overlaps_train_test_ids = cls .get_overlaps (
62+ train_smiles_ids , test_smiles_ids
63+ )
6264 cls .overlaps_val_test_ids = cls .get_overlaps (val_smiles_ids , test_smiles_ids )
6365
6466 @staticmethod
@@ -82,36 +84,62 @@ def get_overlaps(list_1, list_2):
8284 @unittest .expectedFailure
8385 def test_train_val_overlap_based_on_smiles (self ):
8486 """Check that train-val splits are performed correctly i.e.every entity
85- only appears in one of the train and validation set based on smiles tokens/features"""
86- self .assertEqual (len (self .overlaps_train_val ), 0 , "Duplicate entities present in Train and Validation set based on SMILES" )
87+ only appears in one of the train and validation set based on smiles tokens/features
88+ """
89+ self .assertEqual (
90+ len (self .overlaps_train_val ),
91+ 0 ,
92+ "Duplicate entities present in Train and Validation set based on SMILES" ,
93+ )
8794
8895 @unittest .expectedFailure
8996 def test_train_test_overlap_based_on_smiles (self ):
9097 """Check that train-test splits are performed correctly i.e.every entity
9198 only appears in one of the train and test set based on smiles tokens/features"""
92- self .assertEqual (len (self .overlaps_train_test ), 0 , "Duplicate entities present in Train and Test set based on SMILES" )
99+ self .assertEqual (
100+ len (self .overlaps_train_test ),
101+ 0 ,
102+ "Duplicate entities present in Train and Test set based on SMILES" ,
103+ )
93104
94105 @unittest .expectedFailure
95106 def test_val_test_overlap_based_on_smiles (self ):
96107 """Check that val-test splits are performed correctly i.e.every entity
97- only appears in one of the validation and test set based on smiles tokens/features"""
98- self .assertEqual (len (self .overlaps_val_test ), 0 , "Duplicate entities present in Validation and Test set based on SMILES" )
108+ only appears in one of the validation and test set based on smiles tokens/features
109+ """
110+ self .assertEqual (
111+ len (self .overlaps_val_test ),
112+ 0 ,
113+ "Duplicate entities present in Validation and Test set based on SMILES" ,
114+ )
99115
100116 def test_train_val_overlap_based_on_ids (self ):
101117 """Check that train-val splits are performed correctly i.e.every entity
102118 only appears in one of the train and validation set based on smiles IDs"""
103- self .assertEqual (len (self .overlaps_train_val_ids ), 0 , "Duplicate entities present in Train and Validation set based on IDs" )
119+ self .assertEqual (
120+ len (self .overlaps_train_val_ids ),
121+ 0 ,
122+ "Duplicate entities present in Train and Validation set based on IDs" ,
123+ )
104124
105125 def test_train_test_overlap_based_on_ids (self ):
106126 """Check that train-test splits are performed correctly i.e.every entity
107127 only appears in one of the train and test set based on smiles IDs"""
108- self .assertEqual (len (self .overlaps_train_test_ids ), 0 , "Duplicate entities present in Train and Test set based on IDs" )
128+ self .assertEqual (
129+ len (self .overlaps_train_test_ids ),
130+ 0 ,
131+ "Duplicate entities present in Train and Test set based on IDs" ,
132+ )
109133
110134 def test_val_test_overlap_based_on_ids (self ):
111135 """Check that val-test splits are performed correctly i.e.every entity
112136 only appears in one of the validation and test set based on smiles IDs"""
113- self .assertEqual (len (self .overlaps_val_test_ids ), 0 , "Duplicate entities present in Validation and Test set based on IDs" )
137+ self .assertEqual (
138+ len (self .overlaps_val_test_ids ),
139+ 0 ,
140+ "Duplicate entities present in Validation and Test set based on IDs" ,
141+ )
114142
115143
116- if __name__ == ' __main__' :
117- unittest .main ()
144+ if __name__ == " __main__" :
145+ unittest .main ()
0 commit comments