@@ -82,8 +82,18 @@ def run_trainer(self, place, args):
82
82
strategy = fluid .ExecutionStrategy ()
83
83
strategy .num_threads = 1
84
84
strategy .allow_op_delay = False
85
+ build_stra = fluid .BuildStrategy ()
86
+
87
+ if args .use_reduce :
88
+ build_stra .reduce_strategy = fluid .BuildStrategy .ReduceStrategy .Reduce
89
+ else :
90
+ build_stra .reduce_strategy = fluid .BuildStrategy .ReduceStrategy .AllReduce
91
+
85
92
exe = fluid .ParallelExecutor (
86
- True , loss_name = avg_cost .name , exec_strategy = strategy )
93
+ True ,
94
+ loss_name = avg_cost .name ,
95
+ exec_strategy = strategy ,
96
+ build_strategy = build_stra )
87
97
88
98
feed_var_list = [
89
99
var for var in trainer_prog .global_block ().vars .values ()
@@ -123,6 +133,7 @@ def runtime_main(test_class):
123
133
'--current_endpoint' , type = str , required = False , default = "" )
124
134
parser .add_argument ('--sync_mode' , action = 'store_true' )
125
135
parser .add_argument ('--mem_opt' , action = 'store_true' )
136
+ parser .add_argument ('--use_reduce' , action = 'store_true' )
126
137
127
138
args = parser .parse_args ()
128
139
@@ -149,20 +160,25 @@ def setUp(self):
149
160
self ._python_interp = "python"
150
161
self ._sync_mode = True
151
162
self ._mem_opt = False
163
+ self ._use_reduce = False
152
164
self ._setup_config ()
153
165
154
166
def start_pserver (self , model_file , check_error_log ):
155
-
156
167
ps0_ep , ps1_ep = self ._ps_endpoints .split ("," )
157
- ps_cmd = "%s %s --role pserver --endpoints %s --trainer_id 0 --current_endpoint %s --trainers %d --is_dist %s %s"
158
- sync_mode_str = "--sync_mode" if self ._sync_mode else ""
159
- mem_opt_str = "--mem_opt" if self ._mem_opt else ""
168
+ ps_cmd = "%s %s --role pserver --endpoints %s --trainer_id 0 --current_endpoint %s --trainers %d --is_dist"
160
169
ps0_cmd = ps_cmd % \
161
170
(self ._python_interp , model_file , self ._ps_endpoints , ps0_ep ,
162
- self ._trainers , sync_mode_str , mem_opt_str )
171
+ self ._trainers )
163
172
ps1_cmd = ps_cmd % \
164
173
(self ._python_interp , model_file , self ._ps_endpoints , ps1_ep ,
165
- self ._trainers , sync_mode_str , mem_opt_str )
174
+ self ._trainers )
175
+
176
+ if self ._sync_mode :
177
+ ps0_cmd += " --sync_mode"
178
+ ps1_cmd += " --sync_mode"
179
+ if self ._mem_opt :
180
+ ps0_cmd += " --mem_opt"
181
+ ps1_cmd += " --mem_opt"
166
182
167
183
ps0_pipe = subprocess .PIPE
168
184
ps1_pipe = subprocess .PIPE
@@ -242,17 +258,23 @@ def check_with_place(self, model_file, delta=1e-3, check_error_log=False):
242
258
self ._wait_ps_ready (ps1 .pid )
243
259
244
260
ps0_ep , ps1_ep = self ._ps_endpoints .split ("," )
245
- tr_cmd = "%s %s --role trainer --endpoints %s --trainer_id %d --current_endpoint %s --trainers %d --is_dist %s %s"
246
- sync_mode_str = "--sync_mode" if self ._sync_mode else ""
247
- mem_opt_str = "--mem_opt" if self ._mem_opt else ""
261
+ tr_cmd = "%s %s --role trainer --endpoints %s --trainer_id %d --current_endpoint %s --trainers %d --is_dist"
248
262
tr0_cmd = tr_cmd % \
249
263
(self ._python_interp , model_file , self ._ps_endpoints ,
250
- 0 , ps0_ep ,
251
- self ._trainers , sync_mode_str , mem_opt_str )
264
+ 0 , ps0_ep , self ._trainers )
252
265
tr1_cmd = tr_cmd % \
253
266
(self ._python_interp , model_file , self ._ps_endpoints ,
254
- 1 , ps1_ep ,
255
- self ._trainers , sync_mode_str , mem_opt_str )
267
+ 1 , ps1_ep , self ._trainers )
268
+
269
+ if self ._sync_mode :
270
+ tr0_cmd += " --sync_mode"
271
+ tr1_cmd += " --sync_mode"
272
+ if self ._mem_opt :
273
+ tr0_cmd += " --mem_opt"
274
+ tr1_cmd += " --mem_opt"
275
+ if self ._use_reduce :
276
+ tr0_cmd += " --use_reduce"
277
+ tr1_cmd += " --use_reduce"
256
278
257
279
env0 = {"CUDA_VISIBLE_DEVICES" : "0" }
258
280
env1 = {"CUDA_VISIBLE_DEVICES" : "1" }
@@ -303,6 +325,8 @@ def check_with_place(self, model_file, delta=1e-3, check_error_log=False):
303
325
# FIXME: use terminate() instead of sigkill.
304
326
os .kill (ps0 .pid , signal .SIGKILL )
305
327
os .kill (ps1 .pid , signal .SIGKILL )
328
+ ps0 .terminate ()
329
+ ps1 .terminate ()
306
330
ps0 .wait ()
307
331
ps1 .wait ()
308
332
FNULL .close ()
0 commit comments