23
23
from op_test import OpTest
24
24
25
25
26
- def run_pserver (use_cuda , sync_mode , ip , port , trainer_count , trainer_id ):
26
+ def run_pserver (use_cuda , sync_mode , ip , port , trainers , trainer_id ):
27
27
x = fluid .layers .data (name = 'x' , shape = [1 ], dtype = 'float32' )
28
28
y_predict = fluid .layers .fc (input = x , size = 1 , act = None )
29
29
y = fluid .layers .data (name = 'y' , shape = [1 ], dtype = 'float32' )
@@ -39,15 +39,8 @@ def run_pserver(use_cuda, sync_mode, ip, port, trainer_count, trainer_id):
39
39
place = fluid .CUDAPlace (0 ) if use_cuda else fluid .CPUPlace ()
40
40
exe = fluid .Executor (place )
41
41
42
- port = os .getenv ("PADDLE_INIT_PORT" , port )
43
- pserver_ips = os .getenv ("PADDLE_INIT_PSERVERS" , ip ) # ip,ip...
44
- eplist = []
45
- for ip in pserver_ips .split ("," ):
46
- eplist .append (':' .join ([ip , port ]))
47
- pserver_endpoints = "," .join (eplist ) # ip:port,ip:port...
48
- trainers = int (os .getenv ("TRAINERS" , trainer_count ))
49
- current_endpoint = os .getenv ("POD_IP" , ip ) + ":" + port
50
- trainer_id = int (os .getenv ("PADDLE_INIT_TRAINER_ID" , trainer_id ))
42
+ pserver_endpoints = ip + ":" + port
43
+ current_endpoint = ip + ":" + port
51
44
t = fluid .DistributeTranspiler ()
52
45
t .transpile (
53
46
trainer_id ,
@@ -62,47 +55,51 @@ def run_pserver(use_cuda, sync_mode, ip, port, trainer_count, trainer_id):
62
55
63
56
class TestListenAndServOp (OpTest ):
64
57
def setUp (self ):
65
- self .sleep_time = 5
58
+ self .ps_timeout = 5
66
59
self .ip = "127.0.0.1"
67
60
self .port = "6173"
68
- self .trainer_count = 1
61
+ self .trainers = 1
69
62
self .trainer_id = 1
70
63
71
- def _raise_signal (self , parent_pid , raised_signal ):
72
- time .sleep (self .sleep_time )
73
- ps_command = subprocess .Popen (
74
- "ps -o pid --ppid %d --noheaders" % parent_pid ,
75
- shell = True ,
76
- stdout = subprocess .PIPE )
77
- ps_output = ps_command .stdout .read ()
78
- retcode = ps_command .wait ()
79
- assert retcode == 0 , "ps command returned %d" % retcode
80
-
81
- for pid_str in ps_output .split ("\n " )[:- 1 ]:
82
- try :
83
- os .kill (int (pid_str ), raised_signal )
84
- except Exception :
85
- continue
86
-
87
64
def _start_pserver (self , use_cuda , sync_mode ):
88
65
p = Process (
89
66
target = run_pserver ,
90
- args = (use_cuda , sync_mode , self .ip , self .port , self .trainer_count ,
67
+ args = (use_cuda , sync_mode , self .ip , self .port , self .trainers ,
91
68
self .trainer_id ))
92
69
p .start ()
70
+ return p .pid
71
+
72
+ def _wait_ps_ready (self , pid ):
73
+ retry_times = self .ps_timeout
74
+ while True :
75
+ assert retry_times >= 0 , "wait ps ready failed"
76
+ time .sleep (0.5 )
77
+ try :
78
+ # the listen_and_serv_op would touch a file which contains the listen port
79
+ # on the /tmp directory until it was ready to process all the RPC call.
80
+ os .stat ("/tmp/paddle.%d.port" % pid )
81
+ return
82
+ except os .error :
83
+ retry_times -= 1
84
+
85
+ def test_rpc_interfaces (self ):
86
+ # TODO(Yancey1989): need to make sure the rpc interface correctly.
87
+ pass
93
88
94
89
def test_handle_signal_in_serv_op (self ):
95
90
# run pserver on CPU in sync mode
96
- self ._start_pserver (False , True )
91
+ pid = self ._start_pserver (False , True )
92
+ self ._wait_ps_ready (pid )
97
93
98
- # raise SIGINT to pserver
99
- self . _raise_signal ( os .getpid () , signal .SIGINT )
94
+ # raise SIGTERM to pserver
95
+ os .kill ( pid , signal .SIGTERM )
100
96
101
97
# run pserver on CPU in async mode
102
- self ._start_pserver (False , False )
98
+ pid = self ._start_pserver (False , False )
99
+ self ._wait_ps_ready (pid )
103
100
104
101
# raise SIGTERM to pserver
105
- self . _raise_signal ( os .getpid () , signal .SIGTERM )
102
+ os .kill ( pid , signal .SIGTERM )
106
103
107
104
108
105
if __name__ == '__main__' :
0 commit comments