@@ -76,8 +76,18 @@ def run_trainer(self, place, endpoints, trainer_id, trainers, is_dist=True):
76
76
strategy = fluid .ExecutionStrategy ()
77
77
strategy .num_threads = 1
78
78
strategy .allow_op_delay = False
79
+ build_stra = fluid .BuildStrategy ()
80
+
81
+ if args .use_reduce :
82
+ build_stra .reduce_strategy = fluid .BuildStrategy .ReduceStrategy .Reduce
83
+ else :
84
+ build_stra .reduce_strategy = fluid .BuildStrategy .ReduceStrategy .AllReduce
85
+
79
86
exe = fluid .ParallelExecutor (
80
- True , loss_name = avg_cost .name , exec_strategy = strategy )
87
+ True ,
88
+ loss_name = avg_cost .name ,
89
+ exec_strategy = strategy ,
90
+ build_strategy = build_stra )
81
91
82
92
feed_var_list = [
83
93
var for var in trainer_prog .global_block ().vars .values ()
@@ -106,16 +116,20 @@ def runtime_main(test_class):
106
116
import paddle .fluid as fluid
107
117
import paddle .fluid .core as core
108
118
109
- if len (sys .argv ) != 7 :
110
- print (
111
- "Usage: python dist_se_resnext.py [pserver/trainer] [endpoints] [trainer_id] [current_endpoint] [trainers] [is_dist]"
112
- )
113
- role = sys .argv [1 ]
114
- endpoints = sys .argv [2 ]
115
- trainer_id = int (sys .argv [3 ])
116
- current_endpoint = sys .argv [4 ]
117
- trainers = int (sys .argv [5 ])
118
- is_dist = True if sys .argv [6 ] == "TRUE" else False
119
+ parser = argparse .ArgumentParser (description = 'Run dist test.' )
120
+ parser .add_argument (
121
+ '--role' , type = str , required = True , choices = ['pserver' , 'trainer' ])
122
+ parser .add_argument ('--endpoints' , type = str , required = False , default = "" )
123
+ parser .add_argument ('--is_dist' , action = 'store_true' )
124
+ parser .add_argument ('--trainer_id' , type = int , required = False , default = 0 )
125
+ parser .add_argument ('--trainers' , type = int , required = False , default = 1 )
126
+ parser .add_argument (
127
+ '--current_endpoint' , type = str , required = False , default = "" )
128
+ parser .add_argument ('--sync_mode' , action = 'store_true' )
129
+ parser .add_argument ('--mem_opt' , action = 'store_true' )
130
+ parser .add_argument ('--use_reduce' , action = 'store_true' )
131
+
132
+ args = parser .parse_args ()
119
133
120
134
model = test_class ()
121
135
if role == "pserver" :
@@ -135,16 +149,28 @@ def setUp(self):
135
149
self ._pservers = 2
136
150
self ._ps_endpoints = "127.0.0.1:9123,127.0.0.1:9124"
137
151
self ._python_interp = "python"
152
+ self ._sync_mode = True
153
+ self ._mem_opt = False
154
+ self ._use_reduce = False
155
+ self ._setup_config ()
138
156
139
157
def start_pserver (self , model_file , check_error_log ):
140
158
ps0_ep , ps1_ep = self ._ps_endpoints .split ("," )
141
- ps0_cmd = "%s %s pserver %s 0 %s %d TRUE" % \
159
+ ps_cmd = "%s %s --role pserver --endpoints %s --trainer_id 0 --current_endpoint %s --trainers %d --is_dist"
160
+ ps0_cmd = ps_cmd % \
142
161
(self ._python_interp , model_file , self ._ps_endpoints , ps0_ep ,
143
162
self ._trainers )
144
- ps1_cmd = "%s %s pserver %s 0 %s %d TRUE" % \
163
+ ps1_cmd = ps_cmd % \
145
164
(self ._python_interp , model_file , self ._ps_endpoints , ps1_ep ,
146
165
self ._trainers )
147
166
167
+ if self ._sync_mode :
168
+ ps0_cmd += " --sync_mode"
169
+ ps1_cmd += " --sync_mode"
170
+ if self ._mem_opt :
171
+ ps0_cmd += " --mem_opt"
172
+ ps1_cmd += " --mem_opt"
173
+
148
174
ps0_pipe = subprocess .PIPE
149
175
ps1_pipe = subprocess .PIPE
150
176
if check_error_log :
@@ -226,12 +252,23 @@ def check_with_place(self, model_file, delta=1e-3, check_error_log=False):
226
252
self ._wait_ps_ready (ps1 .pid )
227
253
228
254
ps0_ep , ps1_ep = self ._ps_endpoints .split ("," )
229
- tr0_cmd = "%s %s trainer %s 0 %s %d TRUE" % \
230
- (self ._python_interp , model_file , self ._ps_endpoints , ps0_ep ,
231
- self ._trainers )
232
- tr1_cmd = "%s %s trainer %s 1 %s %d TRUE" % \
233
- (self ._python_interp , model_file , self ._ps_endpoints , ps1_ep ,
234
- self ._trainers )
255
+ tr_cmd = "%s %s --role trainer --endpoints %s --trainer_id %d --current_endpoint %s --trainers %d --is_dist"
256
+ tr0_cmd = tr_cmd % \
257
+ (self ._python_interp , model_file , self ._ps_endpoints ,
258
+ 0 , ps0_ep , self ._trainers )
259
+ tr1_cmd = tr_cmd % \
260
+ (self ._python_interp , model_file , self ._ps_endpoints ,
261
+ 1 , ps1_ep , self ._trainers )
262
+
263
+ if self ._sync_mode :
264
+ tr0_cmd += " --sync_mode"
265
+ tr1_cmd += " --sync_mode"
266
+ if self ._mem_opt :
267
+ tr0_cmd += " --mem_opt"
268
+ tr1_cmd += " --mem_opt"
269
+ if self ._use_reduce :
270
+ tr0_cmd += " --use_reduce"
271
+ tr1_cmd += " --use_reduce"
235
272
236
273
env0 = {"CUDA_VISIBLE_DEVICES" : "0" }
237
274
env1 = {"CUDA_VISIBLE_DEVICES" : "1" }
@@ -282,6 +319,10 @@ def check_with_place(self, model_file, delta=1e-3, check_error_log=False):
282
319
# FIXME: use terminate() instead of sigkill.
283
320
os .kill (ps0 .pid , signal .SIGKILL )
284
321
os .kill (ps1 .pid , signal .SIGKILL )
322
+ ps0 .terminate ()
323
+ ps1 .terminate ()
324
+ ps0 .wait ()
325
+ ps1 .wait ()
285
326
FNULL .close ()
286
327
287
328
self .assertAlmostEqual (local_first_loss , dist_first_loss , delta = delta )
0 commit comments