Skip to content

Commit 92d60f1

Browse files
authored
Merge pull request #90 from ChEB-AI/fix/load_processed_data
Fix for recursive calls for `load_processed_data` method
2 parents 94f6710 + 8c5dce4 commit 92d60f1

File tree

1 file changed

+14
-17
lines changed
  • chebai/preprocessing/datasets

1 file changed

+14
-17
lines changed

chebai/preprocessing/datasets/base.py

Lines changed: 14 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1103,7 +1103,9 @@ def _retrieve_splits_from_csv(self) -> None:
11031103
splits_df = pd.read_csv(self.splits_file_path)
11041104

11051105
filename = self.processed_file_names_dict["data"]
1106-
data = self.load_processed_data(filename=filename)
1106+
data = torch.load(
1107+
os.path.join(self.processed_dir, filename), weights_only=False
1108+
)
11071109
df_data = pd.DataFrame(data)
11081110

11091111
train_ids = splits_df[splits_df["split"] == "train"]["id"]
@@ -1114,6 +1116,7 @@ def _retrieve_splits_from_csv(self) -> None:
11141116
self._dynamic_df_val = df_data[df_data["ident"].isin(validation_ids)]
11151117
self._dynamic_df_test = df_data[df_data["ident"].isin(test_ids)]
11161118

1119+
# ------------------------------ Phase: DataLoaders -----------------------------------
11171120
def load_processed_data(
11181121
self, kind: Optional[str] = None, filename: Optional[str] = None
11191122
) -> List[Dict[str, Any]]:
@@ -1149,24 +1152,18 @@ def load_processed_data(
11491152

11501153
# If both kind and filename are given, use filename
11511154
if kind is not None and filename is None:
1152-
try:
1153-
if self.use_inner_cross_validation and kind != "test":
1154-
filename = self.processed_file_names_dict[
1155-
f"fold_{self.fold_index}_{kind}"
1156-
]
1157-
else:
1158-
data_df = self.dynamic_split_dfs[kind]
1159-
return data_df.to_dict(orient="records")
1160-
except KeyError:
1161-
kind = f"{kind}"
1155+
if self.use_inner_cross_validation and kind != "test":
1156+
filename = self.processed_file_names_dict[
1157+
f"fold_{self.fold_index}_{kind}"
1158+
]
1159+
else:
1160+
data_df = self.dynamic_split_dfs[kind]
1161+
return data_df.to_dict(orient="records")
11621162

11631163
# If filename is provided
1164-
try:
1165-
return torch.load(
1166-
os.path.join(self.processed_dir, filename), weights_only=False
1167-
)
1168-
except FileNotFoundError:
1169-
raise FileNotFoundError(f"File {filename} doesn't exist")
1164+
return torch.load(
1165+
os.path.join(self.processed_dir, filename), weights_only=False
1166+
)
11701167

11711168
# ------------------------------ Phase: Raw Properties -----------------------------------
11721169
@property

0 commit comments

Comments
 (0)