@@ -8,14 +8,15 @@ class TestChebiData(unittest.TestCase):
88
99 @classmethod
1010 def setUpClass (cls ) -> None :
11- cls .getChebiDataConfig ()
1211 cls .getDataSplitsOverlaps ()
1312
1413 @classmethod
1514 def getChebiDataConfig (cls ):
1615 """Import the respective class and instantiate with given version from the config"""
1716 CONFIG_FILE_NAME = "chebi50.yml"
18- with open (f"configS/data/{ CONFIG_FILE_NAME } " , "r" ) as yaml_file :
17+ with open (
18+ os .path .join ("configs" , "data" , f"{ CONFIG_FILE_NAME } " ), "r"
19+ ) as yaml_file :
1920 config = yaml .safe_load (yaml_file )
2021
2122 class_path = config ["class_path" ]
@@ -25,12 +26,14 @@ def getChebiDataConfig(cls):
2526 module = __import__ (module , fromlist = [class_name ])
2627 class_ = getattr (module , class_name )
2728
28- cls . chebi_class = class_ (** init_args )
29+ return class_ (** init_args )
2930
3031 @classmethod
3132 def getDataSplitsOverlaps (cls ):
3233 """Get the overlap between data splits"""
33- processed_path = os .path .join (os .getcwd (), cls .chebi_class .processed_dir )
34+ processed_path = os .path .join (
35+ os .getcwd (), cls .getChebiDataConfig ().processed_dir
36+ )
3437 print (f"Checking Data from - { processed_path } " )
3538
3639 train_set = torch .load (os .path .join (processed_path , "train.pt" ))
@@ -43,20 +46,12 @@ def getDataSplitsOverlaps(cls):
4346
4447 # ----- Get the overlap between data splits based on smiles tokens/features -----
4548
46- # train_smiles.append(val_smiles[0])
47- # train_smiles.append(test_smiles[0])
48- # val_smiles.append(test_smiles[0])
49-
5049 cls .overlaps_train_val = cls .get_overlaps (train_smiles , val_smiles )
5150 cls .overlaps_train_test = cls .get_overlaps (train_smiles , test_smiles )
5251 cls .overlaps_val_test = cls .get_overlaps (val_smiles , test_smiles )
5352
5453 # ----- Get the overlap between data splits based on IDs -----
5554
56- # train_smiles_ids.append(val_smiles_ids[0])
57- # train_smiles_ids.append(test_smiles_ids[0])
58- # val_smiles_ids.append(test_smiles_ids[0])
59-
6055 cls .overlaps_train_val_ids = cls .get_overlaps (train_smiles_ids , val_smiles_ids )
6156 cls .overlaps_train_test_ids = cls .get_overlaps (
6257 train_smiles_ids , test_smiles_ids
0 commit comments