Skip to content

Commit 5ea039b

Browse files
authored
Merge pull request #11470 from typhoonzero/fix_unitests
Fix dist ut
2 parents 916e863 + 40c631e commit 5ea039b

File tree

4 files changed

+59
-43
lines changed

4 files changed

+59
-43
lines changed

paddle/fluid/operators/listen_and_serv_op.cc

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -348,7 +348,8 @@ class ListenAndServOpMaker : public framework::OpProtoAndCheckerMaker {
348348
};
349349

350350
void SignalHandler::StopAndExit(int signal_num) {
351-
VLOG(3) << "Catch interrupt signal: " << signal_num << ", program will exit";
351+
// Do not use VLOG here for the device for printing maybe already released.
352+
// exit will release interal allocated resoureces.
352353
exit(0);
353354
}
354355

python/paddle/fluid/layers/io.py

Lines changed: 22 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -22,9 +22,9 @@
2222
from layer_function_generator import generate_layer_fn, templatedoc
2323

2424
__all__ = [
25-
'data', 'BlockGuardServ', 'ListenAndServ', 'Send', 'open_recordio_file',
26-
'open_files', 'read_file', 'shuffle', 'batch', 'double_buffer',
27-
'random_data_generator', 'Preprocessor', 'load'
25+
'data', 'BlockGuardServ', 'ListenAndServ', 'Send', 'Recv',
26+
'open_recordio_file', 'open_files', 'read_file', 'shuffle', 'batch',
27+
'double_buffer', 'random_data_generator', 'Preprocessor', 'load'
2828
]
2929

3030

@@ -177,59 +177,51 @@ def complete_op(self):
177177
})
178178

179179

180-
def Send(endpoints, send_vars, get_vars=None):
180+
def Send(endpoints, send_vars, sync=True):
181181
"""
182-
Send layer
182+
Send variables to the server side, and get vars from server
183+
side when server have finished running server side program.
183184
184185
Args:
185-
endpoints: comma seperated IP:PORT pairs in the order
186+
endpoints (str): comma seperated IP:PORT pairs in the order
186187
of send_vars to send
187-
send_vars: vars to send
188-
get_vars: vars to get from server after send completes.
189-
190-
Send variables to the server side, and get vars from server
191-
side when server have finished running server side program.
188+
send_vars (list): variables to send to server
189+
sync (bool): whether to wait the request finish
190+
192191
"""
193192
assert (type(send_vars) == list)
194193

195194
epmap = endpoints.split(",")
196195
endpoints = list(set(epmap))
197196

198197
helper = LayerHelper("Send", **locals())
199-
if not get_vars:
200-
get_vars = []
201-
for s in send_vars:
202-
v = helper.create_tmp_variable(dtype=s.dtype, stop_gradient=True)
203-
get_vars.append(v)
204198
rpc_op_role_name = core.op_proto_and_checker_maker.kOpRoleAttrName()
205199

206200
helper.append_op(
207201
type="send",
208202
inputs={"X": send_vars},
209-
outputs={"Out": get_vars},
210203
attrs={
211204
"endpoints": endpoints,
212205
"epmap": epmap,
213206
rpc_op_role_name: core.op_proto_and_checker_maker.OpRole.RPC
214207
})
208+
if sync:
209+
helper.append_op(type="send_barrier", attrs={"endpoints": endpoints})
215210

216-
return get_vars
217211

218-
219-
def Recv(endpoints, get_vars):
212+
def Recv(endpoints, get_vars, sync=True):
220213
"""
221-
Recv layer
214+
Receive variables from server side
222215
223216
Args:
224-
endpoints: comma seperated IP:PORT pairs in the order
217+
endpoints (str): comma seperated IP:PORT pairs in the order
225218
of send_vars to send
226-
send_vars: vars to send
227-
get_vars: vars to get from server after send completes.
219+
get_vars (list): vars to get from server after send completes.
220+
sync (bool): whether to wait the request finish
228221
229-
Send variables to the server side, and get vars from server
230-
side when server have finished running server side program.
222+
Returns:
223+
list: list of received variables
231224
"""
232-
assert (type(send_vars) == list)
233225
assert (type(get_vars) == list)
234226

235227
epmap = endpoints.split(",")
@@ -242,6 +234,9 @@ def Recv(endpoints, get_vars):
242234
outputs={"Out": get_vars},
243235
attrs={"endpoints": endpoints,
244236
"epmap": epmap})
237+
if sync:
238+
helper.append_op(type="fetch_barrier", attrs={"endpoints": endpoints})
239+
return get_vars
245240

246241

