@@ -69,6 +69,7 @@ class DetectionMAPOpKernel : public framework::OpKernel<T> {
69
69
float overlap_threshold = ctx.Attr <float >(" overlap_threshold" );
70
70
float evaluate_difficult = ctx.Attr <bool >(" evaluate_difficult" );
71
71
auto ap_type = GetAPType (ctx.Attr <std::string>(" ap_type" ));
72
+ int class_num = ctx.Attr <int >(" class_num" );
72
73
73
74
auto label_lod = in_label->lod ();
74
75
auto detect_lod = in_detect->lod ();
@@ -95,17 +96,19 @@ class DetectionMAPOpKernel : public framework::OpKernel<T> {
95
96
96
97
if (in_pos_count != nullptr && state) {
97
98
GetInputPos (*in_pos_count, *in_true_pos, *in_false_pos, label_pos_count,
98
- true_pos, false_pos);
99
+ true_pos, false_pos, class_num );
99
100
}
100
101
101
102
CalcTrueAndFalsePositive (gt_boxes, detect_boxes, evaluate_difficult,
102
103
overlap_threshold, label_pos_count, true_pos,
103
104
false_pos);
104
105
105
- T map = CalcMAP (ap_type, label_pos_count, true_pos, false_pos);
106
+ int background_label = ctx.Attr <int >(" background_label" );
107
+ T map = CalcMAP (ap_type, label_pos_count, true_pos, false_pos,
108
+ background_label);
106
109
107
110
GetOutputPos (ctx, label_pos_count, true_pos, false_pos, *out_pos_count,
108
- *out_true_pos, *out_false_pos);
111
+ *out_true_pos, *out_false_pos, class_num );
109
112
110
113
T* map_data = out_map->mutable_data <T>(ctx.GetPlace ());
111
114
map_data[0 ] = map;
@@ -190,24 +193,20 @@ class DetectionMAPOpKernel : public framework::OpKernel<T> {
190
193
const std::map<int , std::vector<std::pair<T, int >>>& false_pos,
191
194
framework::Tensor& output_pos_count,
192
195
framework::LoDTensor& output_true_pos,
193
- framework::LoDTensor& output_false_pos) const {
194
- int max_class_id = 0 ;
196
+ framework::LoDTensor& output_false_pos, const int class_num) const {
195
197
int true_pos_count = 0 ;
196
198
int false_pos_count = 0 ;
197
- for (auto it = label_pos_count.begin (); it != label_pos_count.end (); ++it) {
198
- int label = it->first ;
199
- if (label > max_class_id) max_class_id = label;
200
- int label_num_pos = it->second ;
201
- if (label_num_pos == 0 || true_pos.find (label) == true_pos.end ())
202
- continue ;
203
- auto label_true_pos = true_pos.find (label)->second ;
204
- auto label_false_pos = false_pos.find (label)->second ;
205
- true_pos_count += label_true_pos.size ();
206
- false_pos_count += label_false_pos.size ();
199
+ for (auto it = true_pos.begin (); it != true_pos.end (); ++it) {
200
+ auto tp = it->second ;
201
+ true_pos_count += tp.size ();
202
+ }
203
+ for (auto it = false_pos.begin (); it != false_pos.end (); ++it) {
204
+ auto fp = it->second ;
205
+ false_pos_count += fp.size ();
207
206
}
208
207
209
208
int * pos_count_data = output_pos_count.mutable_data <int >(
210
- framework::make_ddim ({max_class_id + 1 , 1 }), ctx.GetPlace ());
209
+ framework::make_ddim ({class_num , 1 }), ctx.GetPlace ());
211
210
212
211
T* true_pos_data = output_true_pos.mutable_data <T>(
213
212
framework::make_ddim ({true_pos_count, 2 }), ctx.GetPlace ());
@@ -217,7 +216,7 @@ class DetectionMAPOpKernel : public framework::OpKernel<T> {
217
216
false_pos_count = 0 ;
218
217
std::vector<size_t > true_pos_starts = {0 };
219
218
std::vector<size_t > false_pos_starts = {0 };
220
- for (int i = 0 ; i <= max_class_id ; ++i) {
219
+ for (int i = 0 ; i < class_num ; ++i) {
221
220
auto it_count = label_pos_count.find (i);
222
221
pos_count_data[i] = 0 ;
223
222
if (it_count != label_pos_count.end ()) {
@@ -258,17 +257,16 @@ class DetectionMAPOpKernel : public framework::OpKernel<T> {
258
257
return ;
259
258
}
260
259
261
- void GetInputPos (
262
- const framework::Tensor& input_pos_count ,
263
- const framework::LoDTensor& input_true_pos ,
264
- const framework::LoDTensor& input_false_pos ,
265
- std::map<int , int >& label_pos_count ,
266
- std::map<int , std::vector<std::pair<T, int >>>& true_pos ,
267
- std::map< int , std::vector<std::pair<T, int >>>& false_pos ) const {
260
+ void GetInputPos (const framework::Tensor& input_pos_count,
261
+ const framework::LoDTensor& input_true_pos ,
262
+ const framework::LoDTensor& input_false_pos ,
263
+ std::map< int , int >& label_pos_count ,
264
+ std::map<int , std::vector<std::pair<T, int >>>& true_pos ,
265
+ std::map<int , std::vector<std::pair<T, int >>>& false_pos ,
266
+ const int class_num ) const {
268
267
constexpr T kEPS = static_cast <T>(1e-6 );
269
- int class_number = input_pos_count.dims ()[0 ];
270
268
const int * pos_count_data = input_pos_count.data <int >();
271
- for (int i = 0 ; i < class_number ; ++i) {
269
+ for (int i = 0 ; i < class_num ; ++i) {
272
270
label_pos_count[i] = pos_count_data[i];
273
271
}
274
272
@@ -391,17 +389,19 @@ class DetectionMAPOpKernel : public framework::OpKernel<T> {
391
389
}
392
390
}
393
391
394
- T CalcMAP (
395
- APType ap_type, const std::map<int , int >& label_pos_count ,
396
- const std::map<int , std::vector<std::pair<T, int >>>& true_pos ,
397
- const std::map< int , std::vector<std::pair<T, int >>>& false_pos ) const {
392
+ T CalcMAP (APType ap_type, const std::map< int , int >& label_pos_count,
393
+ const std::map<int , std::vector<std::pair<T, int >>>& true_pos ,
394
+ const std::map<int , std::vector<std::pair<T, int >>>& false_pos ,
395
+ const int background_label ) const {
398
396
T mAP = 0.0 ;
399
397
int count = 0 ;
400
398
for (auto it = label_pos_count.begin (); it != label_pos_count.end (); ++it) {
401
399
int label = it->first ;
402
400
int label_num_pos = it->second ;
403
- if (label_num_pos == 0 || true_pos.find (label) == true_pos.end ())
401
+ if (label_num_pos == background_label ||
402
+ true_pos.find (label) == true_pos.end ()) {
404
403
continue ;
404
+ }
405
405
auto label_true_pos = true_pos.find (label)->second ;
406
406
auto label_false_pos = false_pos.find (label)->second ;
407
407
// Compute average precision.
@@ -450,7 +450,7 @@ class DetectionMAPOpKernel : public framework::OpKernel<T> {
450
450
}
451
451
}
452
452
if (count != 0 ) mAP /= count;
453
- return mAP * 100 ;
453
+ return mAP ;
454
454
}
455
455
}; // namespace operators
456
456
0 commit comments