@@ -71,26 +71,26 @@ def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor:
7171 y_pred: prediction logits or probabilities. Shape should be (B, C, spatial_dims).
7272 y_true: ground truth labels. Shape should match y_pred.
7373 """
74-
74+
7575 # Track whether we've already converted to probabilities
7676 is_converted_to_prob = self .is_prob_input
77-
77+
7878 # Handle single-channel binary case
7979 if y_pred .shape [1 ] == 1 :
8080 # For binary segmentation, enforce num_classes == 2
8181 if self .num_classes != 2 :
8282 raise ValueError (
8383 f"Single-channel input requires num_classes=2, but got num_classes={ self .num_classes } "
8484 )
85-
85+
8686 # Convert to probability if needed
8787 if not is_converted_to_prob :
8888 y_pred = torch .sigmoid (y_pred )
8989 is_converted_to_prob = True
90-
90+
9191 # Expand to 2 channels
9292 y_pred = torch .cat ([1 - y_pred , y_pred ], dim = 1 )
93-
93+
9494 # Convert y_true to one-hot with 2 classes (matching expanded y_pred)
9595 if y_true .shape [1 ] == 1 :
9696 y_true = one_hot (y_true , num_classes = 2 )
@@ -213,23 +213,23 @@ def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor:
213213
214214 # Track whether we've already converted to probabilities
215215 is_converted_to_prob = self .is_prob_input
216-
216+
217217 # Handle single-channel binary case
218218 if y_pred .shape [1 ] == 1 :
219219 # For binary segmentation, enforce num_classes == 2
220220 if self .num_classes != 2 :
221221 raise ValueError (
222222 f"Single-channel input requires num_classes=2, but got num_classes={ self .num_classes } "
223223 )
224-
224+
225225 # Convert to probability if needed
226226 if not is_converted_to_prob :
227227 y_pred = torch .sigmoid (y_pred )
228228 is_converted_to_prob = True
229-
229+
230230 # Expand to 2 channels
231231 y_pred = torch .cat ([1 - y_pred , y_pred ], dim = 1 )
232-
232+
233233 # Convert y_true to one-hot with 2 classes (matching expanded y_pred)
234234 if y_true .shape [1 ] == 1 :
235235 y_true = one_hot (y_true , num_classes = 2 )
@@ -334,13 +334,13 @@ def __init__(
334334 >>> fl(pred, grnd)
335335 """
336336 super ().__init__ (reduction = LossReduction (reduction ).value )
337-
337+
338338 # Validate configuration
339339 if use_softmax and num_classes < 2 :
340340 raise ValueError (
341341 f"use_softmax=True requires num_classes >= 2, but got num_classes={ num_classes } "
342342 )
343-
343+
344344 self .to_onehot_y = to_onehot_y
345345 self .num_classes = num_classes
346346 self .gamma = gamma
@@ -382,18 +382,18 @@ def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor:
382382 ValueError: When the number of classes in ground truth exceeds the configured `num_classes`.
383383 ValueError: When use_softmax=True with single-channel input.
384384 """
385-
385+
386386 # Validate input dimensions
387387 if len (y_pred .shape ) not in [4 , 5 ]:
388388 raise ValueError (f"Input shape must be 4 (2D) or 5 (3D), but got { y_pred .shape } " )
389-
389+
390390 # Validate use_softmax with single channel
391391 if self .use_softmax and y_pred .shape [1 ] == 1 :
392392 raise ValueError (
393393 "use_softmax=True is invalid with single-channel input (C=1). "
394394 "Use use_softmax=False for binary segmentation with sigmoid activation."
395395 )
396-
396+
397397 # Shape validation
398398 if y_pred .shape != y_true .shape :
399399 # Allow mismatch only for valid cases
@@ -413,7 +413,7 @@ def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor:
413413 raise ValueError (
414414 f"Ground truth contains class indices >= { self .num_classes } , which exceeds num_classes."
415415 )
416-
416+
417417 # Get losses from sub-losses
418418 # AsymmetricFocalLoss with NONE reduction returns (B, C, H, W, [D])
419419 # AsymmetricFocalTverskyLoss with NONE reduction returns (B, C)
@@ -442,4 +442,4 @@ def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor:
442442 else :
443443 loss = torch .mean (combined_loss )
444444
445- return loss
445+ return loss
0 commit comments