Skip to content

Commit a619695

Browse files
authored
Feature/enhance evaluator (#5824)
* Stash * Stash * Polish Evaluator * Merge code * Revert
1 parent 1f6002e commit a619695

File tree

8 files changed

+144
-213
lines changed

8 files changed

+144
-213
lines changed

paddle/operators/math/selected_rows_functor.cu

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -227,7 +227,6 @@ template struct SelectedRowsAddToTensor<platform::GPUPlace, float>;
227227
template struct SelectedRowsAddToTensor<platform::GPUPlace, double>;
228228
template struct SelectedRowsAddToTensor<platform::GPUPlace, int>;
229229
template struct SelectedRowsAddToTensor<platform::GPUPlace, int64_t>;
230-
231230
} // namespace math
232231
} // namespace operators
233232
} // namespace paddle

python/paddle/v2/fluid/evaluator.py

Lines changed: 96 additions & 151 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,14 @@
11
import numpy as np
2-
from paddle.v2.fluid.framework import Program, g_main_program, unique_name, Variable
3-
import paddle.v2.fluid.core as core
42

3+
import paddle.v2.fluid.layers as layers
4+
from paddle.v2.fluid.framework import Program, unique_name, \
5+
Variable
6+
from paddle.v2.fluid.layer_helper import LayerHelper
57

