Skip to content

Commit cfbc92e

Browse files
committed
"polish document"
1 parent 46c61b3 commit cfbc92e

File tree

2 files changed

+19
-47
lines changed

2 files changed

+19
-47
lines changed

doc/design/evaluator.md

Lines changed: 19 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -15,35 +15,44 @@ Currently, every operation is expressed in the graph. we divide the evaluator pr
1515
3. Merge the mini-batch statistics to form the evaluation result for multiple mini-batches. When it comes to distributed training/Multi-GPU training, aggregate the value from different devices.
1616

1717
### Implementation
18-
This design is shown in python API. There would be an abstract python interface and multiple inheritances for each evaluation method.
18+
This design is shown in python API.
19+
Each metric operator need to caculate the metric statistic and return the batch aware states, Python side responsible for accumulate the states for each pass.
1920

21+
2022
```python
2123
class Evaluator(object):
2224
"""
2325
Evaluator Base class.
2426
"""
25-
def __init__(self):
27+
def __init__(self, name, **kwargs):
2628
"""
2729
Different evaluator may has different metric states. E.g, Accuracy need two variables, total and right sample counts.
2830
Auc need four variables, `true_positives`,
29-
`true_negatives`, `false_positives` and `false_negatives`. So every evaluator should create its needed variables and append the related mini-batch operator to main_program
31+
`true_negatives`, `false_positives` and `false_negatives`. So every evaluator should create its needed variables and append to main_program
3032
3133
The initialization of Evaluator should be responsible for:
3234
create metric states and append to the main_program
33-
add mini-batch evaluator caculate operators to the main_program
34-
add increment operator to accumulate the metric states
3535
"""
3636
pass
3737

38-
def clear(self):
38+
def _update_ops(self, input, label, **kwargs)
39+
"""
40+
Add mini-batch evaluator caculate operators to the main_program.
41+
Add increment operator to accumulate the metric states.
42+
"""
43+
44+
45+
def reset(self, executor, program=None):
3946
"""
40-
clear metric states at the begin of each pass/user specified batch
47+
Reset metric states at the begin of each pass/user specified batch number.
48+
Execute the reset_program to reset the states.
4149
"""
42-
return init_program
50+
4351

44-
def evaluate(self):
52+
def eval(self, executor, program=None):
4553
"""
4654
Merge the mini-batch statistics to form the evaluation result for multiple mini-batches.
55+
Execute the eval_program and return the result.
4756
"""
48-
return eval_program
57+
return eval_result
4958
```

python/paddle/v2/framework/evaluator.py

Lines changed: 0 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -181,43 +181,6 @@ def eval(self, executor, program=None):
181181
return np.array(out[0])
182182

183183

184-
# Demo for composing low level ops to compute the F1 metric
185-
class FScore(Evaluator):
186-
def __init__(self, input, label, beta=1.0, **kwargs):
187-
super(F1, self).__init__("FScore", **kwargs)
188-
block = self._program.global_block()
189-
g_tp = block.create_var(
190-
name=unique_name("Tp"), persistable=True, dtype="int64", shape=[1])
191-
g_fn = block.create_var(
192-
name=unique_name("Fn"), persistable=True, dtype="int64", shape=[1])
193-
g_fp = block.create_var(
194-
name=unique_name("Fp"), persistable=True, dtype="int64", shape=[1])
195-
196-
self._states["Tp"] = g_tp
197-
self._states["Fp"] = g_fp
198-
self._states["Fn"] = g_fn
199-
200-
def _update_ops(self):
201-
block = self._program.global_block()
202-
equal_out = block.create_var()
203-
block.append_op(
204-
type="equal",
205-
inputs={"X": [input],
206-
"Y": [label]},
207-
outputs={"Out": equal_out})
208-
209-
positive = block.create_var()
210-
block.append_op(
211-
type="sequence_pool",
212-
inputs={"X": [equal_out]},
213-
outputs={"Out": positive},
214-
attrs={"pooltype": "SUM"})
215-
batch = block.create_var(
216-
name=feed_var_name,
217-
type=core.VarDesc.VarType.FEED_MINIBATCH,
218-
persistable=True)
219-
220-
221184
# FIXME(dzh): add a decorator to call _update_ops automatically
222185
def accuracy(*args, **kwargs):
223186
cls = Accuracy(*args, **kwargs)

0 commit comments

Comments
 (0)