@@ -40,10 +40,7 @@ def parse_args():
40
40
parser .add_argument (
41
41
'--batch_size' , type = int , default = 32 , help = 'The minibatch size.' )
42
42
parser .add_argument (
43
- '--learning_rate' ,
44
- type = float ,
45
- default = 0.001 ,
46
- help = 'The minibatch size.' )
43
+ '--learning_rate' , type = float , default = 0.001 , help = 'The learning rate.' )
47
44
# TODO(wuyi): add "--use_fake_data" option back.
48
45
parser .add_argument (
49
46
'--skip_batch_num' ,
@@ -231,10 +228,7 @@ def train(avg_loss, infer_prog, optimizer, train_reader, test_reader, batch_acc,
231
228
train_losses .append (loss )
232
229
print ("Pass: %d, Iter: %d, Loss: %f\n " %
233
230
(pass_id , iters , np .mean (train_losses )))
234
- train_elapsed = time .time () - start_time
235
- examples_per_sec = num_samples / train_elapsed
236
- print ('\n Total examples: %d, total time: %.5f, %.5f examples/sec\n ' %
237
- (num_samples , train_elapsed , examples_per_sec ))
231
+ print_train_time (start_time , time .time (), num_samples )
238
232
print ("Pass: %d, Loss: %f" % (pass_id , np .mean (train_losses )))
239
233
# evaluation
240
234
if not args .no_test and batch_acc != None :
@@ -315,10 +309,7 @@ def train_parallel(avg_loss, infer_prog, optimizer, train_reader, test_reader,
315
309
if batch_id % 1 == 0 :
316
310
print ("Pass %d, batch %d, loss %s" %
317
311
(pass_id , batch_id , np .array (loss )))
318
- train_elapsed = time .time () - start_time
319
- examples_per_sec = num_samples / train_elapsed
320
- print ('\n Total examples: %d, total time: %.5f, %.5f examples/sed\n ' %
321
- (num_samples , train_elapsed , examples_per_sec ))
312
+ print_train_time (start_time , time .time (), num_samples )
322
313
if not args .no_test and batch_acc != None :
323
314
test_acc = test (startup_exe , infer_prog , test_reader , feeder ,
324
315
batch_acc )
@@ -329,20 +320,27 @@ def train_parallel(avg_loss, infer_prog, optimizer, train_reader, test_reader,
329
320
def print_arguments (args ):
330
321
vars (args )['use_nvprof' ] = (vars (args )['use_nvprof' ] and
331
322
vars (args )['device' ] == 'GPU' )
332
- print ('----------- resnet Configuration Arguments -----------' )
323
+ print ('----------- Configuration Arguments -----------' )
333
324
for arg , value in sorted (vars (args ).iteritems ()):
334
325
print ('%s: %s' % (arg , value ))
335
326
print ('------------------------------------------------' )
336
327
337
328
329
+ def print_train_time (start_time , end_time , num_samples ):
330
+ train_elapsed = end_time - start_time
331
+ examples_per_sec = num_samples / train_elapsed
332
+ print ('\n Total examples: %d, total time: %.5f, %.5f examples/sed\n ' %
333
+ (num_samples , train_elapsed , examples_per_sec ))
334
+
335
+
338
336
def main ():
339
337
args = parse_args ()
340
338
print_arguments (args )
341
339
342
340
# the unique trainer id, starting from 0, needed by trainer
343
341
# only
344
342
nccl_id_var , num_trainers , trainer_id = (
345
- None , 1 , int (os .getenv ("PADDLE_TRAINER_ID" , "-1 " )))
343
+ None , 1 , int (os .getenv ("PADDLE_TRAINER_ID" , "0 " )))
346
344
347
345
if args .use_cprof :
348
346
pr = cProfile .Profile ()
0 commit comments