Skip to content

Commit 6fdd5de

Browse files
author
yi.wu
committed
update
1 parent 8893cf1 commit 6fdd5de

File tree

1 file changed

+8
-5
lines changed

1 file changed

+8
-5
lines changed

benchmark/fluid/fluid_benchmark.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -266,7 +266,10 @@ def train(avg_loss, infer_prog, optimizer, train_reader, test_reader, batch_acc,
266266
# FIXME(wuyi): For use_reader_op, if the current
267267
# pass is not the last, the last batch of this pass
268268
# is also equal to args.batch_size.
269-
num_samples += len(args.batch_size)
269+
if args.use_reader_op:
270+
num_samples += args.batch_size
271+
else:
272+
num_samples += len(data)
270273
train_losses.append(loss)
271274
print("Pass: %d, Iter: %d, Loss: %f\n" %
272275
(pass_id, iters, np.mean(train_losses)))
@@ -350,9 +353,6 @@ def train_parallel(avg_loss, infer_prog, optimizer, train_reader, test_reader,
350353
if iters == args.skip_batch_num:
351354
start_time = time.time()
352355
num_samples = 0
353-
# NOTE: if use reader ops, the input data is not splited to multiple cards
354-
if args.use_reader_op and iters >= args.iterations / args.gpus:
355-
break
356356
if args.use_fake_data or args.use_reader_op:
357357
try:
358358
loss, = exe.run([avg_loss.name])
@@ -362,7 +362,10 @@ def train_parallel(avg_loss, infer_prog, optimizer, train_reader, test_reader,
362362
loss, = exe.run([avg_loss.name], feed=feeder.feed(data))
363363
if args.update_method == "pserver":
364364
exe.bcast_params()
365-
num_samples += len(data)
365+
if args.use_reader_op:
366+
num_samples += args.batch_size
367+
else:
368+
num_samples += len(data)
366369
iters += 1
367370
if batch_id % 1 == 0:
368371
print("Pass %d, batch %d, loss %s" %

0 commit comments

Comments
 (0)