Skip to content

Commit 5c5c991

Browse files
committed
separate load_processed_data into file-based and kind-based loading
1 parent bc5b5ec commit 5c5c991

File tree

2 files changed

+8
-7
lines changed

2 files changed

+8
-7
lines changed

chebai/preprocessing/datasets/base.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1103,8 +1103,8 @@ def _retrieve_splits_from_csv(self) -> None:
11031103
splits_df = pd.read_csv(self.splits_file_path)
11041104

11051105
filename = self.processed_file_names_dict["data"]
1106-
data = torch.load(
1107-
os.path.join(self.processed_dir, filename), weights_only=False
1106+
data = self.load_processed_data_from_file(
1107+
os.path.join(self.processed_dir, filename)
11081108
)
11091109
df_data = pd.DataFrame(data)
11101110

@@ -1161,9 +1161,10 @@ def load_processed_data(
11611161
return data_df.to_dict(orient="records")
11621162

11631163
# If filename is provided
1164-
return torch.load(
1165-
os.path.join(self.processed_dir, filename), weights_only=False
1166-
)
1164+
return self.load_processed_data_from_file(filename)
1165+
1166+
def load_processed_data_from_file(self, filename):
1167+
return torch.load(os.path.join(filename), weights_only=False)
11671168

11681169
# ------------------------------ Phase: Raw Properties -----------------------------------
11691170
@property

chebai/preprocessing/datasets/chebi.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -401,8 +401,8 @@ def _get_data_splits(self) -> Tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame]:
401401
"""
402402
try:
403403
filename = self.processed_file_names_dict["data"]
404-
data_chebi_version = torch.load(
405-
os.path.join(self.processed_dir, filename), weights_only=False
404+
data_chebi_version = self.load_processed_data_from_file(
405+
os.path.join(self.processed_dir, filename)
406406
)
407407
except FileNotFoundError:
408408
raise FileNotFoundError(

0 commit comments

Comments
 (0)