6-
def _clone_var_in_block_(block, var):
8+
__all__ = ['Accuracy']
9+
10+
11+
def _clone_var_(block, var):
712
assert isinstance(var, Variable)
813
return block.create_var(
914
name=var.name,
@@ -16,175 +21,115 @@ def _clone_var_in_block_(block, var):
1621

1722
class Evaluator(object):
1823
"""
19-
Evalutor Base class.
20-
21-
create metric states
22-
add mini-batch evaluator caculate operator
23-
add increment operator to accumulate the metric states
24+
Base Class for all evaluators
25+
26+
Args:
27+
name(str): The name of evaluator. such as, "accuracy". Used for generate
28+
temporary variable name.
29+
main_program(Program, optional): The evaluator should be added to this
30+
main_program. Default g_main_program
31+
startup_program(Program, optional):The parameter should be added to this
32+
startup_program. Default g_startup_program
33+
34+
Attributes:
35+
states(list): The list of state variables. states will be reset to zero
36+
when `reset` is invoked.
37+
metrics(list): The list of metrics variables. They will be calculate
38+
every mini-batch
2439
"""
2540

2641
def __init__(self, name, **kwargs):
42+
self.states = []
43+
self.metrics = []
44+
self.helper = LayerHelper(name, **kwargs)
45+
46+
def reset(self, executor, reset_program=None):
2747
"""
28-
init the global states
48+
reset metric states at the begin of each pass/user specified batch
2949
"""
30-
self._states = {}
31-
if kwargs.has_key("main_program"):
32-
self._main_program = kwargs.get("main_program")
33-
else:
34-
self._main_program = g_main_program
50+
if reset_program is None:
51+
reset_program = Program()
52+
53+
for var in self.states:
54+
assert isinstance(var, Variable)
55+
g_var = _clone_var_(reset_program.current_block(), var)
56+
layers.fill_constant(
57+
shape=g_var.shape,
58+
value=0.0,
59+
dtype=g_var.dtype,
60+
out=g_var,
61+
main_program=reset_program)
3562

36-
def states(self):
37-
return self._states
63+
executor.run(reset_program)
3864

39-
def _update_ops(self, *args, **kwargs):
65+
def eval(self, executor, eval_program=None):
4066
"""
41-
append update ops to the global states
67+
Evaluate the statistics merged by multiple mini-batches.
4268
"""
4369
raise NotImplementedError()
4470

45-
def reset(self, executor, reset_program=None):
71+
def create_state(self, suffix, dtype, shape):
4672
"""
47-
Clear metric states at the begin of each pass/user specified batch
48-
"""
49-
if reset_program == None:
50-
reset_program = Program()
51-
else:
52-
reset_program = program
53-
block = reset_program.global_block()
54-
for k, var in self._states.iteritems():
55-
g_var = _clone_var_in_block_(block, var)
56-
zeros = block.create_var(dtype="float32", persistable=True)
57-
block.append_op(
58-
type="fill_constant",
59-
outputs={"Out": [zeros]},
60-
attrs={
61-
"shape": g_var.shape,
62-
"value": .0,
63-
"dtype": 5,
64-
})
65-
block.append_op(
66-
type="scale", inputs={"X": zeros}, outputs={"Out": g_var})
67-
executor.run(reset_program, fetch_list=self._states.values())
73+
Create state variable.
74+
75+
NOTE: It is not a public API.
76+
77+
Args:
78+
suffix(str): the state suffix.
79+
dtype(str|core.DataType): the state data type
80+
shape(tuple|list): the shape of state
81+
82+
Returns: State variable
6883
69-
def eval(self, executor, eval_program=None):
70-
"""
71-
Merge the mini-batch statistics to form the evaluation result for multiple mini-batches.
7284
"""
73-
raise NotImplementedError()
85+
state = self.helper.create_variable(
86+
name="_".join([unique_name(self.helper.name), suffix]),
87+
persistable=True,
88+
dtype=dtype,
89+
shape=shape)
90+
self.states.append(state)
91+
return state
7492

7593

7694
class Accuracy(Evaluator):
7795
"""
78-
Accuracy need two state variable Total, Correct
96+
Average Accuracy for multiple mini-batches.
7997
"""
8098

81-
def __init__(self, *args, **kwargs):
99+
def __init__(self, input, label, k=1, **kwargs):
82100
super(Accuracy, self).__init__("accuracy", **kwargs)
83-
block = self._main_program.global_block()
84-
g_total = block.create_var(
85-
name=unique_name("Total"),
86-
persistable=True,
87-
dtype="int64",
88-
shape=[1])
89-
g_correct = block.create_var(
90-
name=unique_name("Correct"),
91-
persistable=True,
92-
dtype="int64",
93-
shape=[1])
94-
self._states["Total"] = g_total
95-
self._states["Correct"] = g_correct
96-
97-
def _update_ops(self, input, label, k=1, **kwargs):
98-
block = self._main_program.global_block()
99-
topk_out = block.create_var(dtype=input.dtype)
100-
topk_indices = block.create_var(dtype="int64")
101-
block.append_op(
102-
type="top_k",
103-
inputs={"X": [input]},
104-
outputs={"Out": [topk_out],
105-
"Indices": [topk_indices]},
106-
attrs={"k": k})
107-
acc_out = block.create_var(dtype=kwargs.get("out_dtype", "float32"))
108-
correct = block.create_var(dtype="int64", persistable=True)
109-
total = block.create_var(dtype="int64", persistable=True)
110-
block.append_op(
111-
type="accuracy",
112-
inputs={
113-
"Out": [topk_out],
114-
"Indices": [topk_indices],
115-
"Label": [label]
116-
},
117-
outputs={
118-
"Accuracy": [acc_out],
119-
"Correct": [correct],
120-
"Total": [total],
121-
})
122-
123-
block.append_op(
124-
type="cast",
125-
inputs={"X": [self._states["Total"]]},
126-
outputs={"Out": [self._states["Total"]]},
127-
attrs={
128-
"in_dtype": 5, # float32
129-
"out_dtype": 2, # int32
130-
})
131-
block.append_op(
132-
type="cast",
133-
inputs={"X": [self._states["Correct"]]},
134-
outputs={"Out": [self._states["Correct"]]},
135-
attrs={
136-
"in_dtype": 5,
137-
"out_dtype": 2,
138-
})
139-
140-
block.append_op(
141-
type="elementwise_add",
142-
inputs={"X": [self._states["Total"]],
143-
"Y": [total]},
144-
outputs={"Out": [self._states["Total"]]})
145-
block.append_op(
146-
type="elementwise_add",
147-
inputs={"X": [self._states["Correct"]],
148-
"Y": [correct]},
149-
outputs={"Out": [self._states["Correct"]]})
150-
151-
return acc_out
101+
main_program = self.helper.main_program
102+
if main_program.current_block().idx != 0:
103+
raise ValueError("You can only invoke Evaluator in root block")
104+
105+
self.total = self.create_state(dtype='int64', shape=[1], suffix='total')
106+
self.correct = self.create_state(
107+
dtype='int64', shape=[1], suffix='correct')
108+
kwargs = {'main_program': main_program}
109+
total = self.helper.create_tmp_variable(dtype='int')
110+
correct = self.helper.create_tmp_variable(dtype='int')
111+
acc = layers.accuracy(
112+
input=input,
113+
label=label,
114+
k=k,
115+
total=total,
116+
correct=correct,
117+
**kwargs)
118+
total = layers.cast(x=total, dtype='int64', **kwargs)
119+
correct = layers.cast(x=correct, dtype='int64', **kwargs)
120+
layers.sums(input=[self.total, total], out=self.total, **kwargs)
121+
layers.sums(input=[self.correct, correct], out=self.correct, **kwargs)
122+
123+
self.metrics.append(acc)
152124

153125
def eval(self, executor, eval_program=None):
154-
if eval_program != None:
155-
eval_program = eval_program
156-
else:
126+
if eval_program is None:
157127
eval_program = Program()
158-
block = eval_program.global_block()
159-
eval_out = block.create_var(dtype=self._states["Total"].dtype)
160-
e_total = _clone_var_in_block_(block, self._states["Total"])
161-
e_correct = _clone_var_in_block_(block, self._states["Correct"])
162-
block.append_op(
163-
type="cast",
164-
inputs={"X": [e_total]},
165-
outputs={"Out": [e_total]},
166-
attrs={
167-
"in_dtype": 2, # int32
168-
"out_dtype": 5, # float32
169-
})
170-
block.append_op(
171-
type="cast",
172-
inputs={"X": [e_correct]},
173-
outputs={"Out": [e_correct]},
174-
attrs={
175-
"in_dtype": 2,
176-
"out_dtype": 5,
177-
})
178-
block.append_op(
179-
type="elementwise_div",
180-
inputs={"X": e_correct,
181-
"Y": e_total},
182-
outputs={"Out": eval_out})
183-
out = executor.run(eval_program, fetch_list=[eval_out])
184-
return np.array(out[0])
185-
186-
187-
def accuracy(*args, **kwargs):
188-
cls = Accuracy(*args, **kwargs)
189-
out = cls._update_ops(*args, **kwargs)
190-
return cls, out
128+
block = eval_program.current_block()
129+
kwargs = {'main_program': eval_program}
130+
total = _clone_var_(block, self.total)
131+
correct = _clone_var_(block, self.correct)
132+
total = layers.cast(total, dtype='float32', **kwargs)
133+
correct = layers.cast(correct, dtype='float32', **kwargs)
134+
out = layers.elementwise_div(x=correct, y=total, **kwargs)
135+
return np.array(executor.run(eval_program, fetch_list=[out])[0])

python/paddle/v2/fluid/layers.py

Lines changed: 30 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -418,6 +418,7 @@ def func(**kwargs):
418418
_create_op_func_('mean')
419419
_create_op_func_('mul')
420420
_create_op_func_('elementwise_add')
421+
_create_op_func_('elementwise_div')
421422
_create_op_func_('dropout')
422423
_create_op_func_('reshape')
423424
_create_op_func_('sigmoid')
@@ -457,13 +458,14 @@ def concat(input, axis, main_program=None, startup_program=None):
457458
return out
458459

459460

460-
def sums(input, main_program=None, startup_program=None):
461+
def sums(input, out=None, main_program=None, startup_program=None):
461462
"""
462463
This function takes in the input and performs the sum operation on it
463464
and returns that as the output.
464465
"""
465466
helper = LayerHelper('sum', **locals())
466-
out = helper.create_tmp_variable(dtype=helper.input_dtype())
467+
if out is None:
468+
out = helper.create_tmp_variable(dtype=helper.input_dtype())
467469
helper.append_op(type='sum', inputs={'X': input}, outputs={'Out': out})
468470
return out
469471

@@ -606,7 +608,7 @@ def square_error_cost(input, label, **kwargs):
606608
return square_out
607609

608610

609-
def accuracy(input, label, k=1, **kwargs):
611+
def accuracy(input, label, k=1, correct=None, total=None, **kwargs):
610612
"""
611613
This function computes the accuracy using the input and label.
612614
The output is the top_k inputs and their indices.
@@ -620,10 +622,11 @@ def accuracy(input, label, k=1, **kwargs):
620622
outputs={"Out": [topk_out],
621623
"Indices": [topk_indices]},
622624
attrs={"k": k})
623-
acc_out_dtype = kwargs.get("out_dtype", "float32")
624625
acc_out = helper.create_tmp_variable(dtype="float32")
625-
correct = helper.create_tmp_variable(dtype="int64")
626-
total = helper.create_tmp_variable(dtype="int64")
626+
if correct is None:
627+
correct = helper.create_tmp_variable(dtype="int64")
628+
if total is None:
629+
total = helper.create_tmp_variable(dtype="int64")
627630
helper.append_op(
628631
type="accuracy",
629632
inputs={
@@ -1355,6 +1358,19 @@ def lod_rank_table(x, level=0, main_program=None):
13551358
return table
13561359

13571360

1361+
def topk(input, k, main_program=None, startup_program=None):
1362+
helper = LayerHelper('topk', **locals())
1363+
topk_out = helper.create_tmp_variable(dtype=input.data_type)
1364+
topk_indices = helper.create_tmp_variable(dtype='int64')
1365+
helper.append_op(
1366+
type='top_k',
1367+
inputs={'X': [input]},
1368+
outputs={'Out': [topk_out],
1369+
'Indices': [topk_indices]},
1370+
attrs={'k': k})
1371+
return topk_out, topk_indices
1372+
1373+
13581374
def lod_tensor_to_array(x, table, main_program=None):
13591375
"""
13601376
This function creates an operator to convert an LOD_Tensor to
@@ -1388,14 +1404,20 @@ def array_to_lod_tensor(x, table, main_program=None):
13881404
return tmp
13891405

13901406

1391-
def fill_constant(shape, dtype, value, main_program=None, startup_program=None):
1407+
def fill_constant(shape,
1408+
dtype,
1409+
value,
1410+
out=None,
1411+
main_program=None,
1412+
startup_program=None):
13921413
"""
13931414
This function creates a tensor , with shape as mentioned in the input and
13941415
specified dtype and fills this up with a constant value that
13951416
comes in the input. It also sets the stop_gradient to be True.
13961417
"""
13971418
helper = LayerHelper("fill_constant", **locals())
1398-
out = helper.create_tmp_variable(dtype=dtype)
1419+
if out is None:
1420+
out = helper.create_tmp_variable(dtype=dtype)
13991421
helper.append_op(
14001422
type='fill_constant',
14011423
inputs={},

0 commit comments

Comments
 (0)