Skip to content

Commit 608feea

Browse files
authored
Implement detection mAP evaluator wrapper and unify label format between SSD loss and mAP evaluator (#8736)
* Implement mAP evalutor Python interface. * Fix unit testing and uniy label format between SSD loss and mAP evalutor. * Update doc.
1 parent 95a28d1 commit 608feea

File tree

6 files changed

+158
-33
lines changed

6 files changed

+158
-33
lines changed

paddle/fluid/operators/detection_map_op.cc

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -47,11 +47,10 @@ class DetectionMAPOp : public framework::OperatorWithKernel {
4747
PADDLE_ENFORCE_EQ(det_dims[1], 6UL,
4848
"The shape is of Input(DetectRes) [N, 6].");
4949
auto label_dims = ctx->GetInputDim("Label");
50-
PADDLE_ENFORCE_EQ(label_dims.size(), 2UL,
50+
PADDLE_ENFORCE_EQ(label_dims.size(), 2,
5151
"The rank of Input(Label) must be 2, "
5252
"the shape is [N, 6].");
53-
PADDLE_ENFORCE_EQ(label_dims[1], 6UL,
54-
"The shape is of Input(Label) [N, 6].");
53+
PADDLE_ENFORCE_EQ(label_dims[1], 6, "The shape is of Input(Label) [N, 6].");
5554

5655
if (ctx->HasInput("PosCount")) {
5756
PADDLE_ENFORCE(ctx->HasInput("TruePos"),
@@ -96,6 +95,10 @@ class DetectionMAPOpMaker : public framework::OpProtoAndCheckerMaker {
9695
"instance, the offsets in first dimension are called LoD, "
9796
"the number of offset is N + 1, if LoD[i + 1] - LoD[i] == 0, "
9897
"means there is no ground-truth data.");
98+
AddInput("HasState",
99+
"(Tensor<int>) A tensor with shape [1], 0 means ignoring input "
100+
"states, which including PosCount, TruePos, FalsePos.")
101+
.AsDispensable();
99102
AddInput("PosCount",
100103
"(Tensor) A tensor with shape [Ncls, 1], store the "
101104
"input positive example count of each class, Ncls is the count of "
@@ -145,7 +148,7 @@ class DetectionMAPOpMaker : public framework::OpProtoAndCheckerMaker {
145148
"(float) "
146149
"The lower bound jaccard overlap threshold of detection output and "
147150
"ground-truth data.")
148-
.SetDefault(.3f);
151+
.SetDefault(.5f);
149152
AddAttr<bool>("evaluate_difficult",
150153
"(bool, default true) "
151154
"Switch to control whether the difficult data is evaluated.")

paddle/fluid/operators/detection_map_op.h

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,13 @@ class DetectionMAPOpKernel : public framework::OpKernel<T> {
8787
std::map<int, std::vector<std::pair<T, int>>> true_pos;
8888
std::map<int, std::vector<std::pair<T, int>>> false_pos;
8989

90-
if (in_pos_count != nullptr) {
90+
auto* has_state = ctx.Input<framework::LoDTensor>("HasState");
91+
int state = 0;
92+
if (has_state) {
93+
state = has_state->data<int>()[0];
94+
}
95+
96+
if (in_pos_count != nullptr && state) {
9197
GetInputPos(*in_pos_count, *in_true_pos, *in_false_pos, label_pos_count,
9298
true_pos, false_pos);
9399
}
@@ -202,6 +208,7 @@ class DetectionMAPOpKernel : public framework::OpKernel<T> {
202208

203209
int* pos_count_data = output_pos_count.mutable_data<int>(
204210
framework::make_ddim({max_class_id + 1, 1}), ctx.GetPlace());
211+
205212
T* true_pos_data = output_true_pos.mutable_data<T>(
206213
framework::make_ddim({true_pos_count, 2}), ctx.GetPlace());
207214
T* false_pos_data = output_false_pos.mutable_data<T>(

python/paddle/fluid/evaluator.py

Lines changed: 119 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,11 +18,13 @@
1818
from framework import Program, Variable, program_guard
1919
import unique_name
2020
from layer_helper import LayerHelper
21+
from initializer import Constant
2122

2223
__all__ = [
2324
'Accuracy',
2425
'ChunkEvaluator',
2526
'EditDistance',
27+
'DetectionMAP',
2628
]
2729

2830

@@ -285,3 +287,120 @@ def eval(self, executor, eval_program=None):
285287
result = executor.run(
286288
eval_program, fetch_list=[avg_distance, avg_instance_error])
287289
return np.array(result[0]), np.array(result[1])
290+
291+
292+
class DetectionMAP(Evaluator):
293+
"""
294+
Calculate the detection mean average precision (mAP).
295+
296+
TODO (Dang Qingqing): update the following doc.
297+
The general steps are as follows:
298+
1. calculate the true positive and false positive according to the input
299+
of detection and labels.
300+
2. calculate mAP value, support two versions: '11 point' and 'integral'.
301+
302+
Please get more information from the following articles:
303+
https://sanchom.wordpress.com/tag/average-precision/
304+
https://arxiv.org/abs/1512.02325
305+
306+
Args:
307+
input (Variable): The detection results, which is a LoDTensor with shape
308+
[M, 6]. The layout is [label, confidence, xmin, ymin, xmax, ymax].
309+
gt_label (Variable): The ground truth label index, which is a LoDTensor
310+
with shape [N, 1].
311+
gt_difficult (Variable): Whether this ground truth is a difficult
312+
bounding box (bbox), which is a LoDTensor [N, 1].
313+
gt_box (Variable): The ground truth bounding box (bbox), which is a
314+
LoDTensor with shape [N, 6]. The layout is [xmin, ymin, xmax, ymax].
315+
overlap_threshold (float): The threshold for deciding true/false
316+
positive, 0.5 by defalut.
317+
evaluate_difficult (bool): Whether to consider difficult ground truth
318+
for evaluation, True by defalut.
319+
ap_version (string): The average precision calculation ways, it must be
320+
'integral' or '11point'. Please check
321+
https://sanchom.wordpress.com/tag/average-precision/ for details.
322+
- 11point: the 11-point interpolated average precision.
323+
- integral: the natural integral of the precision-recall curve.
324+
325+
Example:
326+
327+
exe = fluid.executor(place)
328+
map_evaluator = fluid.Evaluator.DetectionMAP(input,
329+
gt_label, gt_difficult, gt_box)
330+
cur_map, accum_map = map_evaluator.get_map_var()
331+
fetch = [cost, cur_map, accum_map]
332+
for epoch in PASS_NUM:
333+
map_evaluator.reset(exe)
334+
for data in batches:
335+
loss, cur_map_v, accum_map_v = exe.run(fetch_list=fetch)
336+
337+
In the above example:
338+
339+
'cur_map_v' is the mAP of current mini-batch.
340+
'accum_map_v' is the accumulative mAP of one pass.
341+
"""
342+
343+
def __init__(self,
344+
input,
345+
gt_label,
346+
gt_box,
347+
gt_difficult,
348+
overlap_threshold=0.5,
349+
evaluate_difficult=True,
350+
ap_version='integral'):
351+
super(DetectionMAP, self).__init__("map_eval")
352+
353+
gt_label = layers.cast(x=gt_label, dtype=gt_box.dtype)
354+
gt_difficult = layers.cast(x=gt_difficult, dtype=gt_box.dtype)
355+
label = layers.concat([gt_label, gt_difficult, gt_box], axis=1)
356+
357+
# calculate mean average precision (mAP) of current mini-batch
358+
map = layers.detection_map(
359+
input,
360+
label,
361+
overlap_threshold=overlap_threshold,
362+
evaluate_difficult=evaluate_difficult,
363+
ap_version=ap_version)
364+
365+
self.create_state(dtype='int32', shape=None, suffix='accum_pos_count')
366+
self.create_state(dtype='float32', shape=None, suffix='accum_true_pos')
367+
self.create_state(dtype='float32', shape=None, suffix='accum_false_pos')
368+
369+
self.has_state = None
370+
var = self.helper.create_variable(
371+
persistable=True, dtype='int32', shape=[1])
372+
self.helper.set_variable_initializer(
373+
var, initializer=Constant(value=int(0)))
374+
self.has_state = var
375+
376+
# calculate accumulative mAP
377+
accum_map = layers.detection_map(
378+
input,
379+
label,
380+
overlap_threshold=overlap_threshold,
381+
evaluate_difficult=evaluate_difficult,
382+
has_state=self.has_state,
383+
input_states=self.states,
384+
out_states=self.states,
385+
ap_version=ap_version)
386+
387+
layers.fill_constant(
388+
shape=self.has_state.shape,
389+
value=1,
390+
dtype=self.has_state.dtype,
391+
out=self.has_state)
392+
393+
self.cur_map = map
394+
self.accum_map = accum_map
395+
396+
def get_map_var(self):
397+
return self.cur_map, self.accum_map
398+
399+
def reset(self, executor, reset_program=None):
400+
if reset_program is None:
401+
reset_program = Program()
402+
with program_guard(main_program=reset_program):
403+
var = _clone_var_(reset_program.current_block(), self.has_state)
404+
layers.fill_constant(
405+
shape=var.shape, value=0, dtype=var.dtype, out=var)
406+
executor.run(reset_program)

python/paddle/fluid/layers/detection.py

Lines changed: 21 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -151,23 +151,34 @@ class number, M is number of bounding boxes. For each category
151151
@autodoc()
152152
def detection_map(detect_res,
153153
label,
154-
pos_count=None,
155-
true_pos=None,
156-
false_pos=None,
157154
overlap_threshold=0.3,
158155
evaluate_difficult=True,
159-
ap_type='integral'):
156+
has_state=None,
157+
input_states=None,
158+
out_states=None,
159+
ap_version='integral'):
160160
helper = LayerHelper("detection_map", **locals())
161161

162-
map_out = helper.create_tmp_variable(dtype='float32')
163-
accum_pos_count_out = helper.create_tmp_variable(dtype='int32')
164-
accum_true_pos_out = helper.create_tmp_variable(dtype='float32')
165-
accum_false_pos_out = helper.create_tmp_variable(dtype='float32')
162+
def __create_var(type):
163+
return helper.create_tmp_variable(dtype=type)
164+
165+
map_out = __create_var('float32')
166+
accum_pos_count_out = out_states[0] if out_states else __create_var('int32')
167+
accum_true_pos_out = out_states[1] if out_states else __create_var(
168+
'float32')
169+
accum_false_pos_out = out_states[2] if out_states else __create_var(
170+
'float32')
171+
172+
pos_count = input_states[0] if input_states else None
173+
true_pos = input_states[1] if input_states else None
174+
false_pos = input_states[2] if input_states else None
175+
166176
helper.append_op(
167177
type="detection_map",
168178
inputs={
169179
'Label': label,
170180
'DetectRes': detect_res,
181+
'HasState': has_state,
171182
'PosCount': pos_count,
172183
'TruePos': true_pos,
173184
'FalsePos': false_pos
@@ -181,9 +192,9 @@ def detection_map(detect_res,
181192
attrs={
182193
'overlap_threshold': overlap_threshold,
183194
'evaluate_difficult': evaluate_difficult,
184-
'ap_type': ap_type
195+
'ap_type': ap_version
185196
})
186-
return map_out, accum_pos_count_out, accum_true_pos_out, accum_false_pos_out
197+
return map_out
187198

188199

189200
def bipartite_match(dist_matrix,

python/paddle/fluid/tests/test_detection.py

Lines changed: 1 addition & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -158,26 +158,9 @@ def test_detection_map(self):
158158
append_batch_size=False,
159159
dtype='float32')
160160

161-
map_out, accum_pos_count_out, accum_true_pos_out, accum_false_pos_out = layers.detection_map(
162-
detect_res=detect_res, label=label)
161+
map_out = layers.detection_map(detect_res=detect_res, label=label)
163162
self.assertIsNotNone(map_out)
164-
self.assertIsNotNone(accum_pos_count_out)
165-
self.assertIsNotNone(accum_true_pos_out)
166-
self.assertIsNotNone(accum_false_pos_out)
167163
self.assertEqual(map_out.shape, (1, ))
168-
map_out, accum_pos_count_out2, accum_true_pos_out2, accum_false_pos_out2 = layers.detection_map(
169-
detect_res=detect_res, label=label)
170-
self.assertIsNotNone(map_out)
171-
self.assertIsNotNone(accum_pos_count_out2)
172-
self.assertIsNotNone(accum_true_pos_out2)
173-
self.assertIsNotNone(accum_false_pos_out2)
174-
self.assertEqual(map_out.shape, (1, ))
175-
self.assertEqual(accum_pos_count_out.shape,
176-
accum_pos_count_out2.shape)
177-
self.assertEqual(accum_true_pos_out.shape,
178-
accum_true_pos_out2.shape)
179-
self.assertEqual(accum_false_pos_out.shape,
180-
accum_false_pos_out2.shape)
181164
print(str(program))
182165

183166

python/paddle/fluid/tests/unittests/test_detection_map_op.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,10 +34,12 @@ def set_data(self):
3434
'int32')
3535
self.true_pos = np.array(self.true_pos).astype('float32')
3636
self.false_pos = np.array(self.false_pos).astype('float32')
37+
self.has_state = np.array([1]).astype('int32')
3738

3839
self.inputs = {
3940
'Label': (self.label, self.label_lod),
4041
'DetectRes': (self.detect, self.detect_lod),
42+
'HasState': self.has_state,
4143
'PosCount': self.class_pos_count,
4244
'TruePos': (self.true_pos, self.true_pos_lod),
4345
'FalsePos': (self.false_pos, self.false_pos_lod)

0 commit comments

Comments
 (0)