Skip to content

Commit 3ab3a7f

Browse files
authored
Trainer auto wait pserver ports (#13341)
* trainer auto wait pserver port ready * add file * fix docstring * add option to not wait * update api spec * clean * fix test hang
1 parent 7622234 commit 3ab3a7f

File tree

5 files changed

+58
-4
lines changed

5 files changed

+58
-4
lines changed

paddle/fluid/API.spec

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ paddle.fluid.DistributeTranspiler.__init__ ArgSpec(args=['self', 'config'], vara
5959
paddle.fluid.DistributeTranspiler.get_pserver_program ArgSpec(args=['self', 'endpoint'], varargs=None, keywords=None, defaults=None)
6060
paddle.fluid.DistributeTranspiler.get_pserver_programs ArgSpec(args=['self', 'endpoint'], varargs=None, keywords=None, defaults=None)
6161
paddle.fluid.DistributeTranspiler.get_startup_program ArgSpec(args=['self', 'endpoint', 'pserver_program', 'startup_program'], varargs=None, keywords=None, defaults=(None, None))
62-
paddle.fluid.DistributeTranspiler.get_trainer_program ArgSpec(args=['self'], varargs=None, keywords=None, defaults=None)
62+
paddle.fluid.DistributeTranspiler.get_trainer_program ArgSpec(args=['self', 'wait_port'], varargs=None, keywords=None, defaults=(True,))
6363
paddle.fluid.DistributeTranspiler.transpile ArgSpec(args=['self', 'trainer_id', 'program', 'pservers', 'trainers', 'sync_mode', 'startup_program'], varargs=None, keywords=None, defaults=(None, '127.0.0.1:6174', 1, True, None))
6464
paddle.fluid.InferenceTranspiler.__init__
6565
paddle.fluid.InferenceTranspiler.transpile ArgSpec(args=['self', 'program', 'place', 'scope'], varargs=None, keywords=None, defaults=(None,))
@@ -346,7 +346,7 @@ paddle.fluid.transpiler.DistributeTranspiler.__init__ ArgSpec(args=['self', 'con
346346
paddle.fluid.transpiler.DistributeTranspiler.get_pserver_program ArgSpec(args=['self', 'endpoint'], varargs=None, keywords=None, defaults=None)
347347
paddle.fluid.transpiler.DistributeTranspiler.get_pserver_programs ArgSpec(args=['self', 'endpoint'], varargs=None, keywords=None, defaults=None)
348348
paddle.fluid.transpiler.DistributeTranspiler.get_startup_program ArgSpec(args=['self', 'endpoint', 'pserver_program', 'startup_program'], varargs=None, keywords=None, defaults=(None, None))
349-
paddle.fluid.transpiler.DistributeTranspiler.get_trainer_program ArgSpec(args=['self'], varargs=None, keywords=None, defaults=None)
349+
paddle.fluid.transpiler.DistributeTranspiler.get_trainer_program ArgSpec(args=['self', 'wait_port'], varargs=None, keywords=None, defaults=(True,))
350350
paddle.fluid.transpiler.DistributeTranspiler.transpile ArgSpec(args=['self', 'trainer_id', 'program', 'pservers', 'trainers', 'sync_mode', 'startup_program'], varargs=None, keywords=None, defaults=(None, '127.0.0.1:6174', 1, True, None))
351351
paddle.fluid.transpiler.InferenceTranspiler.__init__
352352
paddle.fluid.transpiler.InferenceTranspiler.transpile ArgSpec(args=['self', 'program', 'place', 'scope'], varargs=None, keywords=None, defaults=(None,))

python/paddle/fluid/tests/unittests/test_dist_transpiler.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ def get_trainer(self, config=None):
6262

6363
t = self._transpiler_instance(config)
6464

65-
trainer_main = t.get_trainer_program()
65+
trainer_main = t.get_trainer_program(wait_port=False)
6666
trainer_startup = fluid.default_startup_program()
6767

6868
assert (src.num_blocks == 1)

python/paddle/fluid/transpiler/details/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,3 +16,4 @@
1616

1717
from .program_utils import *
1818
from .ufind import *
19+
from .checkport import *
Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import sys
16+
import time
17+
import socket
18+
from contextlib import closing
19+
20+
21+
def wait_server_ready(endpoints):
22+
"""
23+
Wait until parameter servers are ready, use connext_ex to detect
24+
port readiness.
25+
26+
Args:
27+
endpoints (list): endpoints string list, like:
28+
["127.0.0.1:8080", "127.0.0.1:8081"]
29+
30+
Examples:
31+
.. code-block:: python
32+
33+
wait_server_ready(["127.0.0.1:8080", "127.0.0.1:8081"])
34+
"""
35+
while True:
36+
all_ok = True
37+
for ep in endpoints:
38+
ip_port = ep.split(":")
39+
with closing(socket.socket(socket.AF_INET,
40+
socket.SOCK_STREAM)) as sock:
41+
sock.settimeout(2)
42+
result = sock.connect_ex((ip_port[0], int(ip_port[1])))
43+
if result != 0:
44+
all_ok = False
45+
if not all_ok:
46+
sys.stderr.write("pserver not ready, wait 3 sec to retry...\n")
47+
sys.stderr.flush()
48+
time.sleep(3)
49+
else:
50+
break

python/paddle/fluid/transpiler/distribute_transpiler.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -381,7 +381,7 @@ def transpile(self,
381381
pserver_endpoints)
382382
self._split_table_grad_and_add_send_vars(program, pserver_endpoints)
383383

384-
def get_trainer_program(self):
384+
def get_trainer_program(self, wait_port=True):
385385
"""
386386
Get transpiled trainer side program.
387387
@@ -393,6 +393,9 @@ def get_trainer_program(self):
393393
delete_ops(self.origin_program.global_block(), self.optimize_ops)
394394
self.origin_program.__str__()
395395

396+
if wait_port:
397+
wait_server_ready(self.pserver_endpoints)
398+
396399
return self.origin_program
397400

398401
def _get_trainer_startup_program(self, recv_vars, eplist):

0 commit comments

Comments
 (0)