Skip to content

Commit ab068c9

Browse files
author
me
committed
assigner: cast to float32 as well in case you're training with bf16. Added focal loss
1 parent 64d9566 commit ab068c9

File tree

2 files changed

+17
-4
lines changed

2 files changed

+17
-4
lines changed

src/assigner.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -266,9 +266,9 @@ std::tuple<torch::Tensor, torch::Tensor, torch::Tensor> tal (
266266

267267
// Store device for later then put everything on CPU
268268
auto device = pred_boxes.device();
269-
pred_boxes = pred_boxes.to(torch::TensorOptions(torch::Device("cpu")));
270-
pred_scores = pred_scores.to(torch::TensorOptions(torch::Device("cpu")));
271-
targets = targets.to(torch::TensorOptions(torch::Device("cpu")));
269+
pred_boxes = pred_boxes.to(torch::kCPU, torch::kFloat);
270+
pred_scores = pred_scores.to(torch::kCPU, torch::kFloat);
271+
targets = targets.to(torch::kCPU, torch::kFloat);
272272

273273
// Outputs
274274
auto boxes = torch::zeros({B, N, 4});

src/models.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,9 @@ def exists(val):
107107
def default(val, d):
108108
return val if exists(val) else d
109109

110+
def default_lazy(x, fn):
111+
return x if exists(x) else fn()
112+
110113
def Repeat(module, N):
111114
return nn.Sequential(*[deepcopy(module) for _ in range(N)])
112115

@@ -886,6 +889,15 @@ def box2dist(box, sxy, strides):
886889
dist = torch.cat([lt,rb], -1)
887890
return dist
888891

892+
def focal_loss(logits, scores, soft_targets, gamma: float = 2.0, reduction: str = 'sum'):
893+
p = default_lazy(scores, lambda: logits.sigmoid())
894+
bce = F.binary_cross_entropy_with_logits(logits, soft_targets, reduction="none")
895+
mod = (p - soft_targets).abs().pow(gamma)
896+
loss = bce*mod
897+
if reduction == 'sum': loss = loss.sum()
898+
elif reduction == 'mean': loss = loss.mean()
899+
return loss
900+
889901
@torch.no_grad()
890902
def make_anchors(feats, strides): # anchor-free
891903
xys, strides2 = [], []
@@ -1028,7 +1040,8 @@ def forward_private(self, xs, cv2, cv3, targets=None):
10281040
else: loss_dfl = (F.l1_loss(ltrb[mask], box2dist(tboxes, sxy, strides)[mask], reduction='none') * weight.unsqueeze(-1)).sum() / tgt_scores_sum
10291041

10301042
# Class loss (positive samples + negative)
1031-
loss_cls = F.binary_cross_entropy_with_logits(logits, tcls*tscores.unsqueeze(-1), reduction='sum') / tgt_scores_sum
1043+
# loss_cls = F.binary_cross_entropy_with_logits(logits, tcls*tscores.unsqueeze(-1), reduction='sum') / tgt_scores_sum
1044+
loss_cls = focal_loss(logits, probs, tcls*tscores.unsqueeze(-1), gamma=2.0, reduction='sum') / tgt_scores_sum
10321045

10331046
return pred if not exists(targets) else (pred, {'iou': loss_iou, 'dfl': loss_dfl, 'cls': loss_cls})
10341047

0 commit comments

Comments
 (0)