@@ -13,6 +13,11 @@ See the License for the specific language governing permissions and
13
13
limitations under the License. */
14
14
15
15
#pragma once
16
+ #include < algorithm>
17
+ #include < map>
18
+ #include < string>
19
+ #include < utility>
20
+ #include < vector>
16
21
#include " paddle/fluid/framework/eigen.h"
17
22
#include " paddle/fluid/framework/op_registry.h"
18
23
@@ -82,7 +87,7 @@ class DetectionMAPOpKernel : public framework::OpKernel<T> {
82
87
std::vector<std::map<int , std::vector<Box>>> gt_boxes;
83
88
std::vector<std::map<int , std::vector<std::pair<T, Box>>>> detect_boxes;
84
89
85
- GetBoxes (*in_label, *in_detect, gt_boxes, detect_boxes);
90
+ GetBoxes (*in_label, *in_detect, & gt_boxes, detect_boxes);
86
91
87
92
std::map<int , int > label_pos_count;
88
93
std::map<int , std::vector<std::pair<T, int >>> true_pos;
@@ -95,20 +100,20 @@ class DetectionMAPOpKernel : public framework::OpKernel<T> {
95
100
}
96
101
97
102
if (in_pos_count != nullptr && state) {
98
- GetInputPos (*in_pos_count, *in_true_pos, *in_false_pos, label_pos_count,
99
- true_pos, false_pos, class_num);
103
+ GetInputPos (*in_pos_count, *in_true_pos, *in_false_pos, & label_pos_count,
104
+ & true_pos, & false_pos, class_num);
100
105
}
101
106
102
107
CalcTrueAndFalsePositive (gt_boxes, detect_boxes, evaluate_difficult,
103
- overlap_threshold, label_pos_count, true_pos,
104
- false_pos);
108
+ overlap_threshold, & label_pos_count, & true_pos,
109
+ & false_pos);
105
110
106
111
int background_label = ctx.Attr <int >(" background_label" );
107
112
T map = CalcMAP (ap_type, label_pos_count, true_pos, false_pos,
108
113
background_label);
109
114
110
- GetOutputPos (ctx, label_pos_count, true_pos, false_pos, * out_pos_count,
111
- * out_true_pos, * out_false_pos, class_num);
115
+ GetOutputPos (ctx, label_pos_count, true_pos, false_pos, out_pos_count,
116
+ out_true_pos, out_false_pos, class_num);
112
117
113
118
T* map_data = out_map->mutable_data <T>(ctx.GetPlace ());
114
119
map_data[0 ] = map;
@@ -155,7 +160,7 @@ class DetectionMAPOpKernel : public framework::OpKernel<T> {
155
160
156
161
void GetBoxes (const framework::LoDTensor& input_label,
157
162
const framework::LoDTensor& input_detect,
158
- std::vector<std::map<int , std::vector<Box>>>& gt_boxes,
163
+ std::vector<std::map<int , std::vector<Box>>>* gt_boxes,
159
164
std::vector<std::map<int , std::vector<std::pair<T, Box>>>>&
160
165
detect_boxes) const {
161
166
auto labels = framework::EigenTensor<T, 2 >::From (input_label);
@@ -179,7 +184,7 @@ class DetectionMAPOpKernel : public framework::OpKernel<T> {
179
184
box.is_difficult = true ;
180
185
boxes[label].push_back (box);
181
186
}
182
- gt_boxes. push_back (boxes);
187
+ gt_boxes-> push_back (boxes);
183
188
}
184
189
185
190
auto detect_index = detect_lod[0 ];
@@ -200,9 +205,9 @@ class DetectionMAPOpKernel : public framework::OpKernel<T> {
200
205
const std::map<int , int >& label_pos_count,
201
206
const std::map<int , std::vector<std::pair<T, int >>>& true_pos,
202
207
const std::map<int , std::vector<std::pair<T, int >>>& false_pos,
203
- framework::Tensor& output_pos_count,
204
- framework::LoDTensor& output_true_pos,
205
- framework::LoDTensor& output_false_pos, const int class_num) const {
208
+ framework::Tensor* output_pos_count,
209
+ framework::LoDTensor* output_true_pos,
210
+ framework::LoDTensor* output_false_pos, const int class_num) const {
206
211
int true_pos_count = 0 ;
207
212
int false_pos_count = 0 ;
208
213
for (auto it = true_pos.begin (); it != true_pos.end (); ++it) {
@@ -214,12 +219,12 @@ class DetectionMAPOpKernel : public framework::OpKernel<T> {
214
219
false_pos_count += fp.size ();
215
220
}
216
221
217
- int * pos_count_data = output_pos_count. mutable_data <int >(
222
+ int * pos_count_data = output_pos_count-> mutable_data <int >(
218
223
framework::make_ddim ({class_num, 1 }), ctx.GetPlace ());
219
224
220
- T* true_pos_data = output_true_pos. mutable_data <T>(
225
+ T* true_pos_data = output_true_pos-> mutable_data <T>(
221
226
framework::make_ddim ({true_pos_count, 2 }), ctx.GetPlace ());
222
- T* false_pos_data = output_false_pos. mutable_data <T>(
227
+ T* false_pos_data = output_false_pos-> mutable_data <T>(
223
228
framework::make_ddim ({false_pos_count, 2 }), ctx.GetPlace ());
224
229
true_pos_count = 0 ;
225
230
false_pos_count = 0 ;
@@ -261,21 +266,21 @@ class DetectionMAPOpKernel : public framework::OpKernel<T> {
261
266
framework::LoD false_pos_lod;
262
267
false_pos_lod.emplace_back (false_pos_starts);
263
268
264
- output_true_pos. set_lod (true_pos_lod);
265
- output_false_pos. set_lod (false_pos_lod);
269
+ output_true_pos-> set_lod (true_pos_lod);
270
+ output_false_pos-> set_lod (false_pos_lod);
266
271
return ;
267
272
}
268
273
269
274
void GetInputPos (const framework::Tensor& input_pos_count,
270
275
const framework::LoDTensor& input_true_pos,
271
276
const framework::LoDTensor& input_false_pos,
272
- std::map<int , int >& label_pos_count,
273
- std::map<int , std::vector<std::pair<T, int >>>& true_pos,
274
- std::map<int , std::vector<std::pair<T, int >>>& false_pos,
277
+ std::map<int , int >* label_pos_count,
278
+ std::map<int , std::vector<std::pair<T, int >>>* true_pos,
279
+ std::map<int , std::vector<std::pair<T, int >>>* false_pos,
275
280
const int class_num) const {
276
281
const int * pos_count_data = input_pos_count.data <int >();
277
282
for (int i = 0 ; i < class_num; ++i) {
278
- label_pos_count[i] = pos_count_data[i];
283
+ (* label_pos_count) [i] = pos_count_data[i];
279
284
}
280
285
281
286
auto SetData = [](const framework::LoDTensor& pos_tensor,
@@ -291,8 +296,8 @@ class DetectionMAPOpKernel : public framework::OpKernel<T> {
291
296
}
292
297
};
293
298
294
- SetData (input_true_pos, true_pos);
295
- SetData (input_false_pos, false_pos);
299
+ SetData (input_true_pos, * true_pos);
300
+ SetData (input_false_pos, * false_pos);
296
301
return ;
297
302
}
298
303
@@ -301,9 +306,9 @@ class DetectionMAPOpKernel : public framework::OpKernel<T> {
301
306
const std::vector<std::map<int , std::vector<std::pair<T, Box>>>>&
302
307
detect_boxes,
303
308
bool evaluate_difficult, float overlap_threshold,
304
- std::map<int , int >& label_pos_count,
305
- std::map<int , std::vector<std::pair<T, int >>>& true_pos,
306
- std::map<int , std::vector<std::pair<T, int >>>& false_pos) const {
309
+ std::map<int , int >* label_pos_count,
310
+ std::map<int , std::vector<std::pair<T, int >>>* true_pos,
311
+ std::map<int , std::vector<std::pair<T, int >>>* false_pos) const {
307
312
int batch_size = gt_boxes.size ();
308
313
for (int n = 0 ; n < batch_size; ++n) {
309
314
auto image_gt_boxes = gt_boxes[n];
@@ -320,10 +325,10 @@ class DetectionMAPOpKernel : public framework::OpKernel<T> {
320
325
continue ;
321
326
}
322
327
int label = it->first ;
323
- if (label_pos_count. find (label) == label_pos_count. end ()) {
324
- label_pos_count[label] = count;
328
+ if (label_pos_count-> find (label) == label_pos_count-> end ()) {
329
+ (* label_pos_count) [label] = count;
325
330
} else {
326
- label_pos_count[label] += count;
331
+ (* label_pos_count) [label] += count;
327
332
}
328
333
}
329
334
}
@@ -338,8 +343,8 @@ class DetectionMAPOpKernel : public framework::OpKernel<T> {
338
343
int label = it->first ;
339
344
for (size_t i = 0 ; i < pred_boxes.size (); ++i) {
340
345
auto score = pred_boxes[i].first ;
341
- true_pos[label].push_back (std::make_pair (score, 0 ));
342
- false_pos[label].push_back (std::make_pair (score, 1 ));
346
+ (* true_pos) [label].push_back (std::make_pair (score, 0 ));
347
+ (* false_pos) [label].push_back (std::make_pair (score, 1 ));
343
348
}
344
349
}
345
350
continue ;
@@ -351,8 +356,8 @@ class DetectionMAPOpKernel : public framework::OpKernel<T> {
351
356
if (image_gt_boxes.find (label) == image_gt_boxes.end ()) {
352
357
for (size_t i = 0 ; i < pred_boxes.size (); ++i) {
353
358
auto score = pred_boxes[i].first ;
354
- true_pos[label].push_back (std::make_pair (score, 0 ));
355
- false_pos[label].push_back (std::make_pair (score, 1 ));
359
+ (* true_pos) [label].push_back (std::make_pair (score, 0 ));
360
+ (* false_pos) [label].push_back (std::make_pair (score, 1 ));
356
361
}
357
362
continue ;
358
363
}
@@ -381,17 +386,17 @@ class DetectionMAPOpKernel : public framework::OpKernel<T> {
381
386
(!evaluate_difficult && !matched_bboxes[max_idx].is_difficult );
382
387
if (match_evaluate_difficult) {
383
388
if (!visited[max_idx]) {
384
- true_pos[label].push_back (std::make_pair (score, 1 ));
385
- false_pos[label].push_back (std::make_pair (score, 0 ));
389
+ (* true_pos) [label].push_back (std::make_pair (score, 1 ));
390
+ (* false_pos) [label].push_back (std::make_pair (score, 0 ));
386
391
visited[max_idx] = true ;
387
392
} else {
388
- true_pos[label].push_back (std::make_pair (score, 0 ));
389
- false_pos[label].push_back (std::make_pair (score, 1 ));
393
+ (* true_pos) [label].push_back (std::make_pair (score, 0 ));
394
+ (* false_pos) [label].push_back (std::make_pair (score, 1 ));
390
395
}
391
396
}
392
397
} else {
393
- true_pos[label].push_back (std::make_pair (score, 0 ));
394
- false_pos[label].push_back (std::make_pair (score, 1 ));
398
+ (* true_pos) [label].push_back (std::make_pair (score, 0 ));
399
+ (* false_pos) [label].push_back (std::make_pair (score, 1 ));
395
400
}
396
401
}
397
402
}
0 commit comments