Skip to content

Commit d52fa26

Browse files
authored
Feature/metrics (#9791)
* "add metrics" * "add fluid metrics" * "add import guards" * "show warnings" * "add demo" * "fix ci" * "add some details" * "fix cci" * "add demo Python" * "add metrics"
1 parent 90084a2 commit d52fa26

File tree

6 files changed

+427
-6
lines changed

6 files changed

+427
-6
lines changed

benchmark/fluid/mnist.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -139,9 +139,6 @@ def run_benchmark(model, args):
139139

140140
# inference program
141141
inference_program = fluid.default_main_program().clone()
142-
with fluid.program_guard(inference_program):
143-
inference_program = fluid.io.get_inference_program(
144-
target_vars=[batch_acc, batch_size_tensor])
145142

146143
# Optimization
147144
opt = fluid.optimizer.AdamOptimizer(
@@ -161,7 +158,7 @@ def run_benchmark(model, args):
161158
train_reader = paddle.batch(
162159
paddle.dataset.mnist.train(), batch_size=args.batch_size)
163160

164-
accuracy = fluid.average.WeightedAverage()
161+
accuracy = fluid.metrics.Accuracy()
165162
iters, num_samples, start_time = 0, 0, time.time()
166163
for pass_id in range(args.pass_num):
167164
accuracy.reset()
@@ -184,7 +181,7 @@ def run_benchmark(model, args):
184181
"label": y_data},
185182
fetch_list=[avg_cost, batch_acc, batch_size_tensor]
186183
) # The accuracy is the accumulation of batches, but not the current batch.
187-
accuracy.add(value=outs[1], weight=outs[2])
184+
accuracy.update(value=outs[1], weight=outs[2])
188185
iters += 1
189186
num_samples += len(y_data)
190187
loss = np.array(outs[0])

python/paddle/fluid/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
import backward
3030
import regularizer
3131
import average
32+
import metrics
3233
from param_attr import ParamAttr, WeightNormParamAttr
3334
from data_feeder import DataFeeder
3435
from core import LoDTensor, CPUPlace, CUDAPlace, CUDAPinnedPlace

python/paddle/fluid/average.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
# limitations under the License.
1414

1515
import numpy as np
16+
import warnings
1617
"""
1718
Class of all kinds of Average.
1819
@@ -22,6 +23,8 @@
2223
wrappers of Python functions.
2324
"""
2425

26+
__all__ = ["WeightedAverage"]
27+
2528

2629
def _is_number_(var):
2730
return isinstance(var, int) or isinstance(var, float) or (isinstance(
@@ -34,6 +37,9 @@ def _is_number_or_matrix_(var):
3437

3538
class WeightedAverage(object):
3639
def __init__(self):
40+
warnings.warn(
41+
"The %s is deprecated, please use fluid.metrics.Accuracy instead." %
42+
(self.__class__.__name__), Warning)
3743
self.reset()
3844

3945
def reset(self):

python/paddle/fluid/evaluator.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
import warnings
1516
import numpy as np
1617

1718
import layers
@@ -59,6 +60,9 @@ class Evaluator(object):
5960
"""
6061

6162
def __init__(self, name, **kwargs):
63+
warnings.warn(
64+
"The %s is deprecated, because maintain a modified program inside evaluator cause bug easily, please use fluid.metrics.%s instead."
65+
% (self.__class__.__name__, self.__class__.__name__), Warning)
6266
self.states = []
6367
self.metrics = []
6468
self.helper = LayerHelper(name, **kwargs)

python/paddle/fluid/layers/metric.py

Lines changed: 36 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,12 +15,13 @@
1515
All layers just related to metric.
1616
"""
1717

18+
import warnings
1819
from ..layer_helper import LayerHelper
1920
from ..initializer import Normal, Constant
2021
from ..framework import Variable
2122
from ..param_attr import ParamAttr
2223

23-
__all__ = ['accuracy']
24+
__all__ = ['accuracy', 'auc']
2425

2526

2627
def accuracy(input, label, k=1, correct=None, total=None):
@@ -55,3 +56,37 @@ def accuracy(input, label, k=1, correct=None, total=None):
5556
"Total": [total],
5657
})
5758
return acc_out
59+
60+
61+
def auc(input, label, curve='ROC', num_thresholds=200):
62+
warnings.warn(
63+
"This interface not recommended, fluid.layers.auc compute the auc at every minibatch, \
64+
but can not aggregate them and get the pass AUC, because pass \
65+
auc can not be averaged with weighted from the minibatch auc value. \
66+
Please use fluid.metrics.Auc, it can compute the auc value via Python natively, \
67+
which can get every minibatch and every pass auc value.", Warning)
68+
helper = LayerHelper("auc", **locals())
69+
topk_out = helper.create_tmp_variable(dtype=input.dtype)
70+
topk_indices = helper.create_tmp_variable(dtype="int64")
71+
helper.append_op(
72+
type="top_k",
73+
inputs={"X": [input]},
74+
outputs={"Out": [topk_out],
75+
"Indices": [topk_indices]},
76+
attrs={"k": k})
77+
auc_out = helper.create_tmp_variable(dtype="float32")
78+
if correct is None:
79+
correct = helper.create_tmp_variable(dtype="int64")
80+
if total is None:
81+
total = helper.create_tmp_variable(dtype="int64")
82+
helper.append_op(
83+
type="accuracy",
84+
inputs={
85+
"Out": [topk_out],
86+
"Indices": [topk_indices],
87+
"Label": [label]
88+
},
89+
attrs={"curve": curve,
90+
"num_thresholds": num_thresholds},
91+
outputs={"AUC": [auc_out], })
92+
return auc_out

0 commit comments

Comments
 (0)