Skip to content

Commit bbd7580

Browse files
kexinzhaodaming-lu
authored andcommitted
simplify recognize digits example code (#10722)
1 parent 2a63652 commit bbd7580

File tree

2 files changed

+8
-20
lines changed

2 files changed

+8
-20
lines changed

python/paddle/fluid/tests/book/high-level-api/recognize_digits/test_recognize_digits_conv.py

Lines changed: 4 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -71,24 +71,18 @@ def event_handler(event):
7171
if isinstance(event, fluid.EndEpochEvent):
7272
test_reader = paddle.batch(
7373
paddle.dataset.mnist.test(), batch_size=BATCH_SIZE)
74-
test_metrics = trainer.test(
74+
avg_cost, acc = trainer.test(
7575
reader=test_reader, feed_order=['img', 'label'])
76-
avg_cost_set = test_metrics[0]
77-
acc_set = test_metrics[1]
78-
79-
# get test acc and loss
80-
acc = numpy.array(acc_set).mean()
81-
avg_cost = numpy.array(avg_cost_set).mean()
8276

8377
print("avg_cost: %s" % avg_cost)
8478
print("acc : %s" % acc)
8579

86-
if float(acc) > 0.2: # Smaller value to increase CI speed
80+
if acc > 0.2: # Smaller value to increase CI speed
8781
trainer.save_params(save_dirname)
8882
else:
8983
print('BatchID {0}, Test Loss {1:0.2}, Acc {2:0.2}'.format(
90-
event.epoch + 1, float(avg_cost), float(acc)))
91-
if math.isnan(float(avg_cost)):
84+
event.epoch + 1, avg_cost, acc))
85+
if math.isnan(avg_cost):
9286
sys.exit("got NaN loss, training failed.")
9387
elif isinstance(event, fluid.EndStepEvent):
9488
print("Step {0}, Epoch {1} Metrics {2}".format(

python/paddle/fluid/tests/book/high-level-api/recognize_digits/test_recognize_digits_mlp.py

Lines changed: 4 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -55,24 +55,18 @@ def event_handler(event):
5555
if isinstance(event, fluid.EndEpochEvent):
5656
test_reader = paddle.batch(
5757
paddle.dataset.mnist.test(), batch_size=BATCH_SIZE)
58-
test_metrics = trainer.test(
58+
avg_cost, acc = trainer.test(
5959
reader=test_reader, feed_order=['img', 'label'])
60-
avg_cost_set = test_metrics[0]
61-
acc_set = test_metrics[1]
62-
63-
# get test acc and loss
64-
acc = numpy.array(acc_set).mean()
65-
avg_cost = numpy.array(avg_cost_set).mean()
6660

6761
print("avg_cost: %s" % avg_cost)
6862
print("acc : %s" % acc)
6963

70-
if float(acc) > 0.2: # Smaller value to increase CI speed
64+
if acc > 0.2: # Smaller value to increase CI speed
7165
trainer.save_params(save_dirname)
7266
else:
7367
print('BatchID {0}, Test Loss {1:0.2}, Acc {2:0.2}'.format(
74-
event.epoch + 1, float(avg_cost), float(acc)))
75-
if math.isnan(float(avg_cost)):
68+
event.epoch + 1, avg_cost, acc))
69+
if math.isnan(avg_cost):
7670
sys.exit("got NaN loss, training failed.")
7771

7872
train_reader = paddle.batch(

0 commit comments

Comments
 (0)