Skip to content

Commit 60229c1

Browse files
committed
Follow comments.
test=develop
1 parent 2939fc9 commit 60229c1

File tree

2 files changed

+56
-9
lines changed

2 files changed

+56
-9
lines changed

python/paddle/fluid/metrics.py

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,6 @@
1313
# limitations under the License.
1414
"""
1515
Fluid Metrics
16-
17-
The metrics are accomplished via Python natively.
1816
"""
1917

2018
from __future__ import print_function
@@ -24,6 +22,12 @@
2422
import warnings
2523
import six
2624

25+
from .layer_helper import LayerHelper
26+
from .initializer import Constant
27+
from . import unique_name
28+
from .framework import Program, Variable, program_guard
29+
from . import layers
30+
2731
__all__ = [
2832
'MetricBase',
2933
'CompositeMetric',
@@ -598,7 +602,7 @@ class DetectionMAP(object):
598602
Examples:
599603
.. code-block:: python
600604
601-
exe = fluid.executor(place)
605+
exe = fluid.Executor(place)
602606
map_evaluator = fluid.Evaluator.DetectionMAP(input,
603607
gt_label, gt_box, gt_difficult)
604608
cur_map, accum_map = map_evaluator.get_map_var()
@@ -624,9 +628,6 @@ def __init__(self,
624628
overlap_threshold=0.5,
625629
evaluate_difficult=True,
626630
ap_version='integral'):
627-
from . import layers
628-
from .layer_helper import LayerHelper
629-
from .initializer import Constant
630631

631632
self.helper = LayerHelper('map_eval')
632633
gt_label = layers.cast(x=gt_label, dtype=gt_box.dtype)
@@ -692,7 +693,6 @@ def _create_state(self, suffix, dtype, shape):
692693
shape(tuple|list): the shape of state
693694
Returns: State variable
694695
"""
695-
from . import unique_name
696696
state = self.helper.create_variable(
697697
name="_".join([unique_name.generate(self.helper.name), suffix]),
698698
persistable=True,
@@ -717,8 +717,6 @@ def reset(self, executor, reset_program=None):
717717
reset_program(Program|None): a single Program for reset process.
718718
If None, will create a Program.
719719
"""
720-
from .framework import Program, Variable, program_guard
721-
from . import layers
722720

723721
def _clone_var_(block, var):
724722
assert isinstance(var, Variable)
Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import unittest
16+
17+
import paddle.fluid as fluid
18+
from paddle.fluid.framework import Program, program_guard
19+
20+
21+
class TestMetricsDetectionMap(unittest.TestCase):
22+
def test_detection_map(self):
23+
program = fluid.Program()
24+
with program_guard(program):
25+
detect_res = fluid.layers.data(
26+
name='detect_res',
27+
shape=[10, 6],
28+
append_batch_size=False,
29+
dtype='float32')
30+
label = fluid.layers.data(
31+
name='label',
32+
shape=[10, 1],
33+
append_batch_size=False,
34+
dtype='float32')
35+
box = fluid.layers.data(
36+
name='bbox',
37+
shape=[10, 4],
38+
append_batch_size=False,
39+
dtype='float32')
40+
map_eval = fluid.metrics.DetectionMAP(
41+
detect_res, label, box, class_num=21)
42+
cur_map, accm_map = map_eval.get_map_var()
43+
self.assertIsNotNone(cur_map)
44+
self.assertIsNotNone(accm_map)
45+
print(str(program))
46+
47+
48+
if __name__ == '__main__':
49+
unittest.main()

0 commit comments

Comments
 (0)