@@ -257,7 +257,7 @@ std::tuple<torch::Tensor, torch::Tensor, torch::Tensor> tal (
257257 torch::Tensor pred_scores, // [B, N, nc]
258258 torch::Tensor sxy, // [N, 2]
259259 torch::Tensor targets, // [B, D, 5]
260- const long topk, // Number of candidates per pyramic level
260+ const long topk,
261261 const float alpha,
262262 const float beta
263263)
@@ -290,14 +290,14 @@ std::tuple<torch::Tensor, torch::Tensor, torch::Tensor> tal (
290290 std::vector<float > metric (N);
291291 std::vector<float > ious (N);
292292 std::vector<long > indices (N);
293- std::vector<float > best_ious (N);
293+ std::vector<float > best_metric (N);
294294 std::vector<long > best_cls (N);
295295
296296 for (long b = 0 ; b < B ; ++b)
297297 {
298298 // Running track of best IOU and best class for an anchor point
299- std::fill (begin (best_ious ), end (best_ious ), 0 .0f );
300- std::fill (begin (best_cls), end (best_cls), 0 );
299+ std::fill (begin (best_metric ), end (best_metric ), 0 .0f );
300+ std::fill (begin (best_cls), end (best_cls), 0 );
301301
302302 for (long d = 0 ; d < D ; ++d)
303303 {
@@ -311,7 +311,7 @@ std::tuple<torch::Tensor, torch::Tensor, torch::Tensor> tal (
311311 {
312312 const box pbox = {pred_boxes_a[b][n][0 ], pred_boxes_a[b][n][1 ], pred_boxes_a[b][n][2 ], pred_boxes_a[b][n][3 ]};
313313 const float pscore = pred_scores_a[b][n][tcls];
314- ious[n] = iou (tbox, pbox, CIOU );
314+ ious[n] = iou (tbox, pbox, IOU );
315315 metric[n] = pow (pscore, alpha) * pow (ious[n], beta);
316316 }
317317
@@ -333,18 +333,20 @@ std::tuple<torch::Tensor, torch::Tensor, torch::Tensor> tal (
333333 // 4. Add to targets2
334334 for (size_t i = 0 ; i < topk ; ++i)
335335 {
336- const long n = indices[i];
336+ const long n = indices[i];
337+ const float scale = max_metric > 0 .f ? (max_iou / max_metric) : 0 .f ;
338+ const float t_hat = metric[n] * scale;
337339
338- if (contains (tbox, sxy_a[n][0 ], sxy_a[n][1 ]) && ious [n] > best_ious [n])
340+ if (contains (tbox, sxy_a[n][0 ], sxy_a[n][1 ]) && metric [n] > best_metric [n])
339341 {
340342 boxes_a[b][n][0 ] = tbox[0 ];
341343 boxes_a[b][n][1 ] = tbox[1 ];
342344 boxes_a[b][n][2 ] = tbox[2 ];
343345 boxes_a[b][n][3 ] = tbox[3 ];
344- scores_a[b][n] = metric[n] * max_iou / max_metric ;
346+ scores_a[b][n] = t_hat ;
345347 classes_a[b][n][best_cls[n]] = 0 ; // Reset last class
346348 classes_a[b][n][tcls] = 1 ;
347- best_ious [n] = ious [n];
349+ best_metric [n] = metric [n];
348350 best_cls[n] = tcls;
349351 }
350352 }
0 commit comments