Skip to content

Commit 524c81e

Browse files
author
Yancey
authored
Merge pull request #11126 from Yancey1989/polish_test_listen_and_serv_op
speedup test_listen_and_serv_op
2 parents f7a6001 + 7f5eb9f commit 524c81e

File tree

2 files changed

+35
-36
lines changed

2 files changed

+35
-36
lines changed

python/paddle/fluid/tests/unittests/CMakeLists.txt

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -48,5 +48,7 @@ foreach(TEST_OP ${TEST_OPS})
4848
endforeach(TEST_OP)
4949
py_test_modules(test_warpctc_op MODULES test_warpctc_op ENVS FLAGS_warpctc_dir=${WARPCTC_LIB_DIR} SERIAL)
5050
py_test_modules(test_dist_train MODULES test_dist_train SERIAL)
51-
# tests that need to be done in fixed timeout
52-
set_tests_properties(test_listen_and_serv_op PROPERTIES TIMEOUT 20)
51+
# FIXME(Yancey1989): this test would cost much more time on CUDAPlace
52+
# since load cudnn libraries, so we use a longer timeout to make this
53+
# unit test stability.
54+
set_tests_properties(test_listen_and_serv_op PROPERTIES TIMEOUT 30)

python/paddle/fluid/tests/unittests/test_listen_and_serv_op.py

Lines changed: 31 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
from op_test import OpTest
2424

2525

26-
def run_pserver(use_cuda, sync_mode, ip, port, trainer_count, trainer_id):
26+
def run_pserver(use_cuda, sync_mode, ip, port, trainers, trainer_id):
2727
x = fluid.layers.data(name='x', shape=[1], dtype='float32')
2828
y_predict = fluid.layers.fc(input=x, size=1, act=None)
2929
y = fluid.layers.data(name='y', shape=[1], dtype='float32')
@@ -39,15 +39,8 @@ def run_pserver(use_cuda, sync_mode, ip, port, trainer_count, trainer_id):
3939
place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace()
4040
exe = fluid.Executor(place)
4141

42-
port = os.getenv("PADDLE_INIT_PORT", port)
43-
pserver_ips = os.getenv("PADDLE_INIT_PSERVERS", ip) # ip,ip...
44-
eplist = []
45-
for ip in pserver_ips.split(","):
46-
eplist.append(':'.join([ip, port]))
47-
pserver_endpoints = ",".join(eplist) # ip:port,ip:port...
48-
trainers = int(os.getenv("TRAINERS", trainer_count))
49-
current_endpoint = os.getenv("POD_IP", ip) + ":" + port
50-
trainer_id = int(os.getenv("PADDLE_INIT_TRAINER_ID", trainer_id))
42+
pserver_endpoints = ip + ":" + port
43+
current_endpoint = ip + ":" + port
5144
t = fluid.DistributeTranspiler()
5245
t.transpile(
5346
trainer_id,
@@ -62,47 +55,51 @@ def run_pserver(use_cuda, sync_mode, ip, port, trainer_count, trainer_id):
6255

6356
class TestListenAndServOp(OpTest):
6457
def setUp(self):
65-
self.sleep_time = 5
58+
self.ps_timeout = 5
6659
self.ip = "127.0.0.1"
6760
self.port = "6173"
68-
self.trainer_count = 1
61+
self.trainers = 1
6962
self.trainer_id = 1
7063

71-
def _raise_signal(self, parent_pid, raised_signal):
72-
time.sleep(self.sleep_time)
73-
ps_command = subprocess.Popen(
74-
"ps -o pid --ppid %d --noheaders" % parent_pid,
75-
shell=True,
76-
stdout=subprocess.PIPE)
77-
ps_output = ps_command.stdout.read()
78-
retcode = ps_command.wait()
79-
assert retcode == 0, "ps command returned %d" % retcode
80-
81-
for pid_str in ps_output.split("\n")[:-1]:
82-
try:
83-
os.kill(int(pid_str), raised_signal)
84-
except Exception:
85-
continue
86-
8764
def _start_pserver(self, use_cuda, sync_mode):
8865
p = Process(
8966
target=run_pserver,
90-
args=(use_cuda, sync_mode, self.ip, self.port, self.trainer_count,
67+
args=(use_cuda, sync_mode, self.ip, self.port, self.trainers,
9168
self.trainer_id))
9269
p.start()
70+
return p.pid
71+
72+
def _wait_ps_ready(self, pid):
73+
retry_times = self.ps_timeout
74+
while True:
75+
assert retry_times >= 0, "wait ps ready failed"
76+
time.sleep(0.5)
77+
try:
78+
# the listen_and_serv_op would touch a file which contains the listen port
79+
# on the /tmp directory until it was ready to process all the RPC call.
80+
os.stat("/tmp/paddle.%d.port" % pid)
81+
return
82+
except os.error:
83+
retry_times -= 1
84+
85+
def test_rpc_interfaces(self):
86+
# TODO(Yancey1989): need to make sure the rpc interface correctly.
87+
pass
9388

9489
def test_handle_signal_in_serv_op(self):
9590
# run pserver on CPU in sync mode
96-
self._start_pserver(False, True)
91+
pid = self._start_pserver(False, True)
92+
self._wait_ps_ready(pid)
9793

98-
# raise SIGINT to pserver
99-
self._raise_signal(os.getpid(), signal.SIGINT)
94+
# raise SIGTERM to pserver
95+
os.kill(pid, signal.SIGTERM)
10096

10197
# run pserver on CPU in async mode
102-
self._start_pserver(False, False)
98+
pid = self._start_pserver(False, False)
99+
self._wait_ps_ready(pid)
103100

104101
# raise SIGTERM to pserver
105-
self._raise_signal(os.getpid(), signal.SIGTERM)
102+
os.kill(pid, signal.SIGTERM)
106103

107104

108105
if __name__ == '__main__':

0 commit comments

Comments
 (0)