Skip to content

Commit 87568cf

Browse files
authored
Merge pull request #8643 from JiayiFeng/remove_evaluator
Removes Accuracy
2 parents f2cf2a7 + ea508c9 commit 87568cf

File tree

9 files changed

+162
-99
lines changed

9 files changed

+162
-99
lines changed

benchmark/cluster/vgg16/vgg16_fluid.py

Lines changed: 20 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
11
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
2-
#
2+
#
33
# Licensed under the Apache License, Version 2.0 (the "License");
44
# you may not use this file except in compliance with the License.
55
# You may obtain a copy of the License at
6-
#
6+
#
77
# http://www.apache.org/licenses/LICENSE-2.0
8-
#
8+
#
99
# Unless required by applicable law or agreed to in writing, software
1010
# distributed under the License is distributed on an "AS IS" BASIS,
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
@@ -138,13 +138,14 @@ def main():
138138
avg_cost = fluid.layers.mean(x=cost)
139139

140140
# Evaluator
141-
accuracy = fluid.evaluator.Accuracy(input=predict, label=label)
141+
batch_size = fluid.layers.create_tensor(dtype='int64')
142+
batch_acc = fluid.layers.accuracy(
143+
input=predict, label=label, total=batch_size)
142144

143145
# inference program
144146
inference_program = fluid.default_main_program().clone()
145147
with fluid.program_guard(inference_program):
146-
test_target = accuracy.metrics + accuracy.states
147-
inference_program = fluid.io.get_inference_program(test_target)
148+
inference_program = fluid.io.get_inference_program(batch_acc)
148149

149150
# Optimization
150151
optimizer = fluid.optimizer.Adam(learning_rate=args.learning_rate)
@@ -157,27 +158,30 @@ def main():
157158

158159
# test
159160
def test(exe):
160-
accuracy.reset(exe)
161+
test_pass_acc = fluid.average.WeightedAverage()
161162
for batch_id, data in enumerate(test_reader()):
162163
img_data = np.array(map(lambda x: x[0].reshape(data_shape),
163164
data)).astype("float32")
164165
y_data = np.array(map(lambda x: x[1], data)).astype("int64")
165166
y_data = y_data.reshape([-1, 1])
166167

167-
exe.run(inference_program,
168-
feed={"pixel": img_data,
169-
"label": y_data})
168+
outs = exe.run(inference_program,
169+
feed={"pixel": img_data,
170+
"label": y_data},
171+
fetch_list=[batch_acc, batch_size])
172+
test_pass_acc.add(value=np.array(outs[0]), weight=np.array(outs[1]))
170173

171-
return accuracy.eval(exe)
174+
return test_pass_acc.eval()
172175

173176
def train_loop(exe, trainer_prog):
174177
iters = 0
175178
ts = time.time()
179+
train_pass_acc = fluid.average.WeightedAverage()
176180
for pass_id in range(args.num_passes):
177181
# train
178182
start_time = time.time()
179183
num_samples = 0
180-
accuracy.reset(exe)
184+
train_pass_acc.reset()
181185
with profiler.profiler("CPU", 'total') as prof:
182186
for batch_id, data in enumerate(train_reader()):
183187
ts = time.time()
@@ -187,21 +191,22 @@ def train_loop(exe, trainer_prog):
187191
y_data = np.array(map(lambda x: x[1], data)).astype("int64")
188192
y_data = y_data.reshape([-1, 1])
189193

190-
loss, acc = exe.run(
194+
loss, acc, b_size = exe.run(
191195
trainer_prog,
192196
feed={"pixel": img_data,
193197
"label": y_data},
194-
fetch_list=[avg_cost] + accuracy.metrics)
198+
fetch_list=[avg_cost, batch_acc, batch_size])
195199
iters += 1
196200
num_samples += len(data)
201+
train_pass_acc.add(value=acc, weight=b_size)
197202
print(
198203
"Pass = %d, Iters = %d, Loss = %f, Accuracy = %f, Speed = %.2f img/s"
199204
% (pass_id, iters, loss, acc,
200205
len(data) / (time.time() - ts))
201206
) # The accuracy is the accumulation of batches, but not the current batch.
202207

