Skip to content

Commit 9e1799c

Browse files
committed
"fix based on comments"
1 parent cfbc92e commit 9e1799c

File tree

3 files changed

+5
-11
lines changed

3 files changed

+5
-11
lines changed

doc/design/evaluator.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,14 +42,14 @@ class Evaluator(object):
4242
"""
4343

4444

45-
def reset(self, executor, program=None):
45+
def reset(self, executor, reset_program=None):
4646
"""
4747
Reset metric states at the begin of each pass/user specified batch number.
4848
Execute the reset_program to reset the states.
4949
"""
5050

5151

52-
def eval(self, executor, program=None):
52+
def eval(self, executor, eval_program=None):
5353
"""
5454
Merge the mini-batch statistics to form the evaluation result for multiple mini-batches.
5555
Execute the eval_program and return the result.

python/paddle/v2/framework/evaluator.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ def _update_ops(self, *args, **kwargs):
3939
"""
4040
raise NotImplementedError()
4141

42-
def reset(self, executor, program=None):
42+
def reset(self, executor, reset_program=None):
4343
"""
4444
Clear metric states at the begin of each pass/user specified batch
4545
"""
@@ -63,7 +63,7 @@ def reset(self, executor, program=None):
6363
type="scale", inputs={"X": zeros}, outputs={"Out": g_var})
6464
executor.run(reset_program, fetch_list=self._states.values())
6565

66-
def eval(self, executor, program=None):
66+
def eval(self, executor, eval_program=None):
6767
"""
6868
Merge the mini-batch statistics to form the evaluation result for multiple mini-batches.
6969
"""

python/paddle/v2/framework/tests/test_fit_a_line.py

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
from paddle.v2.framework.framework import Program, g_main_program
77
from paddle.v2.framework.io import save_persistables, load_persistables
88
from paddle.v2.framework.executor import Executor
9-
from paddle.v2.framework.evaluator import Accuracy
109

1110
import numpy as np
1211

@@ -32,8 +31,6 @@
3231
main_program=main_program,
3332
startup_program=startup_program)
3433

35-
accuracy = evaluator.Accuracy(input=y_predict, label=y)
36-
3734
cost = layers.square_error_cost(
3835
input=y_predict,
3936
label=y,
@@ -61,7 +58,6 @@
6158
for pass_id in range(PASS_NUM):
6259
save_persistables(exe, "./fit_a_line.model/", main_program=main_program)
6360
load_persistables(exe, "./fit_a_line.model/", main_program=main_program)
64-
accuracy.reset(exe)
6561
for data in train_reader():
6662
x_data = np.array(map(lambda x: x[0], data)).astype("float32")
6763
y_data = np.array(map(lambda x: x[1], data)).astype("float32")
@@ -76,10 +72,8 @@
7672
outs = exe.run(main_program,
7773
feed={'x': tensor_x,
7874
'y': tensor_y},
79-
fetch_list=[avg_cost, accuracy])
75+
fetch_list=[avg_cost])
8076
out = np.array(outs[0])
81-
pass_acc = accuracy.eval(exe)
82-
print pass_acc
8377

8478
if out[0] < 10.0:
8579
exit(0) # if avg cost less than 10.0, we think our code is good.

0 commit comments

Comments
 (0)