@@ -71,6 +71,7 @@ def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor:
7171 y_pred = torch .sigmoid (y_pred )
7272 y_pred = torch .cat ([1 - y_pred , y_pred ], dim = 1 )
7373 is_already_prob = True
74+ # Expand y_true to match if it's single channel
7475 if y_true .shape [1 ] == 1 :
7576 y_true = one_hot (y_true , num_classes = 2 )
7677 else :
@@ -213,17 +214,15 @@ def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor:
213214 # Concatenate losses
214215 all_ce = torch .cat ([back_ce .unsqueeze (1 ), fore_ce ], dim = 1 )
215216
216- # Sum over classes (dim=1) to get total loss per pixel
217- total_loss = torch .sum (all_ce , dim = 1 )
218-
219217 # Apply reduction
220218 if self .reduction == LossReduction .MEAN .value :
221- return torch .mean (total_loss )
219+ return torch .mean (torch . sum ( all_ce , dim = 1 ) )
222220 if self .reduction == LossReduction .SUM .value :
223- return torch .sum (total_loss )
221+ return torch .sum (all_ce )
224222 if self .reduction == LossReduction .NONE .value :
225- return total_loss
226- return torch .mean (total_loss )
223+ return all_ce
224+
225+ return torch .mean (torch .sum (all_ce , dim = 1 ))
227226
228227
229228class AsymmetricUnifiedFocalLoss (_Loss ):
@@ -268,14 +267,14 @@ def __init__(
268267 delta = self .delta ,
269268 use_softmax = self .use_softmax ,
270269 to_onehot_y = to_onehot_y ,
271- reduction = reduction ,
270+ reduction = LossReduction . NONE ,
272271 )
273272 self .asy_focal_tversky_loss = AsymmetricFocalTverskyLoss (
274273 gamma = self .gamma ,
275274 delta = self .delta ,
276275 use_softmax = self .use_softmax ,
277276 to_onehot_y = to_onehot_y ,
278- reduction = reduction ,
277+ reduction = LossReduction . NONE ,
279278 )
280279
281280 def forward (self , y_pred : torch .Tensor , y_true : torch .Tensor ) -> torch .Tensor :
@@ -293,6 +292,21 @@ def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor:
293292 asy_focal_loss = self .asy_focal_loss (y_pred , y_true )
294293 asy_focal_tversky_loss = self .asy_focal_tversky_loss (y_pred , y_true )
295294
296- loss : torch .Tensor = self .weight * asy_focal_loss + (1 - self .weight ) * asy_focal_tversky_loss
295+ # Align Focal Loss to (B, C) by averaging over spatial dimensions
296+ spatial_dims = list (range (2 , len (asy_focal_loss .shape )))
297+ focal_aligned = torch .mean (asy_focal_loss , dim = spatial_dims )
298+
299+ # Calculate weighted sum. Result shape: (B, C)
300+ combined_loss = self .weight * focal_aligned + (1 - self .weight ) * asy_focal_tversky_loss
301+
302+ loss : torch .Tensor
303+ if self .reduction == LossReduction .MEAN .value :
304+ loss = torch .mean (combined_loss )
305+ elif self .reduction == LossReduction .SUM .value :
306+ loss = torch .sum (combined_loss )
307+ elif self .reduction == LossReduction .NONE .value :
308+ loss = combined_loss
309+ else :
310+ loss = torch .mean (combined_loss )
297311
298312 return loss
0 commit comments