Skip to content

Commit d3856ae

Browse files
author
pfeatherstone
committed
- don't compute cls.sigmoid() twice. rename cls to logits and compute scores/probs once.
- tgt_scores_sum potential bug fix - optionally compute DFL loss if DFL is enabled. -
1 parent e42d8bc commit d3856ae

File tree

1 file changed

+15
-18
lines changed

1 file changed

+15
-18
lines changed

src/models.py

Lines changed: 15 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -890,7 +890,7 @@ def make_anchors(feats, strides): # anchor-free
890890
xy = rearrange([sx,sy], 'c h w -> (h w) c')
891891
xys.append(xy)
892892
strides2.append(torch.full((h*w,1), fill_value=stride, device=x.device))
893-
return *pack(xys, '* c'), torch.cat(strides2,0)
893+
return torch.cat(xys,0), torch.cat(strides2,0)
894894

895895
@torch.no_grad()
896896
def make_anchors_ab(feats, strides, scales, anchors): # anchor-based
@@ -970,7 +970,7 @@ def forward(self, xs, targets=None):
970970
mask = tscores > 0
971971

972972
# CIOU loss (positive samples)
973-
tgt_scores_sum = max(tscores.sum(), 1)
973+
tgt_scores_sum = tscores.sum().clamp(min=1.0)
974974
weight = tscores[mask]
975975
loss_iou = (torchvision.ops.complete_box_iou_loss(box[mask], tboxes[mask], reduction='none') * weight).sum() / tgt_scores_sum
976976

@@ -1001,32 +1001,29 @@ def spconv(c1, c2, k): return nn.Sequential(Conv(c1,c1,k,g=c1),Conv(c1,c2,1))
10011001
if end2end: self.one2one_cv3 = deepcopy(self.cv3)
10021002

10031003
def forward_private(self, xs, cv2, cv3, targets=None):
1004-
sxy, ps, strides = make_anchors(xs, self.strides)
1005-
feats = [rearrange(torch.cat((c1(x), c2(x)), 1), 'b f h w -> b (h w) f') for x,c1,c2 in zip(xs, cv2, cv3)]
1006-
dist, cls = torch.cat(feats, 1).split((4 * self.reg_max, self.nc), -1)
1007-
dist = rearrange(dist, 'b n (k r) -> b n k r', k=4)
1008-
ltrb = torch.einsum('bnkr, r -> bnk', dist.softmax(-1), self.r) if self.dfl else dist.squeeze(-1)
1009-
box = dist2box(ltrb, sxy, strides)
1010-
pred = torch.cat((box, cls.sigmoid()), -1)
1004+
sxy, strides = make_anchors(xs, self.strides)
1005+
feats = [rearrange(torch.cat((c1(x), c2(x)), 1), 'b f h w -> b (h w) f') for x,c1,c2 in zip(xs, cv2, cv3)]
1006+
dist, logits = torch.cat(feats, 1).split((4 * self.reg_max, self.nc), -1)
1007+
dist = rearrange(dist, 'b n (k r) -> b n k r', k=4)
1008+
ltrb = torch.einsum('bnkr, r -> bnk', dist.softmax(-1), self.r) if self.dfl else dist.squeeze(-1)
1009+
box = dist2box(ltrb, sxy, strides)
1010+
probs = logits.sigmoid()
1011+
pred = torch.cat((box, probs), -1)
10111012

10121013
if exists(targets):
1013-
# awh = torch.full_like(sxy, fill_value=5.0) * strides # Fake height and width for the sake of ATSS
1014-
# anchors = torch.cat([sxy-awh/2, sxy+awh/2],-1)
1015-
# tboxes, tscores, tcls = assigner.atss(anchors, targets, [p[0] for p in ps], self.nc, 9)
1016-
tboxes, tscores, tcls = assigner.tal(box, cls.sigmoid(), sxy, targets, 9, 0.5, 6.0)
1017-
# tboxes, tscores, tcls = assigner.fcos(sxy, targets, self.nc)
1018-
mask = tscores > 0
1014+
tboxes, tscores, tcls = assigner.tal(box, probs, sxy, targets, 9, 0.5, 6.0)
1015+
mask = tscores > 0
10191016

10201017
# CIOU loss (positive samples)
1021-
tgt_scores_sum = max(tscores.sum(), 1)
1018+
tgt_scores_sum = tscores.sum().clamp(min=1.0)
10221019
weight = tscores[mask]
10231020
loss_iou = (torchvision.ops.complete_box_iou_loss(box[mask], tboxes[mask], reduction='none') * weight).sum() / tgt_scores_sum
10241021

10251022
# DFL loss (positive samples)
1026-
loss_dfl = dfl_loss(tboxes, mask, tgt_scores_sum, sxy, strides, dist)
1023+
loss_dfl = dfl_loss(tboxes, mask, tgt_scores_sum, sxy, strides, dist) if self.dfl else torch.zeros((), device=box.device)
10271024

10281025
# Class loss (positive samples + negative)
1029-
loss_cls = F.binary_cross_entropy_with_logits(cls, tcls*tscores.unsqueeze(-1), reduction='sum') / tgt_scores_sum
1026+
loss_cls = F.binary_cross_entropy_with_logits(logits, tcls*tscores.unsqueeze(-1), reduction='sum') / tgt_scores_sum
10301027

10311028
return pred if not exists(targets) else (pred, {'iou': loss_iou, 'dfl': loss_dfl, 'cls': loss_cls})
10321029

0 commit comments

Comments
 (0)