@@ -266,7 +266,10 @@ def train(avg_loss, infer_prog, optimizer, train_reader, test_reader, batch_acc,
266
266
# FIXME(wuyi): For use_reader_op, if the current
267
267
# pass is not the last, the last batch of this pass
268
268
# is also equal to args.batch_size.
269
- num_samples += len (args .batch_size )
269
+ if args .use_reader_op :
270
+ num_samples += args .batch_size
271
+ else :
272
+ num_samples += len (data )
270
273
train_losses .append (loss )
271
274
print ("Pass: %d, Iter: %d, Loss: %f\n " %
272
275
(pass_id , iters , np .mean (train_losses )))
@@ -350,9 +353,6 @@ def train_parallel(avg_loss, infer_prog, optimizer, train_reader, test_reader,
350
353
if iters == args .skip_batch_num :
351
354
start_time = time .time ()
352
355
num_samples = 0
353
- # NOTE: if use reader ops, the input data is not splited to multiple cards
354
- if args .use_reader_op and iters >= args .iterations / args .gpus :
355
- break
356
356
if args .use_fake_data or args .use_reader_op :
357
357
try :
358
358
loss , = exe .run ([avg_loss .name ])
@@ -362,7 +362,10 @@ def train_parallel(avg_loss, infer_prog, optimizer, train_reader, test_reader,
362
362
loss , = exe .run ([avg_loss .name ], feed = feeder .feed (data ))
363
363
if args .update_method == "pserver" :
364
364
exe .bcast_params ()
365
- num_samples += len (data )
365
+ if args .use_reader_op :
366
+ num_samples += args .batch_size
367
+ else :
368
+ num_samples += len (data )
366
369
iters += 1
367
370
if batch_id % 1 == 0 :
368
371
print ("Pass %d, batch %d, loss %s" %
0 commit comments