Skip to content

Commit 7679edf

Browse files
authored
Merge pull request #11374 from guochaorong/fix_fluid_benchmark
fix bugs in fluid_benchmark
2 parents 431491a + 90a74a9 commit 7679edf

File tree

2 files changed

+10
-6
lines changed

2 files changed

+10
-6
lines changed

benchmark/fluid/fluid_benchmark.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -180,7 +180,7 @@ def train(avg_loss, infer_prog, optimizer, train_reader, test_reader, batch_acc,
180180
print_train_time(start_time, time.time(), num_samples)
181181
print("Pass: %d, Loss: %f" % (pass_id, np.mean(train_losses))),
182182
# evaluation
183-
if not args.no_test and batch_acc:
183+
if not args.no_test and batch_acc and not args.use_reader_op:
184184
pass_test_acc = test(exe, infer_prog, test_reader, feeder,
185185
batch_acc)
186186
print(", Test Accuracy: %f" % pass_test_acc)
@@ -277,11 +277,12 @@ def train_parallel(avg_loss, infer_prog, optimizer, train_reader, test_reader,
277277
batch_id += 1
278278

279279
print_train_time(start_time, time.time(), num_samples)
280-
if not args.no_test and batch_acc:
280+
if not args.no_test and batch_acc and not args.use_reader_op:
281+
# we have not implement record io for test
282+
# skip test when use args.use_reader_op
281283
test_acc = test(startup_exe, infer_prog, test_reader, feeder,
282284
batch_acc)
283285
print("Pass: %d, Test Accuracy: %f\n" % (pass_id, test_acc))
284-
exit(0)
285286

286287

287288
def print_arguments(args):

benchmark/fluid/models/resnet.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -199,7 +199,10 @@ def get_model(args):
199199
batched_train_reader = paddle.batch(
200200
paddle.reader.shuffle(
201201
train_reader, buf_size=5120),
202-
batch_size=args.batch_size * args.gpus)
203-
batched_test_reader = paddle.batch(train_reader, batch_size=args.batch_size)
202+
batch_size=args.batch_size * args.gpus,
203+
drop_last=True)
204+
batched_test_reader = paddle.batch(
205+
train_reader, batch_size=args.batch_size, drop_last=True)
204206

205-
return avg_cost, inference_program, optimizer, batched_train_reader, batched_test_reader, batch_acc
207+
return avg_cost, inference_program, optimizer, batched_train_reader,\
208+
batched_test_reader, batch_acc

0 commit comments

Comments
 (0)