Skip to content

Commit 8c166b6

Browse files
authored
Merge pull request #14012 from qingqing01/map_api
Refine detection mAP in metrics.py.
2 parents 5ed3e6f + af0fab9 commit 8c166b6

File tree

3 files changed

+232
-64
lines changed

3 files changed

+232
-64
lines changed

python/paddle/fluid/evaluator.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -316,7 +316,7 @@ class DetectionMAP(Evaluator):
316316
gt_label (Variable): The ground truth label index, which is a LoDTensor
317317
with shape [N, 1].
318318
gt_box (Variable): The ground truth bounding box (bbox), which is a
319-
LoDTensor with shape [N, 6]. The layout is [xmin, ymin, xmax, ymax].
319+
LoDTensor with shape [N, 4]. The layout is [xmin, ymin, xmax, ymax].
320320
gt_difficult (Variable|None): Whether this ground truth is a difficult
321321
bounding bbox, which can be a LoDTensor [N, 1] or not set. If None,
322322
it means all the ground truth labels are not difficult bbox.

python/paddle/fluid/metrics.py

Lines changed: 182 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,6 @@
1313
# limitations under the License.
1414
"""
1515
Fluid Metrics
16-
17-
The metrics are accomplished via Python natively.
1816
"""
1917

2018
from __future__ import print_function
@@ -24,6 +22,12 @@
2422
import warnings
2523
import six
2624