203208
pass_elapsed = time.time() - start_time
204-
pass_train_acc = accuracy.eval(exe)
209+
pass_train_acc = train_pass_acc.eval()
205210
pass_test_acc = test(exe)
206211
print(
207212
"Pass = %d, Training performance = %f imgs/s, Train accuracy = %f, Test accuracy = %f\n"

python/paddle/fluid/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
import optimizer
2929
import backward
3030
import regularizer
31+
import average
3132
from param_attr import ParamAttr, WeightNormParamAttr
3233
from data_feeder import DataFeeder
3334
from core import LoDTensor, CPUPlace, CUDAPlace

python/paddle/fluid/average.py

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import numpy as np
16+
"""
17+
Class of all kinds of Average.
18+
19+
All Averages are accomplished via Python totally.
20+
They do not change Paddle's Program, nor do anything to
21+
modify NN model's configuration. They are completely
22+
wrappers of Python functions.
23+
"""
24+
25+
26+
def _is_number_(var):
27+
return isinstance(var, int) or isinstance(var, float) or (isinstance(
28+
var, np.ndarray) and var.shape == (1, ))
29+
30+
31+
def _is_number_or_matrix_(var):
32+
return _is_number_(var) or isinstance(var, np.ndarray)
33+
34+
35+
class WeightedAverage(object):
36+
def __init__(self):
37+
self.reset()
38+
39+
def reset(self):
40+
self.numerator = None
41+
self.denominator = None
42+
43+
def add(self, value, weight):
44+
if not _is_number_or_matrix_(value):
45+
raise ValueError(
46+
"The 'value' must be a number(int, float) or a numpy ndarray.")
47+
if not _is_number_(weight):
48+
raise ValueError("The 'weight' must be a number(int, float).")
49+
50+
if self.numerator is None or self.denominator is None:
51+
self.numerator = value * weight
52+
self.denominator = weight
53+
else:
54+
self.numerator += value * weight
55+
self.denominator += weight
56+
57+
def eval(self):
58+
if self.numerator is None or self.denominator is None:
59+
raise ValueError(
60+
"There is no data to be averaged in WeightedAverage.")
61+
return self.numerator / self.denominator

python/paddle/fluid/evaluator.py

Lines changed: 0 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -108,44 +108,6 @@ def create_state(self, suffix, dtype, shape):
108108
return state
109109

110110

111-
class Accuracy(Evaluator):
112-
"""
113-
Average Accuracy for multiple mini-batches.
114-
"""
115-
116-
def __init__(self, input, label, k=1, **kwargs):
117-
super(Accuracy, self).__init__("accuracy", **kwargs)
118-
main_program = self.helper.main_program
119-
if main_program.current_block().idx != 0:
120-
raise ValueError("You can only invoke Evaluator in root block")
121-
122-
self.total = self.create_state(dtype='int64', shape=[1], suffix='total')
123-
self.correct = self.create_state(
124-
dtype='int64', shape=[1], suffix='correct')
125-
total = self.helper.create_tmp_variable(dtype='int')
126-
correct = self.helper.create_tmp_variable(dtype='int')
127-
acc = layers.accuracy(
128-
input=input, label=label, k=k, total=total, correct=correct)
129-
total = layers.cast(x=total, dtype='int64')
130-
correct = layers.cast(x=correct, dtype='int64')
131-
layers.sums(input=[self.total, total], out=self.total)
132-
layers.sums(input=[self.correct, correct], out=self.correct)
133-
134-
self.metrics.append(acc)
135-
136-
def eval(self, executor, eval_program=None):
137-
if eval_program is None:
138-
eval_program = Program()
139-
block = eval_program.current_block()
140-
with program_guard(main_program=eval_program):
141-
total = _clone_var_(block, self.total)
142-
correct = _clone_var_(block, self.correct)
143-
total = layers.cast(total, dtype='float32')
144-
correct = layers.cast(correct, dtype='float32')
145-
out = layers.elementwise_div(x=correct, y=total)
146-
return np.array(executor.run(eval_program, fetch_list=[out])[0])
147-
148-
149111
class ChunkEvaluator(Evaluator):
150112
"""
151113
Accumulate counter numbers output by chunk_eval from mini-batches and

python/paddle/fluid/layers/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,8 @@
2828
from math_op_patch import *
2929
import detection
3030
from detection import *
31+
import metric
32+
from metric import *
3133
from learning_rate_scheduler import *
3234

3335
__all__ = []
@@ -39,4 +41,5 @@
3941
__all__ += ops.__all__
4042
__all__ += device.__all__
4143
__all__ += detection.__all__
44+
__all__ += metric.__all__
4245
__all__ += learning_rate_scheduler.__all__

python/paddle/fluid/layers/metric.py

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
"""
15+
All layers just related to metric.
16+
"""
17+
18+
from ..layer_helper import LayerHelper
19+
from ..initializer import Normal, Constant
20+
from ..framework import Variable
21+
from ..param_attr import ParamAttr
22+
23+
__all__ = ['accuracy']
24+
25+
26+
def accuracy(input, label, k=1, correct=None, total=None):
27+
"""
28+
This function computes the accuracy using the input and label.
29+
The output is the top_k inputs and their indices.
30+
"""
31+
helper = LayerHelper("accuracy", **locals())
32+
topk_out = helper.create_tmp_variable(dtype=input.dtype)
33+
topk_indices = helper.create_tmp_variable(dtype="int64")
34+
helper.append_op(
35+
type="top_k",
36+
inputs={"X": [input]},
37+
outputs={"Out": [topk_out],
38+
"Indices": [topk_indices]},
39+
attrs={"k": k})
40+
acc_out = helper.create_tmp_variable(dtype="float32")
41+
if correct is None:
42+
correct = helper.create_tmp_variable(dtype="int64")
43+
if total is None:
44+
total = helper.create_tmp_variable(dtype="int64")
45+
helper.append_op(
46+
type="accuracy",
47+
inputs={
48+
"Out": [topk_out],
49+
"Indices": [topk_indices],
50+
"Label": [label]
51+
},
52+
outputs={
53+
"Accuracy": [acc_out],
54+
"Correct": [correct],
55+
"Total": [total],
56+
})
57+
return acc_out

python/paddle/fluid/layers/nn.py

Lines changed: 1 addition & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,6 @@
3535
'cos_sim',
3636
'cross_entropy',
3737
'square_error_cost',
38-
'accuracy',
3938
'chunk_eval',
4039
'sequence_conv',
4140
'conv2d',
@@ -1022,40 +1021,6 @@ def square_error_cost(input, label):
10221021
return square_out
10231022

10241023

1025-
def accuracy(input, label, k=1, correct=None, total=None):
1026-
"""
1027-
This function computes the accuracy using the input and label.
1028-
The output is the top_k inputs and their indices.
1029-
"""
1030-
helper = LayerHelper("accuracy", **locals())
1031-
topk_out = helper.create_tmp_variable(dtype=input.dtype)
1032-
topk_indices = helper.create_tmp_variable(dtype="int64")
1033-
helper.append_op(
1034-
type="top_k",
1035-
inputs={"X": [input]},
1036-
outputs={"Out": [topk_out],
1037-
"Indices": [topk_indices]},
1038-
attrs={"k": k})
1039-
acc_out = helper.create_tmp_variable(dtype="float32")
1040-
if correct is None:
1041-
correct = helper.create_tmp_variable(dtype="int64")
1042-
if total is None:
1043-
total = helper.create_tmp_variable(dtype="int64")
1044-
helper.append_op(
1045-
type="accuracy",
1046-
inputs={
1047-
"Out": [topk_out],
1048-
"Indices": [topk_indices],
1049-
"Label": [label]
1050-
},
1051-
outputs={
1052-
"Accuracy": [acc_out],
1053-
"Correct": [correct],
1054-
"Total": [total],
1055-
})
1056-
return acc_out
1057-
1058-
10591024
def chunk_eval(input,
10601025
label,
10611026
chunk_scheme,
@@ -3182,7 +3147,7 @@ def smooth_l1(x, y, inside_weight=None, outside_weight=None, sigma=None):
31823147
data = fluid.layers.data(name='data', shape=[128], dtype='float32')
31833148
label = fluid.layers.data(name='label', shape=[100], dtype='int64')
31843149
fc = fluid.layers.fc(input=data, size=100)
3185-
out = fluid.layers.smooth_l1(logits=fc, label=label)
3150+
out = fluid.layers.smooth_l1(x=fc, y=label)
31863151
"""
31873152
helper = LayerHelper('smooth_l1_loss', **locals())
31883153
diff = helper.create_tmp_variable(dtype=x.dtype)

python/paddle/fluid/tests/book_memory_optimization/test_memopt_image_classification_train.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -122,7 +122,8 @@ def conv_block(input, num_filter, groups, dropouts):
122122
optimizer = fluid.optimizer.Adam(learning_rate=0.001)
123123
opts = optimizer.minimize(avg_cost)
124124

125-
accuracy = fluid.evaluator.Accuracy(input=predict, label=label)
125+
batch_size = fluid.layers.create_tensor(dtype='int64')
126+
batch_acc = fluid.layers.accuracy(input=predict, label=label, total=batch_size)
126127

127128
fluid.memory_optimize(fluid.default_main_program())
128129

@@ -144,13 +145,17 @@ def conv_block(input, num_filter, groups, dropouts):
144145
exe.run(fluid.default_startup_program())
145146

146147
i = 0
148+
149+
accuracy = fluid.average.WeightedAverage()
147150
for pass_id in range(PASS_NUM):
148-
accuracy.reset(exe)
151+
accuracy.reset()
149152
for data in train_reader():
150-
loss, acc = exe.run(fluid.default_main_program(),
151-
feed=feeder.feed(data),
152-
fetch_list=[avg_cost] + accuracy.metrics)
153-
pass_acc = accuracy.eval(exe)
153+
loss, acc, weight = exe.run(
154+
fluid.default_main_program(),
155+
feed=feeder.feed(data),
156+
fetch_list=[avg_cost, batch_acc, batch_size])
157+
accuracy.add(value=acc, weight=weight)
158+
pass_acc = accuracy.eval()
154159
print("loss:" + str(loss) + " acc:" + str(acc) + " pass_acc:" + str(
155160
pass_acc))
156161
# this model is slow, so if we can train two mini batch, we think it works properly.

0 commit comments

Comments
 (0)