@@ -38,22 +38,22 @@ class MultiClassNMSOp : public framework::OperatorWithKernel {
38
38
auto box_dims = ctx->GetInputDim (" BBoxes" );
39
39
auto score_dims = ctx->GetInputDim (" Scores" );
40
40
41
- PADDLE_ENFORCE_EQ (box_dims.size (), 2 ,
42
- " The rank of Input(BBoxes) must be 2 ." );
41
+ PADDLE_ENFORCE_EQ (box_dims.size (), 3 ,
42
+ " The rank of Input(BBoxes) must be 3 ." );
43
43
PADDLE_ENFORCE_EQ (score_dims.size (), 3 ,
44
44
" The rank of Input(Scores) must be 3." );
45
- PADDLE_ENFORCE_EQ (box_dims[1 ], 4 ,
45
+ PADDLE_ENFORCE_EQ (box_dims[2 ], 4 ,
46
46
" The 2nd dimension of Input(BBoxes) must be 4, "
47
47
" represents the layout of coordinate "
48
48
" [xmin, ymin, xmax, ymax]" );
49
- PADDLE_ENFORCE_EQ (box_dims[0 ], score_dims[2 ],
49
+ PADDLE_ENFORCE_EQ (box_dims[1 ], score_dims[2 ],
50
50
" The 1st dimensiong of Input(BBoxes) must be equal to "
51
51
" 3rd dimension of Input(Scores), which represents the "
52
52
" predicted bboxes." );
53
53
54
54
// Here the box_dims[0] is not the real dimension of output.
55
55
// It will be rewritten in the computing kernel.
56
- ctx->SetOutputDim (" Out" , {box_dims[0 ], 6 });
56
+ ctx->SetOutputDim (" Out" , {box_dims[1 ], 6 });
57
57
}
58
58
59
59
protected:
@@ -260,15 +260,20 @@ class MultiClassNMSKernel : public framework::OpKernel<T> {
260
260
int64_t batch_size = score_dims[0 ];
261
261
int64_t class_num = score_dims[1 ];
262
262
int64_t predict_dim = score_dims[2 ];
263
+ int64_t box_dim = boxes->dims ()[2 ];
263
264
264
265
std::vector<std::map<int , std::vector<int >>> all_indices;
265
266
std::vector<size_t > batch_starts = {0 };
266
267
for (int64_t i = 0 ; i < batch_size; ++i) {
267
268
Tensor ins_score = scores->Slice (i, i + 1 );
268
269
ins_score.Resize ({class_num, predict_dim});
270
+
271
+ Tensor ins_boxes = boxes->Slice (i, i + 1 );
272
+ ins_boxes.Resize ({predict_dim, box_dim});
273
+
269
274
std::map<int , std::vector<int >> indices;
270
275
int num_nmsed_out = 0 ;
271
- MultiClassNMS (ctx, ins_score, *boxes , indices, num_nmsed_out);
276
+ MultiClassNMS (ctx, ins_score, ins_boxes , indices, num_nmsed_out);
272
277
all_indices.push_back (indices);
273
278
batch_starts.push_back (batch_starts.back () + num_nmsed_out);
274
279
}
@@ -282,11 +287,15 @@ class MultiClassNMSKernel : public framework::OpKernel<T> {
282
287
for (int64_t i = 0 ; i < batch_size; ++i) {
283
288
Tensor ins_score = scores->Slice (i, i + 1 );
284
289
ins_score.Resize ({class_num, predict_dim});
290
+
291
+ Tensor ins_boxes = boxes->Slice (i, i + 1 );
292
+ ins_boxes.Resize ({predict_dim, box_dim});
293
+
285
294
int64_t s = batch_starts[i];
286
295
int64_t e = batch_starts[i + 1 ];
287
296
if (e > s) {
288
297
Tensor out = outs->Slice (s, e);
289
- MultiClassOutput (ins_score, *boxes , all_indices[i], &out);
298
+ MultiClassOutput (ins_score, ins_boxes , all_indices[i], &out);
290
299
}
291
300
}
292
301
}
@@ -303,9 +312,9 @@ class MultiClassNMSOpMaker : public framework::OpProtoAndCheckerMaker {
303
312
MultiClassNMSOpMaker (OpProto* proto, OpAttrChecker* op_checker)
304
313
: OpProtoAndCheckerMaker(proto, op_checker) {
305
314
AddInput (" BBoxes" ,
306
- " (Tensor) A 2 -D Tensor with shape [M, 4] represents the "
307
- " predicted locations of M bounding bboxes. Each bounding box "
308
- " has four coordinate values and the layout is "
315
+ " (Tensor) A 3 -D Tensor with shape [N, M, 4] represents the "
316
+ " predicted locations of M bounding bboxes, N is the batch size. "
317
+ " Each bounding box has four coordinate values and the layout is "
309
318
" [xmin, ymin, xmax, ymax]." );
310
319
AddInput (" Scores" ,
311
320
" (Tensor) A 3-D Tensor with shape [N, C, M] represents the "
0 commit comments