Skip to content

Commit f122a5d

Browse files
authored
Add accuracy layer (#4958)
* Complete accuray layer * Fix error * Fix error * Add 'accuracy' to __all__ * update * Fix Type error * Fix error * Refine unit tests * Fix an unit test error
1 parent 8b1c50c commit f122a5d

File tree

5 files changed

+42
-12
lines changed

5 files changed

+42
-12
lines changed

paddle/operators/accuracy_op.cc

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,8 @@ class AccuracyOp : public framework::OperatorWithKernel {
3232
auto inference_dim = ctx->GetInputDim("Inference");
3333
auto label_dim = ctx->GetInputDim("Label");
3434

35-
PADDLE_ENFORCE_EQ(label_dim.size(), 1, "label must be a vector");
35+
PADDLE_ENFORCE_EQ(label_dim.size(), 2, "label's rank must be 2.");
36+
PADDLE_ENFORCE_EQ(label_dim[1], 1, "label's second dimension must be 1");
3637
PADDLE_ENFORCE_EQ(inference_dim[0], label_dim[0],
3738
"inference size must be the same as label size");
3839

@@ -68,7 +69,8 @@ information, or not. But the output only shares the LoD with input `Inference`.
6869
} // namespace paddle
6970

7071
namespace ops = paddle::operators;
71-
REGISTER_OP_WITHOUT_GRADIENT(accuracy, ops::AccuracyOp, ops::AccuracyOpMaker);
72+
REGISTER_OPERATOR(accuracy, ops::AccuracyOp, ops::AccuracyOpMaker,
73+
paddle::framework::EmptyGradOpMaker);
7274
REGISTER_OP_CPU_KERNEL(
7375
accuracy, ops::AccuracyKernel<paddle::platform::CPUPlace, int>,
7476
ops::AccuracyKernel<paddle::platform::CPUPlace, int64_t>);

paddle/operators/top_k_op.cc

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,11 @@ class TopkOpMaker : public framework::OpProtoAndCheckerMaker {
5252
AddOutput("Out", "The output tensor of Topk op");
5353
AddOutput("Indices", "The indices of Topk elements of input");
5454
AddComment(
55-
R"DOC(If the input is a vector (1d tensor), finds the k largest entries in the vector and outputs their values and indices as vectors. Thus values[j] is the j-th largest entry in input, and its index is indices[j].
55+
R"DOC(If the input is a vector (1d tensor),
56+
finds the k largest entries in the vector
57+
and outputs their values and indices as vectors.
58+
Thus values[j] is the j-th largest entry in input,
59+
and its index is indices[j].
5660
5761
For matrices, computes the top k entries in each row. )DOC");
5862
AddAttr<int>("k",
@@ -66,6 +70,7 @@ class TopkOpMaker : public framework::OpProtoAndCheckerMaker {
6670
} // namespace paddle
6771

6872
namespace ops = paddle::operators;
69-
REGISTER_OP_WITHOUT_GRADIENT(top_k, ops::TopkOp, ops::TopkOpMaker);
73+
REGISTER_OPERATOR(top_k, ops::TopkOp, ops::TopkOpMaker,
74+
paddle::framework::EmptyGradOpMaker);
7075
REGISTER_OP_CPU_KERNEL(top_k,
7176
ops::TopkKernel<paddle::platform::CPUPlace, float>);

python/paddle/v2/framework/layers.py

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
__all__ = [
77
'fc', 'data', 'cross_entropy', 'conv2d', 'pool2d', 'embedding', 'concat',
8-
'StaticRNN', 'cast', 'sequence_conv', 'sequence_pool'
8+
'StaticRNN', 'cast', 'sequence_conv', 'sequence_pool', 'accuracy'
99
]
1010

1111

@@ -229,6 +229,26 @@ def square_error_cost(input, label, **kwargs):
229229
return square_out
230230

231231

232+
def accuracy(input, label, k=1, **kwargs):
233+
helper = LayerHelper("accuracy", **kwargs)
234+
topk_out = helper.create_tmp_variable(dtype=input.data_type)
235+
topk_indices = helper.create_tmp_variable(dtype="int64")
236+
helper.append_op(
237+
type="top_k",
238+
inputs={"X": [input]},
239+
outputs={"Out": [topk_out],
240+
"Indices": [topk_indices]},
241+
attrs={"k": k})
242+
acc_out_dtype = kwargs.get("out_dtype", "float32")
243+
acc_out = helper.create_tmp_variable(dtype=acc_out_dtype)
244+
helper.append_op(
245+
type="accuracy",
246+
inputs={"Inference": [topk_indices],
247+
"Label": [label]},
248+
outputs={"Accuracy": [acc_out]})
249+
return acc_out
250+
251+
232252
def sequence_conv(input,
233253
num_filters,
234254
name=None,

python/paddle/v2/framework/tests/test_accuracy_op.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,12 +8,12 @@ def setUp(self):
88
self.op_type = "accuracy"
99
n = 8192
1010
infer = np.random.randint(0, 2, (n, 1)).astype("int")
11-
label = np.random.randint(0, 2, (n, )).astype("int")
11+
label = np.random.randint(0, 2, (n, 1)).astype("int")
1212
self.inputs = {'Inference': infer, "Label": label}
1313
num_correct = 0
1414
for rowid in xrange(n):
1515
for ele in infer[rowid]:
16-
if ele == label[rowid]:
16+
if ele == label[rowid][0]:
1717
num_correct += 1
1818
break
1919
self.outputs = {

python/paddle/v2/framework/tests/test_recognize_digits_conv.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -51,12 +51,14 @@
5151
cost = layers.cross_entropy(
5252
input=predict, label=label, program=program, init_program=init_program)
5353
avg_cost = layers.mean(x=cost, program=program)
54+
accuracy = layers.accuracy(
55+
input=predict, label=label, program=program, init_program=init_program)
5456

5557
sgd_optimizer = optimizer.SGDOptimizer(learning_rate=0.001)
5658
opts = sgd_optimizer.minimize(avg_cost)
5759

5860
BATCH_SIZE = 50
59-
PASS_NUM = 1
61+
PASS_NUM = 3
6062
train_reader = paddle.batch(
6163
paddle.reader.shuffle(
6264
paddle.dataset.mnist.train(), buf_size=500),
@@ -83,10 +85,11 @@
8385
outs = exe.run(program,
8486
feed={"pixel": tensor_img,
8587
"label": tensor_y},
86-
fetch_list=[avg_cost])
87-
88+
fetch_list=[avg_cost, accuracy])
8889
loss = np.array(outs[0])
90+
acc = np.array(outs[1])
8991

90-
if loss < 10.0:
91-
exit(0) # if avg cost less than 10.0, we think our code is good.
92+
if loss < 10.0 and acc > 0.9:
93+
# if avg cost less than 10.0 and accuracy is larger than 0.9, we think our code is good.
94+
exit(0)
9295
exit(1)

0 commit comments

Comments
 (0)