Skip to content

Commit ad4d965

Browse files
author
Yan Xu
authored
Merge pull request #13250 from Yancey1989/fix_dist_base
fix parallel run dist unit test
2 parents e27a1a6 + 0a71d58 commit ad4d965

File tree

2 files changed

+12
-5
lines changed

2 files changed

+12
-5
lines changed

python/paddle/dataset/image.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,7 @@ def batch_images_from_tar(data_file,
104104
pickle.dump(
105105
output,
106106
open('%s/batch_%d' % (out_path, file_id), 'wb'),
107-
protocol=pickle.HIGHEST_PROTOCOL)
107+
protocol=2)
108108
file_id += 1
109109
data = []
110110
labels = []
@@ -113,9 +113,7 @@ def batch_images_from_tar(data_file,
113113
output['label'] = labels
114114
output['data'] = data
115115
pickle.dump(
116-
output,
117-
open('%s/batch_%d' % (out_path, file_id), 'wb'),
118-
protocol=pickle.HIGHEST_PROTOCOL)
116+
output, open('%s/batch_%d' % (out_path, file_id), 'wb'), protocol=2)
119117

120118
with open(meta_file, 'a') as meta:
121119
for file in os.listdir(out_path):

python/paddle/fluid/tests/unittests/test_dist_base.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@ def run_pserver(self, args):
5555
pserver_prog = t.get_pserver_program(args.current_endpoint)
5656
startup_prog = t.get_startup_program(args.current_endpoint,
5757
pserver_prog)
58+
5859
place = fluid.CPUPlace()
5960
exe = fluid.Executor(place)
6061
exe.run(startup_prog)
@@ -147,6 +148,8 @@ def runtime_main(test_class):
147148

148149

149150
import paddle.compat as cpt
151+
import socket
152+
from contextlib import closing
150153

151154

152155
class TestDistBase(unittest.TestCase):
@@ -156,13 +159,19 @@ def _setup_config(self):
156159
def setUp(self):
157160
self._trainers = 2
158161
self._pservers = 2
159-
self._ps_endpoints = "127.0.0.1:9123,127.0.0.1:9124"
162+
self._ps_endpoints = "127.0.0.1:%s,127.0.0.1:%s" % (
163+
self._find_free_port(), self._find_free_port())
160164
self._python_interp = "python"
161165
self._sync_mode = True
162166
self._mem_opt = False
163167
self._use_reduce = False
164168
self._setup_config()
165169

170+
def _find_free_port(self):
171+
with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as s:
172+
s.bind(('', 0))
173+
return s.getsockname()[1]
174+
166175
def start_pserver(self, model_file, check_error_log):
167176
ps0_ep, ps1_ep = self._ps_endpoints.split(",")
168177
ps_cmd = "%s %s --role pserver --endpoints %s --trainer_id 0 --current_endpoint %s --trainers %d --is_dist"

0 commit comments

Comments
 (0)