Skip to content

Commit edcf04c

Browse files
authored
[cherry-pick] fix pickle between python 2 & 3 (#22620)
* cherry-pick #22555 test=release/1.7, test=develop * cherry-pick #22621 test=release/1.7, test=develop
1 parent c000f8a commit edcf04c

File tree

2 files changed

+17
-10
lines changed

2 files changed

+17
-10
lines changed

python/paddle/fluid/dygraph/checkpoint.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
import collections
1919
from ..framework import Variable, default_main_program, in_dygraph_mode, dygraph_only, Parameter, ParamBase
2020
import pickle
21+
import six
2122
from . import learning_rate_scheduler
2223
import warnings
2324
from .. import core
@@ -88,7 +89,7 @@ def save_dygraph(state_dict, model_path):
8889
os.makedirs(dir_name)
8990

9091
with open(file_name, 'wb') as f:
91-
pickle.dump(model_dict, f)
92+
pickle.dump(model_dict, f, protocol=2)
9293

9394

9495
@dygraph_only
@@ -130,14 +131,16 @@ def load_dygraph(model_path, keep_name_table=False):
130131
params_file_path))
131132

132133
with open(params_file_path, 'rb') as f:
133-
para_dict = pickle.load(f)
134+
para_dict = pickle.load(f) if six.PY2 else pickle.load(
135+
f, encoding='latin1')
134136

135137
if not keep_name_table and "StructuredToParameterName@@" in para_dict:
136138
del para_dict["StructuredToParameterName@@"]
137139
opti_dict = None
138140
opti_file_path = model_path + ".pdopt"
139141
if os.path.exists(opti_file_path):
140142
with open(opti_file_path, 'rb') as f:
141-
opti_dict = pickle.load(f)
143+
opti_dict = pickle.load(f) if six.PY2 else pickle.load(
144+
f, encoding='latin1')
142145

143146
return para_dict, opti_dict

python/paddle/fluid/io.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -800,7 +800,7 @@ def name_has_fc(var):
800800
var_temp = paddle.fluid.global_scope().find_var(each_var.name)
801801
assert var_temp != None, "can't not find var: " + each_var.name
802802
new_shape = (np.array(var_temp.get_tensor())).shape
803-
assert each_var.name in orig_para_shape, earch_var.name + "MUST in var list"
803+
assert each_var.name in orig_para_shape, each_var.name + "MUST in var list"
804804
orig_shape = orig_para_shape.get(each_var.name)
805805
if new_shape != orig_shape:
806806
raise RuntimeError(
@@ -1579,14 +1579,14 @@ def get_tensor(var):
15791579
parameter_list = list(filter(is_parameter, program.list_vars()))
15801580
param_dict = {p.name: get_tensor(p) for p in parameter_list}
15811581
with open(model_path + ".pdparams", 'wb') as f:
1582-
pickle.dump(param_dict, f)
1582+
pickle.dump(param_dict, f, protocol=2)
15831583

15841584
optimizer_var_list = list(
15851585
filter(is_belong_to_optimizer, program.list_vars()))
15861586

15871587
opt_dict = {p.name: get_tensor(p) for p in optimizer_var_list}
15881588
with open(model_path + ".pdopt", 'wb') as f:
1589-
pickle.dump(opt_dict, f)
1589+
pickle.dump(opt_dict, f, protocol=2)
15901590

15911591
main_program = program.clone()
15921592
program.desc.flush()
@@ -1733,7 +1733,8 @@ def set_var(var, ndarray):
17331733
global_scope(),
17341734
executor._default_executor)
17351735
with open(parameter_file_name, 'rb') as f:
1736-
load_dict = pickle.load(f)
1736+
load_dict = pickle.load(f) if six.PY2 else pickle.load(
1737+
f, encoding='latin1')
17371738
for v in parameter_list:
17381739
assert v.name in load_dict, \
17391740
"Can not find [{}] in model file [{}]".format(
@@ -1753,7 +1754,8 @@ def set_var(var, ndarray):
17531754
optimizer_var_list, global_scope(), executor._default_executor)
17541755

17551756
with open(opt_file_name, 'rb') as f:
1756-
load_dict = pickle.load(f)
1757+
load_dict = pickle.load(f) if six.PY2 else pickle.load(
1758+
f, encoding='latin1')
17571759
for v in optimizer_var_list:
17581760
assert v.name in load_dict, \
17591761
"Can not find [{}] in model file [{}]".format(
@@ -1877,12 +1879,14 @@ def clone_var_to_block(block, var):
18771879
"Parameter file [{}] not exits".format(parameter_file_name)
18781880

18791881
with open(parameter_file_name, 'rb') as f:
1880-
para_dict = pickle.load(f)
1882+
para_dict = pickle.load(f) if six.PY2 else pickle.load(
1883+
f, encoding='latin1')
18811884

18821885
opt_file_name = model_prefix + ".pdopt"
18831886
if os.path.exists(opt_file_name):
18841887
with open(opt_file_name, 'rb') as f:
1885-
opti_dict = pickle.load(f)
1888+
opti_dict = pickle.load(f) if six.PY2 else pickle.load(
1889+
f, encoding='latin1')
18861890

18871891
para_dict.update(opti_dict)
18881892

0 commit comments

Comments
 (0)