247242
def monkey_patch_reader_methods(reader):

python/paddle/fluid/tests/unittests/test_dist_train.py

Lines changed: 23 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
import time
1717
import unittest
1818
from multiprocessing import Process
19+
import signal
1920

2021
import numpy
2122

@@ -24,9 +25,6 @@
2425

2526

2627
class TestSendOp(unittest.TestCase):
27-
@unittest.skip(
28-
"This test is buggy. We cannot use time.sleep to sync processes, the connection may fail in unittest."
29-
)
3028
def test_send(self):
3129
# Run init_serv in a thread
3230
place = fluid.CPUPlace()
@@ -35,7 +33,9 @@ def test_send(self):
3533
p.daemon = True
3634
p.start()
3735

38-
time.sleep(10)
36+
self.ps_timeout = 5
37+
self._wait_ps_ready(p.pid)
38+
3939
with open("/tmp/paddle.%d.port" % p.pid, "r") as fn:
4040
selected_port = int(fn.readlines()[0])
4141
self.init_client(place, selected_port)
@@ -44,9 +44,23 @@ def test_send(self):
4444
self.assertTrue(numpy.allclose(self.local_out, self.dist_out))
4545

4646
# FIXME(typhoonzero): find a way to gracefully shutdown the server.
47-
os.system("kill -9 %d" % p.pid)
47+
os.kill(p.pid, signal.SIGKILL)
4848
p.join()
4949

50+
def _wait_ps_ready(self, pid):
51+
start_left_time = self.ps_timeout
52+
sleep_time = 0.5
53+
while True:
54+
assert start_left_time >= 0, "wait ps ready failed"
55+
time.sleep(sleep_time)
56+
try:
57+
# the listen_and_serv_op would touch a file which contains the listen port
58+
# on the /tmp directory until it was ready to process all the RPC call.
59+
os.stat("/tmp/paddle.%d.port" % pid)
60+
return
61+
except os.error:
62+
start_left_time -= sleep_time
63+
5064
def init_serv(self, place):
5165
main = fluid.Program()
5266

@@ -84,7 +98,10 @@ def init_client(self, place, port):
8498
dtype="float32",
8599
persistable=False,
86100
shape=[32, 32])
87-
o = layers.Send("127.0.0.1:%d" % port, [x], [get_var])
101+
fluid.initializer.Constant(value=2.3)(get_var, main.global_block())
102+
layers.Send("127.0.0.1:%d" % port, [x])
103+
o = layers.Recv("127.0.0.1:%d" % port, [get_var])
104+
88105
exe = fluid.Executor(place)
89106
self.dist_out = exe.run(main, fetch_list=o) # o is a list
90107

python/paddle/fluid/tests/unittests/test_listen_and_serv_op.py

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -57,17 +57,18 @@ class TestListenAndServOp(OpTest):
5757
def setUp(self):
5858
self.ps_timeout = 5
5959
self.ip = "127.0.0.1"
60-
self.port = "6173"
60+
self.port = "0"
6161
self.trainers = 1
62-
self.trainer_id = 1
62+
self.trainer_id = 0
6363

6464
def _start_pserver(self, use_cuda, sync_mode):
6565
p = Process(
6666
target=run_pserver,
6767
args=(use_cuda, sync_mode, self.ip, self.port, self.trainers,
6868
self.trainer_id))
69+
p.daemon = True
6970
p.start()
70-
return p.pid
71+
return p
7172

7273
def _wait_ps_ready(self, pid):
7374
start_left_time = self.ps_timeout
@@ -89,18 +90,20 @@ def test_rpc_interfaces(self):
8990

9091
def test_handle_signal_in_serv_op(self):
9192
# run pserver on CPU in sync mode
92-
pid = self._start_pserver(False, True)
93-
self._wait_ps_ready(pid)
93+
p1 = self._start_pserver(False, True)
94+
self._wait_ps_ready(p1.pid)
9495

9596
# raise SIGTERM to pserver
96-
os.kill(pid, signal.SIGTERM)
97+
os.kill(p1.pid, signal.SIGKILL)
98+
p1.join()
9799

98100
# run pserver on CPU in async mode
99-
pid = self._start_pserver(False, False)
100-
self._wait_ps_ready(pid)
101+
p2 = self._start_pserver(False, False)
102+
self._wait_ps_ready(p2.pid)
101103

102104
# raise SIGTERM to pserver
103-
os.kill(pid, signal.SIGTERM)
105+
os.kill(p2.pid, signal.SIGKILL)
106+
p2.join()
104107

105108

106109
if __name__ == '__main__':

0 commit comments

Comments
 (0)