Skip to content

Commit f2a205c

Browse files
committed
add test_pserver_run_empty_optimize_block
1 parent bf97648 commit f2a205c

File tree

2 files changed

+119
-0
lines changed

2 files changed

+119
-0
lines changed

python/paddle/fluid/tests/unittests/CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ if(NOT WITH_DISTRIBUTE)
1515
list(REMOVE_ITEM TEST_OPS test_dist_transpiler)
1616
list(REMOVE_ITEM TEST_OPS test_simple_dist_transpiler)
1717
list(REMOVE_ITEM TEST_OPS test_listen_and_serv_op)
18+
list(REMOVE_ITEM TEST_OPS test_pserver_run_empty_optimize_block)
1819
LIST(REMOVE_ITEM TEST_OPS test_dist_mnist)
1920
LIST(REMOVE_ITEM TEST_OPS test_dist_word2vec)
2021
endif(NOT WITH_DISTRIBUTE)
@@ -74,6 +75,7 @@ py_test_modules(test_warpctc_op MODULES test_warpctc_op ENVS FLAGS_warpctc_dir=$
7475
if(WITH_DISTRIBUTE)
7576
py_test_modules(test_dist_train MODULES test_dist_train SERIAL)
7677
set_tests_properties(test_listen_and_serv_op PROPERTIES TIMEOUT 20)
78+
set_tests_properties(test_pserver_run_empty_optimize_block PROPERTIES TIMEOUT 20)
7779
if(NOT APPLE)
7880
set_tests_properties(test_dist_mnist PROPERTIES TIMEOUT 200)
7981
set_tests_properties(test_dist_word2vec PROPERTIES TIMEOUT 200)
Lines changed: 117 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,117 @@
1+
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from __future__ import print_function
16+
17+
import paddle
18+
import paddle.fluid as fluid
19+
import os
20+
import signal
21+
import subprocess
22+
import time
23+
import unittest
24+
from multiprocessing import Process
25+
from op_test import OpTest
26+
27+
28+
def run_pserver(use_cuda, sync_mode, ip, port, trainers, trainer_id):
29+
x = fluid.layers.data(name='x', shape=[1], dtype='float32')
30+
y_predict = fluid.layers.fc(input=x, size=1, act=None, bias_attr=False)
31+
y = fluid.layers.data(name='y', shape=[1], dtype='float32')
32+
33+
# loss function
34+
cost = fluid.layers.square_error_cost(input=y_predict, label=y)
35+
avg_cost = fluid.layers.mean(cost)
36+
37+
# optimizer
38+
sgd_optimizer = fluid.optimizer.SGD(learning_rate=0.001)
39+
sgd_optimizer.minimize(avg_cost)
40+
41+
place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace()
42+
exe = fluid.Executor(place)
43+
44+
ps1 = ip + ":" + str(int(port) + 1)
45+
ps2 = ip + ":" + port
46+
pserver_endpoints = ps1 + "," + ps2
47+
48+
config = fluid.DistributeTranspilerConfig()
49+
config.slice_var_up = False
50+
t = fluid.DistributeTranspiler(config=config)
51+
t.transpile(
52+
trainer_id,
53+
pservers=pserver_endpoints,
54+
trainers=trainers,
55+
sync_mode=sync_mode)
56+
pserver_prog = t.get_pserver_program(ps2)
57+
58+
# pserver2 have no parameter
59+
assert (len(pserver_prog.blocks), 2)
60+
assert (len(pserver_prog.blocks[1].ops), 0)
61+
62+
pserver_startup = t.get_startup_program(ps2, pserver_prog)
63+
exe.run(pserver_startup)
64+
exe.run(pserver_prog)
65+
66+
67+
class TestListenAndServOp(OpTest):
68+
def setUp(self):
69+
self.ps_timeout = 5
70+
self.ip = "127.0.0.1"
71+
self.port = "0"
72+
self.trainers = 1
73+
self.trainer_id = 0
74+
75+
def _start_pserver(self, use_cuda, sync_mode):
76+
p = Process(
77+
target=run_pserver,
78+
args=(use_cuda, sync_mode, self.ip, self.port, self.trainers,
79+
self.trainer_id))
80+
p.daemon = True
81+
p.start()
82+
return p
83+
84+
def _wait_ps_ready(self, pid):
85+
start_left_time = self.ps_timeout
86+
sleep_time = 0.5
87+
while True:
88+
assert start_left_time >= 0, "wait ps ready failed"
89+
time.sleep(sleep_time)
90+
try:
91+
# the listen_and_serv_op would touch a file which contains the listen port
92+
# on the /tmp directory until it was ready to process all the RPC call.
93+
os.stat("/tmp/paddle.%d.port" % pid)
94+
return
95+
except os.error:
96+
start_left_time -= sleep_time
97+
98+
def test_handle_signal_in_serv_op(self):
99+
# run pserver on CPU in sync mode
100+
p1 = self._start_pserver(False, True)
101+
self._wait_ps_ready(p1.pid)
102+
103+
# raise SIGTERM to pserver
104+
os.kill(p1.pid, signal.SIGINT)
105+
p1.join()
106+
107+
# run pserver on CPU in async mode
108+
p2 = self._start_pserver(False, False)
109+
self._wait_ps_ready(p2.pid)
110+
111+
# raise SIGTERM to pserver
112+
os.kill(p2.pid, signal.SIGTERM)
113+
p2.join()
114+
115+
116+
if __name__ == '__main__':
117+
unittest.main()

0 commit comments

Comments
 (0)