Skip to content

Commit 5c48e4d

Browse files
committed
Added directory for augmented directory for splitting
1 parent 5916420 commit 5c48e4d

File tree

1 file changed

+33
-28
lines changed

1 file changed

+33
-28
lines changed

chebai/preprocessing/datasets/chebi.py

Lines changed: 33 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -449,33 +449,34 @@ def setup_processed(self) -> None:
449449
# )
450450

451451
# Transform the processed data into encoded data
452-
processed_name = self.processed_file_names_dict["data"]
453-
if not os.path.isfile(os.path.join(self.processed_dir, processed_name)):
454-
print(
455-
f"Missing encoded data related to version {self.chebi_version}, transform processed data into encoded data:",
456-
processed_name,
457-
)
458-
torch.save(
459-
self._load_data_from_file(
460-
os.path.join(
461-
self.processed_dir_main,
462-
self.raw_file_names_dict["data"],
463-
)
464-
),
465-
os.path.join(self.processed_dir, processed_name),
466-
)
467-
# Transform the data related to "chebi_version_train" to encoded data, if it doesn't exist
468-
if self.chebi_version_train is not None and not os.path.isfile(
469-
os.path.join(
470-
self._chebi_version_train_obj.processed_dir,
471-
self._chebi_version_train_obj.raw_file_names_dict["data"],
472-
)
473-
):
474-
print(
475-
f"Missing encoded data related to train version: {self.chebi_version_train}"
476-
)
477-
print("Call the setup method related to it")
478-
self._chebi_version_train_obj.setup()
452+
if not self.aug_data:
453+
processed_name = self.processed_file_names_dict["data"]
454+
if not os.path.isfile(os.path.join(self.processed_dir, processed_name)):
455+
print(
456+
f"Missing encoded data related to version {self.chebi_version}, transform processed data into encoded data:",
457+
processed_name,
458+
)
459+
torch.save(
460+
self._load_data_from_file(
461+
os.path.join(
462+
self.processed_dir_main,
463+
self.raw_file_names_dict["data"],
464+
)
465+
),
466+
os.path.join(self.processed_dir, processed_name),
467+
)
468+
# Transform the data related to "chebi_version_train" to encoded data, if it doesn't exist
469+
if self.chebi_version_train is not None and not os.path.isfile(
470+
os.path.join(
471+
self._chebi_version_train_obj.processed_dir,
472+
self._chebi_version_train_obj.raw_file_names_dict["data"],
473+
)
474+
):
475+
print(
476+
f"Missing encoded data related to train version: {self.chebi_version_train}"
477+
)
478+
print("Call the setup method related to it")
479+
self._chebi_version_train_obj.setup()
479480

480481

481482

@@ -810,9 +811,13 @@ def _generate_dynamic_splits(self) -> None:
810811
"""
811812
print("Generate dynamic splits...")
812813
# Load encoded data derived from "chebi_version"
814+
# Determine the directory for loading encoded data based on the aug_data flag
815+
data_dir = self.augmented_dir_main if self.aug_data else self.processed_dir
816+
813817
try:
814818
filename = self.processed_file_names_dict["data"]
815-
data_chebi_version = torch.load(os.path.join(self.processed_dir, filename))
819+
print("Directory:",os.path.join(data_dir, filename))
820+
data_chebi_version = torch.load(os.path.join(data_dir, filename))
816821
except FileNotFoundError:
817822
raise FileNotFoundError(
818823
f"File data.pt doesn't exists. "

0 commit comments

Comments
 (0)