@@ -55,6 +55,46 @@ def run_pserver(use_cuda, sync_mode, ip, port, trainers, trainer_id):
55
55
exe .run (pserver_prog )
56
56
57
57
58
+ def run_pserver_with_empty_block (use_cuda , sync_mode , ip , port , trainers ,
59
+ trainer_id ):
60
+ x = fluid .layers .data (name = 'x' , shape = [1 ], dtype = 'float32' )
61
+ y_predict = fluid .layers .fc (input = x , size = 1 , act = None , bias_attr = False )
62
+ y = fluid .layers .data (name = 'y' , shape = [1 ], dtype = 'float32' )
63
+
64
+ # loss function
65
+ cost = fluid .layers .square_error_cost (input = y_predict , label = y )
66
+ avg_cost = fluid .layers .mean (cost )
67
+
68
+ # optimizer
69
+ sgd_optimizer = fluid .optimizer .SGD (learning_rate = 0.001 )
70
+ sgd_optimizer .minimize (avg_cost )
71
+
72
+ place = fluid .CUDAPlace (0 ) if use_cuda else fluid .CPUPlace ()
73
+ exe = fluid .Executor (place )
74
+
75
+ ps1 = ip + ":" + str (int (port ) + 1 )
76
+ ps2 = ip + ":" + port
77
+ pserver_endpoints = ps1 + "," + ps2
78
+
79
+ config = fluid .DistributeTranspilerConfig ()
80
+ config .slice_var_up = False
81
+ t = fluid .DistributeTranspiler (config = config )
82
+ t .transpile (
83
+ trainer_id ,
84
+ pservers = pserver_endpoints ,
85
+ trainers = trainers ,
86
+ sync_mode = sync_mode )
87
+ pserver_prog = t .get_pserver_program (ps2 )
88
+
89
+ # pserver2 have no parameter
90
+ assert (len (pserver_prog .blocks ) == 2 )
91
+ assert (len (pserver_prog .blocks [1 ].ops ) == 0 )
92
+
93
+ pserver_startup = t .get_startup_program (ps2 , pserver_prog )
94
+ exe .run (pserver_startup )
95
+ exe .run (pserver_prog )
96
+
97
+
58
98
class TestListenAndServOp (OpTest ):
59
99
def setUp (self ):
60
100
self .ps_timeout = 5
@@ -63,9 +103,9 @@ def setUp(self):
63
103
self .trainers = 1
64
104
self .trainer_id = 0
65
105
66
- def _start_pserver (self , use_cuda , sync_mode ):
106
+ def _start_pserver (self , use_cuda , sync_mode , pserver_func ):
67
107
p = Process (
68
- target = run_pserver ,
108
+ target = pserver_func ,
69
109
args = (use_cuda , sync_mode , self .ip , self .port , self .trainers ,
70
110
self .trainer_id ))
71
111
p .daemon = True
@@ -92,15 +132,32 @@ def test_rpc_interfaces(self):
92
132
93
133
def test_handle_signal_in_serv_op (self ):
94
134
# run pserver on CPU in sync mode
95
- p1 = self ._start_pserver (False , True )
135
+ p1 = self ._start_pserver (False , True , run_pserver )
136
+ self ._wait_ps_ready (p1 .pid )
137
+
138
+ # raise SIGTERM to pserver
139
+ os .kill (p1 .pid , signal .SIGINT )
140
+ p1 .join ()
141
+
142
+ # run pserver on CPU in async mode
143
+ p2 = self ._start_pserver (False , False , run_pserver )
144
+ self ._wait_ps_ready (p2 .pid )
145
+
146
+ # raise SIGTERM to pserver
147
+ os .kill (p2 .pid , signal .SIGTERM )
148
+ p2 .join ()
149
+
150
+ def test_list_and_serv_run_empty_optimize_block (self ):
151
+ # run pserver on CPU in sync mode
152
+ p1 = self ._start_pserver (False , True , run_pserver_with_empty_block )
96
153
self ._wait_ps_ready (p1 .pid )
97
154
98
155
# raise SIGTERM to pserver
99
156
os .kill (p1 .pid , signal .SIGINT )
100
157
p1 .join ()
101
158
102
159
# run pserver on CPU in async mode
103
- p2 = self ._start_pserver (False , False )
160
+ p2 = self ._start_pserver (False , False , run_pserver_with_empty_block )
104
161
self ._wait_ps_ready (p2 .pid )
105
162
106
163
# raise SIGTERM to pserver
0 commit comments