22
22
import subprocess
23
23
import six
24
24
import argparse
25
+ import pickle
26
+ import numpy as np
25
27
26
28
import paddle .fluid as fluid
27
29
@@ -128,10 +130,15 @@ def get_data():
128
130
else :
129
131
return origin_batch
130
132
133
+ out_losses = []
131
134
for _ in six .moves .xrange (RUN_STEP ):
132
135
loss , = exe .run (fetch_list = [avg_cost .name ],
133
136
feed = feeder .feed (get_data ()))
134
- print (loss )
137
+ out_losses .append (loss [0 ])
138
+ if six .PY2 :
139
+ print (pickle .dumps (out_losses ))
140
+ else :
141
+ sys .stdout .buffer .write (pickle .dumps (out_losses ))
135
142
136
143
137
144
def runtime_main (test_class ):
@@ -149,7 +156,7 @@ def runtime_main(test_class):
149
156
parser .add_argument ('--use_cuda' , action = 'store_true' )
150
157
parser .add_argument ('--use_reduce' , action = 'store_true' )
151
158
parser .add_argument (
152
- '--use_reader_alloc' , action = 'store_true' , required = False , default = True )
159
+ '--use_reader_alloc' , action = 'store_true' , required = False )
153
160
parser .add_argument ('--batch_size' , required = False , type = int , default = 2 )
154
161
parser .add_argument (
155
162
'--batch_merge_repeat' , required = False , type = int , default = 1 )
@@ -237,21 +244,6 @@ def start_pserver(self, model_file, check_error_log, required_envs):
237
244
238
245
return ps0_proc , ps1_proc , ps0_pipe , ps1_pipe
239
246
240
- def _wait_ps_ready (self , pid ):
241
- retry_times = 50
242
- while True :
243
- assert retry_times >= 0 , "wait ps ready failed"
244
- time .sleep (3 )
245
- try :
246
- # the listen_and_serv_op would touch a file which contains the listen port
247
- # on the /tmp directory until it was ready to process all the RPC call.
248
- os .stat ("/tmp/paddle.%d.port" % pid )
249
- return
250
- except os .error as e :
251
- sys .stderr .write ('waiting for pserver: %s, left retry %d\n ' %
252
- (e , retry_times ))
253
- retry_times -= 1
254
-
255
247
def _run_local (self ,
256
248
model ,
257
249
envs ,
@@ -288,23 +280,20 @@ def _run_local(self,
288
280
env = envs )
289
281
290
282
local_out , local_err = local_proc .communicate ()
291
- local_ret = cpt .to_text (local_out )
292
283
293
284
if check_error_log :
294
285
err_log .close ()
295
286
296
- sys .stderr .write ('local_stdout: %s\n ' % local_ret )
287
+ sys .stderr .write ('local_stdout: %s\n ' % pickle . loads ( local_out ) )
297
288
sys .stderr .write ('local_stderr: %s\n ' % local_err )
298
289
299
- local_losses = local_ret .split ("\n " )
300
- return local_losses
290
+ return pickle .loads (local_out )
301
291
302
292
def _run_cluster (self , model , envs , check_error_log ):
303
293
# Run dist train to compare with local results
304
294
ps0 , ps1 , ps0_pipe , ps1_pipe = self .start_pserver (model ,
305
295
check_error_log , envs )
306
- self ._wait_ps_ready (ps0 .pid )
307
- self ._wait_ps_ready (ps1 .pid )
296
+
308
297
ps0_ep , ps1_ep = self ._ps_endpoints .split ("," )
309
298
310
299
tr_cmd = "%s %s --role trainer --endpoints %s --trainer_id %d --current_endpoint %s --trainers %d --is_dist"
@@ -339,8 +328,8 @@ def _run_cluster(self, model, envs, check_error_log):
339
328
env0 .update (envs )
340
329
env1 .update (envs )
341
330
342
- print ("tr0_cmd:{}, env0: {} " .format (tr0_cmd , env0 ))
343
- print ("tr1_cmd:{}, env1: {} " .format (tr1_cmd , env1 ))
331
+ print ("tr0_cmd:{}" .format (tr0_cmd ))
332
+ print ("tr1_cmd:{}" .format (tr1_cmd ))
344
333
tr0_pipe = open ("/tmp/tr0_err.log" , "wb" )
345
334
tr1_pipe = open ("/tmp/tr1_err.log" , "wb" )
346
335
@@ -356,9 +345,7 @@ def _run_cluster(self, model, envs, check_error_log):
356
345
env = env1 )
357
346
358
347
tr0_out , tr0_err = tr0_proc .communicate ()
359
- tr0_loss_text = cpt .to_text (tr0_out )
360
348
tr1_out , tr1_err = tr1_proc .communicate ()
361
- tr1_loss_text = cpt .to_text (tr1_out )
362
349
363
350
# close trainer file
364
351
tr0_pipe .close ()
@@ -373,15 +360,13 @@ def _run_cluster(self, model, envs, check_error_log):
373
360
ps1 .terminate ()
374
361
375
362
# print log
376
- sys .stderr .write ('trainer 0 stdout:\n %s\n ' % tr0_loss_text )
377
- sys .stderr .write ('trainer 0 stderr:\n %s\n ' % tr0_err )
378
- sys .stderr .write ('trainer 1 stdout: %s\n ' % tr1_loss_text )
363
+ sys .stderr .write ('trainer 0 stdout: %s\n ' % pickle . loads ( tr0_out ) )
364
+ sys .stderr .write ('trainer 0 stderr: %s\n ' % tr0_err )
365
+ sys .stderr .write ('trainer 1 stdout: %s\n ' % pickle . loads ( tr1_out ) )
379
366
sys .stderr .write ('trainer 1 stderr: %s\n ' % tr1_err )
380
367
381
- tr0_losses = tr0_loss_text .split ("\n " )
382
- tr1_losses = tr1_loss_text .split ("\n " )
383
-
384
- return tr0_losses , tr1_losses
368
+ # return tr0_losses, tr1_losses
369
+ return pickle .loads (tr0_out ), pickle .loads (tr1_out )
385
370
386
371
def check_with_place (self ,
387
372
model_file ,
@@ -411,9 +396,9 @@ def check_with_place(self,
411
396
check_error_log )
412
397
413
398
for step_id in range (RUN_STEP ):
414
- local_loss = eval ( local_losses [step_id ])[ 0 ]
415
- tr0_loss = eval ( tr0_losses [step_id ])[ 0 ]
416
- tr1_loss = eval ( tr1_losses [step_id ])[ 0 ]
417
- dist_loss = (tr0_loss + tr1_loss ) / 2
418
- print (str ( local_loss ) + ":" + str ( dist_loss ) )
419
- self .assertAlmostEqual (local_loss , dist_loss , delta = delta )
399
+ local_loss = local_losses [step_id ]
400
+ tr0_loss = tr0_losses [step_id ]
401
+ tr1_loss = tr1_losses [step_id ]
402
+ dist_loss = (np . array ([ tr0_loss ]) + np . array ([ tr1_loss ]) ) / 2
403
+ print ("=======" , local_loss , ":" , dist_loss [ 0 ], "=======" )
404
+ self .assertAlmostEqual (local_loss , dist_loss [ 0 ] , delta = delta )
0 commit comments