@@ -449,33 +449,34 @@ def setup_processed(self) -> None:
449449 # )
450450
451451 # Transform the processed data into encoded data
452- processed_name = self .processed_file_names_dict ["data" ]
453- if not os .path .isfile (os .path .join (self .processed_dir , processed_name )):
454- print (
455- f"Missing encoded data related to version { self .chebi_version } , transform processed data into encoded data:" ,
456- processed_name ,
457- )
458- torch .save (
459- self ._load_data_from_file (
460- os .path .join (
461- self .processed_dir_main ,
462- self .raw_file_names_dict ["data" ],
463- )
464- ),
465- os .path .join (self .processed_dir , processed_name ),
466- )
467- # Transform the data related to "chebi_version_train" to encoded data, if it doesn't exist
468- if self .chebi_version_train is not None and not os .path .isfile (
469- os .path .join (
470- self ._chebi_version_train_obj .processed_dir ,
471- self ._chebi_version_train_obj .raw_file_names_dict ["data" ],
472- )
473- ):
474- print (
475- f"Missing encoded data related to train version: { self .chebi_version_train } "
476- )
477- print ("Call the setup method related to it" )
478- self ._chebi_version_train_obj .setup ()
452+ if not self .aug_data :
453+ processed_name = self .processed_file_names_dict ["data" ]
454+ if not os .path .isfile (os .path .join (self .processed_dir , processed_name )):
455+ print (
456+ f"Missing encoded data related to version { self .chebi_version } , transform processed data into encoded data:" ,
457+ processed_name ,
458+ )
459+ torch .save (
460+ self ._load_data_from_file (
461+ os .path .join (
462+ self .processed_dir_main ,
463+ self .raw_file_names_dict ["data" ],
464+ )
465+ ),
466+ os .path .join (self .processed_dir , processed_name ),
467+ )
468+ # Transform the data related to "chebi_version_train" to encoded data, if it doesn't exist
469+ if self .chebi_version_train is not None and not os .path .isfile (
470+ os .path .join (
471+ self ._chebi_version_train_obj .processed_dir ,
472+ self ._chebi_version_train_obj .raw_file_names_dict ["data" ],
473+ )
474+ ):
475+ print (
476+ f"Missing encoded data related to train version: { self .chebi_version_train } "
477+ )
478+ print ("Call the setup method related to it" )
479+ self ._chebi_version_train_obj .setup ()
479480
480481
481482
@@ -810,9 +811,13 @@ def _generate_dynamic_splits(self) -> None:
810811 """
811812 print ("Generate dynamic splits..." )
812813 # Load encoded data derived from "chebi_version"
814+ # Determine the directory for loading encoded data based on the aug_data flag
815+ data_dir = self .augmented_dir_main if self .aug_data else self .processed_dir
816+
813817 try :
814818 filename = self .processed_file_names_dict ["data" ]
815- data_chebi_version = torch .load (os .path .join (self .processed_dir , filename ))
819+ print ("Directory:" ,os .path .join (data_dir , filename ))
820+ data_chebi_version = torch .load (os .path .join (data_dir , filename ))
816821 except FileNotFoundError :
817822 raise FileNotFoundError (
818823 f"File data.pt doesn't exists. "
0 commit comments