@@ -107,6 +107,9 @@ def exists(val):
107107def default (val , d ):
108108 return val if exists (val ) else d
109109
110+ def default_lazy (x , fn ):
111+ return x if exists (x ) else fn ()
112+
110113def Repeat (module , N ):
111114 return nn .Sequential (* [deepcopy (module ) for _ in range (N )])
112115
@@ -886,6 +889,15 @@ def box2dist(box, sxy, strides):
886889 dist = torch .cat ([lt ,rb ], - 1 )
887890 return dist
888891
892+ def focal_loss (logits , scores , soft_targets , gamma : float = 2.0 , reduction : str = 'sum' ):
893+ p = default_lazy (scores , lambda : logits .sigmoid ())
894+ bce = F .binary_cross_entropy_with_logits (logits , soft_targets , reduction = "none" )
895+ mod = (p - soft_targets ).abs ().pow (gamma )
896+ loss = bce * mod
897+ if reduction == 'sum' : loss = loss .sum ()
898+ elif reduction == 'mean' : loss = loss .mean ()
899+ return loss
900+
889901@torch .no_grad ()
890902def make_anchors (feats , strides ): # anchor-free
891903 xys , strides2 = [], []
@@ -1028,7 +1040,8 @@ def forward_private(self, xs, cv2, cv3, targets=None):
10281040 else : loss_dfl = (F .l1_loss (ltrb [mask ], box2dist (tboxes , sxy , strides )[mask ], reduction = 'none' ) * weight .unsqueeze (- 1 )).sum () / tgt_scores_sum
10291041
10301042 # Class loss (positive samples + negative)
1031- loss_cls = F .binary_cross_entropy_with_logits (logits , tcls * tscores .unsqueeze (- 1 ), reduction = 'sum' ) / tgt_scores_sum
1043+ # loss_cls = F.binary_cross_entropy_with_logits(logits, tcls*tscores.unsqueeze(-1), reduction='sum') / tgt_scores_sum
1044+ loss_cls = focal_loss (logits , probs , tcls * tscores .unsqueeze (- 1 ), gamma = 2.0 , reduction = 'sum' ) / tgt_scores_sum
10321045
10331046 return pred if not exists (targets ) else (pred , {'iou' : loss_iou , 'dfl' : loss_dfl , 'cls' : loss_cls })
10341047
0 commit comments