Skip to content

Commit 648c675

Browse files
authored
Merge pull request #26 from ChEB-AI/fix-pandas-serialisation
replace pickle.load with pd.read_pickle for raw files
2 parents dd45138 + 04d390b commit 648c675

File tree

2 files changed

+5
-5
lines changed

2 files changed

+5
-5
lines changed

chebai/loss/bce_weighted.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ def set_pos_weight(self, input):
3535
):
3636
complete_data = pd.concat(
3737
[
38-
pickle.load(
38+
pd.read_pickle(
3939
open(
4040
os.path.join(
4141
self.data_extractor.raw_dir,

chebai/preprocessing/datasets/chebi.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -192,7 +192,7 @@ def graph_to_raw_dataset(self, g, split_name=None):
192192
return data
193193

194194
def save_raw(self, data: pd.DataFrame, filename: str):
195-
pickle.dump(data, open(os.path.join(self.raw_dir, filename), "wb"))
195+
pd.to_pickle(data, open(os.path.join(self.raw_dir, filename), "wb"))
196196

197197
def _load_dict(self, input_file_path):
198198
"""
@@ -205,7 +205,7 @@ def _load_dict(self, input_file_path):
205205
dict: The dictionary, keys are `features`, `labels` and `ident`.
206206
"""
207207
with open(input_file_path, "rb") as input_file:
208-
df = pickle.load(input_file)
208+
df = pd.read_pickle(input_file)
209209
if self.single_class is not None:
210210
single_cls_index = list(df.columns).index(int(self.single_class))
211211
for row in df.values:
@@ -218,7 +218,7 @@ def _load_dict(self, input_file_path):
218218
@staticmethod
219219
def _get_data_size(input_file_path):
220220
with open(input_file_path, "rb") as f:
221-
return len(pickle.load(f))
221+
return len(pd.read_pickle(f))
222222

223223
def _setup_pruned_test_set(self):
224224
"""Create test set with same leaf nodes, but use classes that appear in train set"""
@@ -468,7 +468,7 @@ def prepare_data(self, *args, **kwargs):
468468
with open(
469469
os.path.join(self.raw_dir, self.raw_file_names_dict["test"]), "rb"
470470
) as input_file:
471-
test_df = pickle.load(input_file)
471+
test_df = pd.read_pickle(input_file)
472472
# create train/val split based on test set
473473
chebi_path = self._load_chebi(
474474
self.chebi_version_train

0 commit comments

Comments
 (0)