@@ -890,7 +890,7 @@ def make_anchors(feats, strides): # anchor-free
890890 xy = rearrange ([sx ,sy ], 'c h w -> (h w) c' )
891891 xys .append (xy )
892892 strides2 .append (torch .full ((h * w ,1 ), fill_value = stride , device = x .device ))
893- return * pack (xys , '* c' ), torch .cat (strides2 ,0 )
893+ return torch . cat (xys ,0 ), torch .cat (strides2 ,0 )
894894
895895@torch .no_grad ()
896896def make_anchors_ab (feats , strides , scales , anchors ): # anchor-based
@@ -970,7 +970,7 @@ def forward(self, xs, targets=None):
970970 mask = tscores > 0
971971
972972 # CIOU loss (positive samples)
973- tgt_scores_sum = max ( tscores .sum (), 1 )
973+ tgt_scores_sum = tscores .sum (). clamp ( min = 1.0 )
974974 weight = tscores [mask ]
975975 loss_iou = (torchvision .ops .complete_box_iou_loss (box [mask ], tboxes [mask ], reduction = 'none' ) * weight ).sum () / tgt_scores_sum
976976
@@ -1001,32 +1001,29 @@ def spconv(c1, c2, k): return nn.Sequential(Conv(c1,c1,k,g=c1),Conv(c1,c2,1))
10011001 if end2end : self .one2one_cv3 = deepcopy (self .cv3 )
10021002
10031003 def forward_private (self , xs , cv2 , cv3 , targets = None ):
1004- sxy , ps , strides = make_anchors (xs , self .strides )
1005- feats = [rearrange (torch .cat ((c1 (x ), c2 (x )), 1 ), 'b f h w -> b (h w) f' ) for x ,c1 ,c2 in zip (xs , cv2 , cv3 )]
1006- dist , cls = torch .cat (feats , 1 ).split ((4 * self .reg_max , self .nc ), - 1 )
1007- dist = rearrange (dist , 'b n (k r) -> b n k r' , k = 4 )
1008- ltrb = torch .einsum ('bnkr, r -> bnk' , dist .softmax (- 1 ), self .r ) if self .dfl else dist .squeeze (- 1 )
1009- box = dist2box (ltrb , sxy , strides )
1010- pred = torch .cat ((box , cls .sigmoid ()), - 1 )
1004+ sxy , strides = make_anchors (xs , self .strides )
1005+ feats = [rearrange (torch .cat ((c1 (x ), c2 (x )), 1 ), 'b f h w -> b (h w) f' ) for x ,c1 ,c2 in zip (xs , cv2 , cv3 )]
1006+ dist , logits = torch .cat (feats , 1 ).split ((4 * self .reg_max , self .nc ), - 1 )
1007+ dist = rearrange (dist , 'b n (k r) -> b n k r' , k = 4 )
1008+ ltrb = torch .einsum ('bnkr, r -> bnk' , dist .softmax (- 1 ), self .r ) if self .dfl else dist .squeeze (- 1 )
1009+ box = dist2box (ltrb , sxy , strides )
1010+ probs = logits .sigmoid ()
1011+ pred = torch .cat ((box , probs ), - 1 )
10111012
10121013 if exists (targets ):
1013- # awh = torch.full_like(sxy, fill_value=5.0) * strides # Fake height and width for the sake of ATSS
1014- # anchors = torch.cat([sxy-awh/2, sxy+awh/2],-1)
1015- # tboxes, tscores, tcls = assigner.atss(anchors, targets, [p[0] for p in ps], self.nc, 9)
1016- tboxes , tscores , tcls = assigner .tal (box , cls .sigmoid (), sxy , targets , 9 , 0.5 , 6.0 )
1017- # tboxes, tscores, tcls = assigner.fcos(sxy, targets, self.nc)
1018- mask = tscores > 0
1014+ tboxes , tscores , tcls = assigner .tal (box , probs , sxy , targets , 9 , 0.5 , 6.0 )
1015+ mask = tscores > 0
10191016
10201017 # CIOU loss (positive samples)
1021- tgt_scores_sum = max ( tscores .sum (), 1 )
1018+ tgt_scores_sum = tscores .sum (). clamp ( min = 1.0 )
10221019 weight = tscores [mask ]
10231020 loss_iou = (torchvision .ops .complete_box_iou_loss (box [mask ], tboxes [mask ], reduction = 'none' ) * weight ).sum () / tgt_scores_sum
10241021
10251022 # DFL loss (positive samples)
1026- loss_dfl = dfl_loss (tboxes , mask , tgt_scores_sum , sxy , strides , dist )
1023+ loss_dfl = dfl_loss (tboxes , mask , tgt_scores_sum , sxy , strides , dist ) if self . dfl else torch . zeros ((), device = box . device )
10271024
10281025 # Class loss (positive samples + negative)
1029- loss_cls = F .binary_cross_entropy_with_logits (cls , tcls * tscores .unsqueeze (- 1 ), reduction = 'sum' ) / tgt_scores_sum
1026+ loss_cls = F .binary_cross_entropy_with_logits (logits , tcls * tscores .unsqueeze (- 1 ), reduction = 'sum' ) / tgt_scores_sum
10301027
10311028 return pred if not exists (targets ) else (pred , {'iou' : loss_iou , 'dfl' : loss_dfl , 'cls' : loss_cls })
10321029
0 commit comments