@@ -71,20 +71,26 @@ def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor:
7171 y_pred: prediction logits or probabilities. Shape should be (B, C, spatial_dims).
7272 y_true: ground truth labels. Shape should match y_pred.
7373 """
74-
74+
75+ # Track whether we've already converted to probabilities
76+ is_converted_to_prob = self .is_prob_input
77+
7578 # Handle single-channel binary case
7679 if y_pred .shape [1 ] == 1 :
7780 # For binary segmentation, enforce num_classes == 2
7881 if self .num_classes != 2 :
79- raise ValueError (f"Single-channel input requires num_classes=2, but got num_classes={ self .num_classes } " )
80-
82+ raise ValueError (
83+ f"Single-channel input requires num_classes=2, but got num_classes={ self .num_classes } "
84+ )
85+
8186 # Convert to probability if needed
82- if not self . is_prob_input :
87+ if not is_converted_to_prob :
8388 y_pred = torch .sigmoid (y_pred )
84-
89+ is_converted_to_prob = True
90+
8591 # Expand to 2 channels
8692 y_pred = torch .cat ([1 - y_pred , y_pred ], dim = 1 )
87-
93+
8894 # Convert y_true to one-hot with 2 classes (matching expanded y_pred)
8995 if y_true .shape [1 ] == 1 :
9096 y_true = one_hot (y_true , num_classes = 2 )
@@ -107,14 +113,16 @@ def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor:
107113 if y_true .shape [1 ] == 1 :
108114 # Validate class indices
109115 if torch .max (y_true ) >= self .num_classes :
110- raise ValueError (f"Ground truth contains class indices >= { self .num_classes } " )
116+ raise ValueError (
117+ f"Ground truth contains class indices >= { self .num_classes } "
118+ )
111119 y_true = one_hot (y_true , num_classes = self .num_classes )
112120
113121 if y_true .shape != y_pred .shape :
114122 raise ValueError (f"ground truth has different shape ({ y_true .shape } ) from input ({ y_pred .shape } )" )
115123
116124 # Convert logits to probabilities if not already done
117- if not self . is_prob_input and y_pred . shape [ 1 ] != 1 :
125+ if not is_converted_to_prob :
118126 if self .use_softmax :
119127 y_pred = torch .softmax (y_pred , dim = 1 )
120128 else :
@@ -203,19 +211,25 @@ def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor:
203211 y_true: ground truth labels.
204212 """
205213
214+ # Track whether we've already converted to probabilities
215+ is_converted_to_prob = self .is_prob_input
216+
206217 # Handle single-channel binary case
207218 if y_pred .shape [1 ] == 1 :
208219 # For binary segmentation, enforce num_classes == 2
209220 if self .num_classes != 2 :
210- raise ValueError (f"Single-channel input requires num_classes=2, but got num_classes={ self .num_classes } " )
211-
221+ raise ValueError (
222+ f"Single-channel input requires num_classes=2, but got num_classes={ self .num_classes } "
223+ )
224+
212225 # Convert to probability if needed
213- if not self . is_prob_input :
226+ if not is_converted_to_prob :
214227 y_pred = torch .sigmoid (y_pred )
215-
228+ is_converted_to_prob = True
229+
216230 # Expand to 2 channels
217231 y_pred = torch .cat ([1 - y_pred , y_pred ], dim = 1 )
218-
232+
219233 # Convert y_true to one-hot with 2 classes (matching expanded y_pred)
220234 if y_true .shape [1 ] == 1 :
221235 y_true = one_hot (y_true , num_classes = 2 )
@@ -238,14 +252,16 @@ def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor:
238252 if y_true .shape [1 ] == 1 :
239253 # Validate class indices
240254 if torch .max (y_true ) >= self .num_classes :
241- raise ValueError (f"Ground truth contains class indices >= { self .num_classes } " )
255+ raise ValueError (
256+ f"Ground truth contains class indices >= { self .num_classes } "
257+ )
242258 y_true = one_hot (y_true , num_classes = self .num_classes )
243259
244260 if y_true .shape != y_pred .shape :
245261 raise ValueError (f"ground truth has different shape ({ y_true .shape } ) from input ({ y_pred .shape } )" )
246262
247263 # Convert logits to probabilities if not already done
248- if not self . is_prob_input and y_pred . shape [ 1 ] != 1 :
264+ if not is_converted_to_prob :
249265 if self .use_softmax :
250266 y_pred = torch .softmax (y_pred , dim = 1 )
251267 else :
@@ -267,10 +283,12 @@ def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor:
267283
268284 # Apply reduction
269285 if self .reduction == LossReduction .MEAN .value :
286+ # For MEAN: average over all elements (batch and classes and spatial)
270287 return torch .mean (all_ce )
271288 if self .reduction == LossReduction .SUM .value :
272289 return torch .sum (all_ce )
273290 if self .reduction == LossReduction .NONE .value :
291+ # For NONE: return full tensor with all spatial dimensions
274292 return all_ce
275293
276294 return torch .mean (all_ce )
@@ -316,11 +334,13 @@ def __init__(
316334 >>> fl(pred, grnd)
317335 """
318336 super ().__init__ (reduction = LossReduction (reduction ).value )
319-
337+
320338 # Validate configuration
321339 if use_softmax and num_classes < 2 :
322- raise ValueError (f"use_softmax=True requires num_classes >= 2, but got num_classes={ num_classes } " )
323-
340+ raise ValueError (
341+ f"use_softmax=True requires num_classes >= 2, but got num_classes={ num_classes } "
342+ )
343+
324344 self .to_onehot_y = to_onehot_y
325345 self .num_classes = num_classes
326346 self .gamma = gamma
@@ -362,18 +382,18 @@ def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor:
362382 ValueError: When the number of classes in ground truth exceeds the configured `num_classes`.
363383 ValueError: When use_softmax=True with single-channel input.
364384 """
365-
385+
366386 # Validate input dimensions
367387 if len (y_pred .shape ) not in [4 , 5 ]:
368388 raise ValueError (f"Input shape must be 4 (2D) or 5 (3D), but got { y_pred .shape } " )
369-
389+
370390 # Validate use_softmax with single channel
371391 if self .use_softmax and y_pred .shape [1 ] == 1 :
372392 raise ValueError (
373393 "use_softmax=True is invalid with single-channel input (C=1). "
374394 "Use use_softmax=False for binary segmentation with sigmoid activation."
375395 )
376-
396+
377397 # Shape validation
378398 if y_pred .shape != y_true .shape :
379399 # Allow mismatch only for valid cases
@@ -386,13 +406,14 @@ def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor:
386406 "and this mismatch cannot be resolved by `to_onehot_y` or binary expansion."
387407 )
388408
409+ # Pre-process y_true if needed (will be done inside sub-losses, but validate here)
389410 if self .to_onehot_y and y_true .shape [1 ] == 1 :
390411 # Check indices validity before conversion
391412 if torch .max (y_true ) >= self .num_classes :
392413 raise ValueError (
393414 f"Ground truth contains class indices >= { self .num_classes } , which exceeds num_classes."
394415 )
395-
416+
396417 # Get losses from sub-losses
397418 # AsymmetricFocalLoss with NONE reduction returns (B, C, H, W, [D])
398419 # AsymmetricFocalTverskyLoss with NONE reduction returns (B, C)
@@ -421,4 +442,4 @@ def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor:
421442 else :
422443 loss = torch .mean (combined_loss )
423444
424- return loss
445+ return loss
0 commit comments