Skip to content

Commit 757e2c0

Browse files
committed
move layer accuracy to metric.py
1 parent d4dabe3 commit 757e2c0

File tree

3 files changed

+60
-35
lines changed

3 files changed

+60
-35
lines changed

python/paddle/v2/fluid/layers/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,8 @@
2828
from math_op_patch import *
2929
import detection
3030
from detection import *
31+
import metric
32+
from metric import *
3133

3234
__all__ = []
3335
__all__ += math_op_patch.__all__
@@ -38,3 +40,4 @@
3840
__all__ += ops.__all__
3941
__all__ += device.__all__
4042
__all__ += detection.__all__
43+
__all__ += metric.__all__
Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
"""
15+
All layers just related to metric.
16+
"""
17+
18+
from ..layer_helper import LayerHelper
19+
from ..initializer import Normal, Constant
20+
from ..framework import Variable
21+
from ..param_attr import ParamAttr
22+
23+
__all__ = ['accuracy']
24+
25+
26+
def accuracy(input, label, k=1, correct=None, total=None):
27+
"""
28+
This function computes the accuracy using the input and label.
29+
The output is the top_k inputs and their indices.
30+
"""
31+
helper = LayerHelper("accuracy", **locals())
32+
topk_out = helper.create_tmp_variable(dtype=input.dtype)
33+
topk_indices = helper.create_tmp_variable(dtype="int64")
34+
helper.append_op(
35+
type="top_k",
36+
inputs={"X": [input]},
37+
outputs={"Out": [topk_out],
38+
"Indices": [topk_indices]},
39+
attrs={"k": k})
40+
acc_out = helper.create_tmp_variable(dtype="float32")
41+
if correct is None:
42+
correct = helper.create_tmp_variable(dtype="int64")
43+
if total is None:
44+
total = helper.create_tmp_variable(dtype="int64")
45+
helper.append_op(
46+
type="accuracy",
47+
inputs={
48+
"Out": [topk_out],
49+
"Indices": [topk_indices],
50+
"Label": [label]
51+
},
52+
outputs={
53+
"Accuracy": [acc_out],
54+
"Correct": [correct],
55+
"Total": [total],
56+
})
57+
return acc_out

python/paddle/v2/fluid/layers/nn.py

Lines changed: 0 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,6 @@
3434
'cos_sim',
3535
'cross_entropy',
3636
'square_error_cost',
37-
'accuracy',
3837
'chunk_eval',
3938
'sequence_conv',
4039
'conv2d',
@@ -1020,40 +1019,6 @@ def square_error_cost(input, label):
10201019
return square_out
10211020

10221021

1023-
def accuracy(input, label, k=1, correct=None, total=None):
1024-
"""
1025-
This function computes the accuracy using the input and label.
1026-
The output is the top_k inputs and their indices.
1027-
"""
1028-
helper = LayerHelper("accuracy", **locals())
1029-
topk_out = helper.create_tmp_variable(dtype=input.dtype)
1030-
topk_indices = helper.create_tmp_variable(dtype="int64")
1031-
helper.append_op(
1032-
type="top_k",
1033-
inputs={"X": [input]},
1034-
outputs={"Out": [topk_out],
1035-
"Indices": [topk_indices]},
1036-
attrs={"k": k})
1037-
acc_out = helper.create_tmp_variable(dtype="float32")
1038-
if correct is None:
1039-
correct = helper.create_tmp_variable(dtype="int64")
1040-
if total is None:
1041-
total = helper.create_tmp_variable(dtype="int64")
1042-
helper.append_op(
1043-
type="accuracy",
1044-
inputs={
1045-
"Out": [topk_out],
1046-
"Indices": [topk_indices],
1047-
"Label": [label]
1048-
},
1049-
outputs={
1050-
"Accuracy": [acc_out],
1051-
"Correct": [correct],
1052-
"Total": [total],
1053-
})
1054-
return acc_out
1055-
1056-
10571022
def chunk_eval(input,
10581023
label,
10591024
chunk_scheme,

0 commit comments

Comments
 (0)