Skip to content

Commit 7a64d48

Browse files
authored
fix test_save_load with pickle (#14410)
* fix test_save_load with pickle test=develop * fix test_save_load with pickle test=develop * fix test_save_load with pickle test=develop
1 parent d3aed98 commit 7a64d48

File tree

2 files changed

+9
-5
lines changed

2 files changed

+9
-5
lines changed

python/paddle/fluid/tests/unittests/dist_save_load.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
from functools import reduce
2727

2828
import numpy as np
29+
import pickle
2930
import unittest
3031
import six
3132

@@ -166,7 +167,10 @@ def get_data():
166167
io.save_persistables(startup_exe, model_dir, trainer_prog)
167168

168169
var = np.array(fluid.global_scope().find_var('__fc_b__').get_tensor())
169-
print(np.ravel(var).tolist())
170+
if six.PY2:
171+
print(pickle.dumps(np.ravel(var).tolist()))
172+
else:
173+
sys.stdout.buffer.write(pickle.dumps(np.ravel(var).tolist()))
170174

171175

172176
if __name__ == "__main__":

python/paddle/fluid/tests/unittests/test_dist_save_load.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -65,14 +65,14 @@ def check_with_place(self,
6565

6666
shutil.rmtree(model_dir)
6767

68-
local_np = np.array(eval(local_var[0]))
69-
train0_np = np.array(eval(tr0_var[0]))
70-
train1_np = np.array(eval(tr1_var[0]))
68+
local_np = np.array(local_var)
69+
train0_np = np.array(tr0_var)
70+
train1_np = np.array(tr1_var)
71+
7172
self.assertAlmostEqual(local_np.all(), train0_np.all(), delta=delta)
7273
self.assertAlmostEqual(local_np.all(), train1_np.all(), delta=delta)
7374
self.assertAlmostEqual(train0_np.all(), train1_np.all(), delta=delta)
7475

75-
@unittest.skip(reason="CI fail")
7676
def test_dist(self):
7777
need_envs = {
7878
"IS_DISTRIBUTED": '0',

0 commit comments

Comments
 (0)