Skip to content

Commit a366733

Browse files
author
sfluegel
committed
clean up chebi setup
1 parent 3915e32 commit a366733

File tree

1 file changed

+7
-12
lines changed

1 file changed

+7
-12
lines changed

tests/testChebiData.py

Lines changed: 7 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -8,14 +8,15 @@ class TestChebiData(unittest.TestCase):
88

99
@classmethod
1010
def setUpClass(cls) -> None:
11-
cls.getChebiDataConfig()
1211
cls.getDataSplitsOverlaps()
1312

1413
@classmethod
1514
def getChebiDataConfig(cls):
1615
"""Import the respective class and instantiate with given version from the config"""
1716
CONFIG_FILE_NAME = "chebi50.yml"
18-
with open(f"configS/data/{CONFIG_FILE_NAME}", "r") as yaml_file:
17+
with open(
18+
os.path.join("configs", "data", f"{CONFIG_FILE_NAME}"), "r"
19+
) as yaml_file:
1920
config = yaml.safe_load(yaml_file)
2021

2122
class_path = config["class_path"]
@@ -25,12 +26,14 @@ def getChebiDataConfig(cls):
2526
module = __import__(module, fromlist=[class_name])
2627
class_ = getattr(module, class_name)
2728

28-
cls.chebi_class = class_(**init_args)
29+
return class_(**init_args)
2930

3031
@classmethod
3132
def getDataSplitsOverlaps(cls):
3233
"""Get the overlap between data splits"""
33-
processed_path = os.path.join(os.getcwd(), cls.chebi_class.processed_dir)
34+
processed_path = os.path.join(
35+
os.getcwd(), cls.getChebiDataConfig().processed_dir
36+
)
3437
print(f"Checking Data from - {processed_path}")
3538

3639
train_set = torch.load(os.path.join(processed_path, "train.pt"))
@@ -43,20 +46,12 @@ def getDataSplitsOverlaps(cls):
4346

4447
# ----- Get the overlap between data splits based on smiles tokens/features -----
4548

46-
# train_smiles.append(val_smiles[0])
47-
# train_smiles.append(test_smiles[0])
48-
# val_smiles.append(test_smiles[0])
49-
5049
cls.overlaps_train_val = cls.get_overlaps(train_smiles, val_smiles)
5150
cls.overlaps_train_test = cls.get_overlaps(train_smiles, test_smiles)
5251
cls.overlaps_val_test = cls.get_overlaps(val_smiles, test_smiles)
5352

5453
# ----- Get the overlap between data splits based on IDs -----
5554

56-
# train_smiles_ids.append(val_smiles_ids[0])
57-
# train_smiles_ids.append(test_smiles_ids[0])
58-
# val_smiles_ids.append(test_smiles_ids[0])
59-
6055
cls.overlaps_train_val_ids = cls.get_overlaps(train_smiles_ids, val_smiles_ids)
6156
cls.overlaps_train_test_ids = cls.get_overlaps(
6257
train_smiles_ids, test_smiles_ids

0 commit comments

Comments
 (0)