25+
from .layer_helper import LayerHelper
26+
from .initializer import Constant
27+
from . import unique_name
28+
from .framework import Program, Variable, program_guard
29+
from . import layers
30+
2731
__all__ = [
2832
'MetricBase',
2933
'CompositeMetric',
@@ -478,67 +482,6 @@ def eval(self):
478482
return avg_distance, avg_instance_error
479483

480484

481-
class DetectionMAP(MetricBase):
482-
"""
483-
Calculate the detection mean average precision (mAP).
484-
mAP is the metric to measure the accuracy of object detectors
485-
like Faster R-CNN, SSD, etc.
486-
It is the average of the maximum precisions at different recall values.
487-
Please get more information from the following articles:
488-
https://sanchom.wordpress.com/tag/average-precision/
489-
490-
https://arxiv.org/abs/1512.02325
491-
492-
The general steps are as follows:
493-
494-
1. calculate the true positive and false positive according to the input
495-
of detection and labels.
496-
2. calculate mAP value, support two versions: '11 point' and 'integral'.
497-
498-
Examples:
499-
.. code-block:: python
500-
501-
pred = fluid.layers.fc(input=data, size=1000, act="tanh")
502-
batch_map = layers.detection_map(
503-
input,
504-
label,
505-
class_num,
506-
background_label,
507-
overlap_threshold=overlap_threshold,
508-
evaluate_difficult=evaluate_difficult,
509-
ap_version=ap_version)
510-
metric = fluid.metrics.DetectionMAP()
511-
for data in train_reader():
512-
loss, preds, labels = exe.run(fetch_list=[cost, batch_map])
513-
batch_size = data[0]
514-
metric.update(value=batch_map, weight=batch_size)
515-
numpy_map = metric.eval()
516-
"""
517-
518-
def __init__(self, name=None):
519-
super(DetectionMAP, self).__init__(name)
520-
# the current map value
521-
self.value = .0
522-
self.weight = .0
523-
524-
def update(self, value, weight):
525-
if not _is_number_or_matrix_(value):
526-
raise ValueError(
527-
"The 'value' must be a number(int, float) or a numpy ndarray.")
528-
if not _is_number_(weight):
529-
raise ValueError("The 'weight' must be a number(int, float).")
530-
self.value += value
531-
self.weight += weight
532-
533-
def eval(self):
534-
if self.weight == 0:
535-
raise ValueError(
536-
"There is no data in DetectionMAP Metrics. "
537-
"Please check layers.detection_map output has added to DetectionMAP."
538-
)
539-
return self.value / self.weight
540-
541-
542485
class Auc(MetricBase):
543486
"""
544487
Auc metric adapts to the binary classification.
@@ -616,3 +559,179 @@ def eval(self):
616559
idx -= 1
617560

618561
return auc / tot_pos / tot_neg if tot_pos > 0.0 and tot_neg > 0.0 else 0.0
562+
563+
564+
class DetectionMAP(object):
565+
"""
566+
Calculate the detection mean average precision (mAP).
567+
568+
The general steps are as follows:
569+
1. calculate the true positive and false positive according to the input
570+
of detection and labels.
571+
2. calculate mAP value, support two versions: '11 point' and 'integral'.
572+
573+
Please get more information from the following articles:
574+
https://sanchom.wordpress.com/tag/average-precision/
575+
https://arxiv.org/abs/1512.02325
576+
577+
Args:
578+
input (Variable): The detection results, which is a LoDTensor with shape
579+
[M, 6]. The layout is [label, confidence, xmin, ymin, xmax, ymax].
580+
gt_label (Variable): The ground truth label index, which is a LoDTensor
581+
with shape [N, 1].
582+
gt_box (Variable): The ground truth bounding box (bbox), which is a
583+
LoDTensor with shape [N, 4]. The layout is [xmin, ymin, xmax, ymax].
584+
gt_difficult (Variable|None): Whether this ground truth is a difficult
585+
bounding bbox, which can be a LoDTensor [N, 1] or not set. If None,
586+
it means all the ground truth labels are not difficult bbox.
587+
class_num (int): The class number.
588+
background_label (int): The index of background label, the background
589+
label will be ignored. If set to -1, then all categories will be
590+
considered, 0 by defalut.
591+
overlap_threshold (float): The threshold for deciding true/false
592+
positive, 0.5 by defalut.
593+
evaluate_difficult (bool): Whether to consider difficult ground truth
594+
for evaluation, True by defalut. This argument does not work when
595+
gt_difficult is None.
596+
ap_version (string): The average precision calculation ways, it must be
597+
'integral' or '11point'. Please check
598+
https://sanchom.wordpress.com/tag/average-precision/ for details.
599+
- 11point: the 11-point interpolated average precision.
600+
- integral: the natural integral of the precision-recall curve.
601+
602+
Examples:
603+
.. code-block:: python
604+
605+
exe = fluid.Executor(place)
606+
map_evaluator = fluid.Evaluator.DetectionMAP(input,
607+
gt_label, gt_box, gt_difficult)
608+
cur_map, accum_map = map_evaluator.get_map_var()
609+
fetch = [cost, cur_map, accum_map]
610+
for epoch in PASS_NUM:
611+
map_evaluator.reset(exe)
612+
for data in batches:
613+
loss, cur_map_v, accum_map_v = exe.run(fetch_list=fetch)
614+
615+
In the above example:
616+
617+
'cur_map_v' is the mAP of current mini-batch.
618+
'accum_map_v' is the accumulative mAP of one pass.
619+
"""
620+
621+
def __init__(self,
622+
input,
623+
gt_label,
624+
gt_box,
625+
gt_difficult=None,
626+
class_num=None,
627+
background_label=0,
628+
overlap_threshold=0.5,
629+
evaluate_difficult=True,
630+
ap_version='integral'):
631+
632+
self.helper = LayerHelper('map_eval')
633+
gt_label = layers.cast(x=gt_label, dtype=gt_box.dtype)
634+
if gt_difficult:
635+
gt_difficult = layers.cast(x=gt_difficult, dtype=gt_box.dtype)
636+
label = layers.concat([gt_label, gt_difficult, gt_box], axis=1)
637+
else:
638+
label = layers.concat([gt_label, gt_box], axis=1)
639+
640+
# calculate mean average precision (mAP) of current mini-batch
641+
map = layers.detection_map(
642+
input,
643+
label,
644+
class_num,
645+
background_label,
646+
overlap_threshold=overlap_threshold,
647+
evaluate_difficult=evaluate_difficult,
648+
ap_version=ap_version)
649+
650+
states = []
651+
states.append(
652+
self._create_state(
653+
dtype='int32', shape=None, suffix='accum_pos_count'))
654+
states.append(
655+
self._create_state(
656+
dtype='float32', shape=None, suffix='accum_true_pos'))
657+
states.append(
658+
self._create_state(
659+
dtype='float32', shape=None, suffix='accum_false_pos'))
660+
var = self._create_state(dtype='int32', shape=[1], suffix='has_state')
661+
self.helper.set_variable_initializer(
662+
var, initializer=Constant(value=int(0)))
663+
self.has_state = var
664+
665+
# calculate accumulative mAP
666+
accum_map = layers.detection_map(
667+
input,
668+
label,
669+
class_num,
670+
background_label,
671+
overlap_threshold=overlap_threshold,
672+
evaluate_difficult=evaluate_difficult,
673+
has_state=self.has_state,
674+
input_states=states,
675+
out_states=states,
676+
ap_version=ap_version)
677+
678+
layers.fill_constant(
679+
shape=self.has_state.shape,
680+
value=1,
681+
dtype=self.has_state.dtype,
682+
out=self.has_state)
683+
684+
self.cur_map = map
685+
self.accum_map = accum_map
686+
687+
def _create_state(self, suffix, dtype, shape):
688+
"""
689+
Create state variable.
690+
Args:
691+
suffix(str): the state suffix.
692+
dtype(str|core.VarDesc.VarType): the state data type
693+
shape(tuple|list): the shape of state
694+
Returns: State variable
695+
"""
696+
state = self.helper.create_variable(
697+
name="_".join([unique_name.generate(self.helper.name), suffix]),
698+
persistable=True,
699+
dtype=dtype,
700+
shape=shape)
701+
return state
702+
703+
def get_map_var(self):
704+
"""
705+
Returns: mAP variable of current mini-batch and
706+
accumulative mAP variable cross mini-batches.
707+
"""
708+
return self.cur_map, self.accum_map
709+
710+
def reset(self, executor, reset_program=None):
711+
"""
712+
Reset metric states at the begin of each pass/user specified batch.
713+
714+
Args:
715+
executor(Executor): a executor for executing
716+
the reset_program.
717+
reset_program(Program|None): a single Program for reset process.
718+
If None, will create a Program.
719+
"""
720+
721+
def _clone_var_(block, var):
722+
assert isinstance(var, Variable)
723+
return block.create_var(
724+
name=var.name,
725+
shape=var.shape,
726+
dtype=var.dtype,
727+
type=var.type,
728+
lod_level=var.lod_level,
729+
persistable=var.persistable)
730+
731+
if reset_program is None:
732+
reset_program = Program()
733+
with program_guard(main_program=reset_program):
734+
var = _clone_var_(reset_program.current_block(), self.has_state)
735+
layers.fill_constant(
736+
shape=var.shape, value=0, dtype=var.dtype, out=var)
737+
executor.run(reset_program)
Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import unittest
16+
17+
import paddle.fluid as fluid
18+
from paddle.fluid.framework import Program, program_guard
19+
20+
21+
class TestMetricsDetectionMap(unittest.TestCase):
22+
def test_detection_map(self):
23+
program = fluid.Program()
24+
with program_guard(program):
25+
detect_res = fluid.layers.data(
26+
name='detect_res',
27+
shape=[10, 6],
28+
append_batch_size=False,
29+
dtype='float32')
30+
label = fluid.layers.data(
31+
name='label',
32+
shape=[10, 1],
33+
append_batch_size=False,
34+
dtype='float32')
35+
box = fluid.layers.data(
36+
name='bbox',
37+
shape=[10, 4],
38+
append_batch_size=False,
39+
dtype='float32')
40+
map_eval = fluid.metrics.DetectionMAP(
41+
detect_res, label, box, class_num=21)
42+
cur_map, accm_map = map_eval.get_map_var()
43+
self.assertIsNotNone(cur_map)
44+
self.assertIsNotNone(accm_map)
45+
print(str(program))
46+
47+
48+
if __name__ == '__main__':
49+
unittest.main()

0 commit comments

Comments
 (0)