diff --git a/chebai/preprocessing/datasets/base.py b/chebai/preprocessing/datasets/base.py index 4cf6edf0..577251db 100644 --- a/chebai/preprocessing/datasets/base.py +++ b/chebai/preprocessing/datasets/base.py @@ -1103,7 +1103,9 @@ def _retrieve_splits_from_csv(self) -> None: splits_df = pd.read_csv(self.splits_file_path) filename = self.processed_file_names_dict["data"] - data = self.load_processed_data(filename=filename) + data = torch.load( + os.path.join(self.processed_dir, filename), weights_only=False + ) df_data = pd.DataFrame(data) train_ids = splits_df[splits_df["split"] == "train"]["id"] @@ -1114,6 +1116,7 @@ def _retrieve_splits_from_csv(self) -> None: self._dynamic_df_val = df_data[df_data["ident"].isin(validation_ids)] self._dynamic_df_test = df_data[df_data["ident"].isin(test_ids)] + # ------------------------------ Phase: DataLoaders ----------------------------------- def load_processed_data( self, kind: Optional[str] = None, filename: Optional[str] = None ) -> List[Dict[str, Any]]: @@ -1149,24 +1152,18 @@ def load_processed_data( # If both kind and filename are given, use filename if kind is not None and filename is None: - try: - if self.use_inner_cross_validation and kind != "test": - filename = self.processed_file_names_dict[ - f"fold_{self.fold_index}_{kind}" - ] - else: - data_df = self.dynamic_split_dfs[kind] - return data_df.to_dict(orient="records") - except KeyError: - kind = f"{kind}" + if self.use_inner_cross_validation and kind != "test": + filename = self.processed_file_names_dict[ + f"fold_{self.fold_index}_{kind}" + ] + else: + data_df = self.dynamic_split_dfs[kind] + return data_df.to_dict(orient="records") # If filename is provided - try: - return torch.load( - os.path.join(self.processed_dir, filename), weights_only=False - ) - except FileNotFoundError: - raise FileNotFoundError(f"File {filename} doesn't exist") + return torch.load( + os.path.join(self.processed_dir, filename), weights_only=False + ) # ------------------------------ Phase: Raw Properties ----------------------------------- @property