From fdf6f5ebee9fadf890f465bd8732d453a07e5915 Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Sat, 10 May 2025 21:21:27 +0200 Subject: [PATCH 1/2] no need to catch exceptions --- chebai/preprocessing/datasets/base.py | 26 ++++++++++---------------- 1 file changed, 10 insertions(+), 16 deletions(-) diff --git a/chebai/preprocessing/datasets/base.py b/chebai/preprocessing/datasets/base.py index 4cf6edf0..14319014 100644 --- a/chebai/preprocessing/datasets/base.py +++ b/chebai/preprocessing/datasets/base.py @@ -1149,24 +1149,18 @@ def load_processed_data( # If both kind and filename are given, use filename if kind is not None and filename is None: - try: - if self.use_inner_cross_validation and kind != "test": - filename = self.processed_file_names_dict[ - f"fold_{self.fold_index}_{kind}" - ] - else: - data_df = self.dynamic_split_dfs[kind] - return data_df.to_dict(orient="records") - except KeyError: - kind = f"{kind}" + if self.use_inner_cross_validation and kind != "test": + filename = self.processed_file_names_dict[ + f"fold_{self.fold_index}_{kind}" + ] + else: + data_df = self.dynamic_split_dfs[kind] + return data_df.to_dict(orient="records") # If filename is provided - try: - return torch.load( - os.path.join(self.processed_dir, filename), weights_only=False - ) - except FileNotFoundError: - raise FileNotFoundError(f"File {filename} doesn't exist") + return torch.load( + os.path.join(self.processed_dir, filename), weights_only=False + ) # ------------------------------ Phase: Raw Properties ----------------------------------- @property From 8c5dce43c09891f20de1bc648630cccb9f88337c Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Sat, 10 May 2025 21:24:25 +0200 Subject: [PATCH 2/2] fix: use `torch.load` instead of `load_processed_data` --- chebai/preprocessing/datasets/base.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/chebai/preprocessing/datasets/base.py b/chebai/preprocessing/datasets/base.py index 14319014..577251db 100644 --- a/chebai/preprocessing/datasets/base.py +++ b/chebai/preprocessing/datasets/base.py @@ -1103,7 +1103,9 @@ def _retrieve_splits_from_csv(self) -> None: splits_df = pd.read_csv(self.splits_file_path) filename = self.processed_file_names_dict["data"] - data = self.load_processed_data(filename=filename) + data = torch.load( + os.path.join(self.processed_dir, filename), weights_only=False + ) df_data = pd.DataFrame(data) train_ids = splits_df[splits_df["split"] == "train"]["id"] @@ -1114,6 +1116,7 @@ def _retrieve_splits_from_csv(self) -> None: self._dynamic_df_val = df_data[df_data["ident"].isin(validation_ids)] self._dynamic_df_test = df_data[df_data["ident"].isin(test_ids)] + # ------------------------------ Phase: DataLoaders ----------------------------------- def load_processed_data( self, kind: Optional[str] = None, filename: Optional[str] = None ) -> List[Dict[str, Any]]: