diff --git a/chebai/preprocessing/datasets/base.py b/chebai/preprocessing/datasets/base.py index dfa0f999..c703ae1d 100644 --- a/chebai/preprocessing/datasets/base.py +++ b/chebai/preprocessing/datasets/base.py @@ -1045,9 +1045,7 @@ def _retrieve_splits_from_csv(self) -> None: splits_df = pd.read_csv(self.splits_file_path) filename = self.processed_file_names_dict["data"] - data = torch.load( - os.path.join(self.processed_dir, filename), weights_only=False - ) + data = self.load_processed_data(filename=filename) df_data = pd.DataFrame(data) train_ids = splits_df[splits_df["split"] == "train"]["id"]