Skip to content

Commit e42d8bc

Browse files
author
pfeatherstone
committed
TAL: uses normal IOU in metric calculation instead of CIOU. Avoid divide by 0 when computing t_hat, i.e. normalized metric. Rank anchors by best metric, not normalized anchors.
1 parent ab179dd commit e42d8bc

File tree

1 file changed

+11
-9
lines changed

1 file changed

+11
-9
lines changed

src/assigner.cpp

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -257,7 +257,7 @@ std::tuple<torch::Tensor, torch::Tensor, torch::Tensor> tal (
257257
torch::Tensor pred_scores, // [B, N, nc]
258258
torch::Tensor sxy, // [N, 2]
259259
torch::Tensor targets, // [B, D, 5]
260-
const long topk, // Number of candidates per pyramic level
260+
const long topk,
261261
const float alpha,
262262
const float beta
263263
)
@@ -290,14 +290,14 @@ std::tuple<torch::Tensor, torch::Tensor, torch::Tensor> tal (
290290
std::vector<float> metric(N);
291291
std::vector<float> ious(N);
292292
std::vector<long> indices(N);
293-
std::vector<float> best_ious(N);
293+
std::vector<float> best_metric(N);
294294
std::vector<long> best_cls(N);
295295

296296
for (long b = 0 ; b < B ; ++b)
297297
{
298298
// Running track of best IOU and best class for an anchor point
299-
std::fill(begin(best_ious), end(best_ious), 0.0f);
300-
std::fill(begin(best_cls), end(best_cls), 0);
299+
std::fill(begin(best_metric), end(best_metric), 0.0f);
300+
std::fill(begin(best_cls), end(best_cls), 0);
301301

302302
for (long d = 0 ; d < D ; ++d)
303303
{
@@ -311,7 +311,7 @@ std::tuple<torch::Tensor, torch::Tensor, torch::Tensor> tal (
311311
{
312312
const box pbox = {pred_boxes_a[b][n][0], pred_boxes_a[b][n][1], pred_boxes_a[b][n][2], pred_boxes_a[b][n][3]};
313313
const float pscore = pred_scores_a[b][n][tcls];
314-
ious[n] = iou(tbox, pbox, CIOU);
314+
ious[n] = iou(tbox, pbox, IOU);
315315
metric[n] = pow(pscore, alpha) * pow(ious[n], beta);
316316
}
317317

@@ -333,18 +333,20 @@ std::tuple<torch::Tensor, torch::Tensor, torch::Tensor> tal (
333333
// 4. Add to targets2
334334
for (size_t i = 0 ; i < topk ; ++i)
335335
{
336-
const long n = indices[i];
336+
const long n = indices[i];
337+
const float scale = max_metric > 0.f ? (max_iou / max_metric) : 0.f;
338+
const float t_hat = metric[n] * scale;
337339

338-
if (contains(tbox, sxy_a[n][0], sxy_a[n][1]) && ious[n] > best_ious[n])
340+
if (contains(tbox, sxy_a[n][0], sxy_a[n][1]) && metric[n] > best_metric[n])
339341
{
340342
boxes_a[b][n][0] = tbox[0];
341343
boxes_a[b][n][1] = tbox[1];
342344
boxes_a[b][n][2] = tbox[2];
343345
boxes_a[b][n][3] = tbox[3];
344-
scores_a[b][n] = metric[n] * max_iou / max_metric;
346+
scores_a[b][n] = t_hat;
345347
classes_a[b][n][best_cls[n]] = 0; // Reset last class
346348
classes_a[b][n][tcls] = 1;
347-
best_ious[n] = ious[n];
349+
best_metric[n] = metric[n];
348350
best_cls[n] = tcls;
349351
}
350352
}

0 commit comments

Comments
 (0)