Skip to content

Commit a89727d

Browse files
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
1 parent ecf33eb commit a89727d

File tree

1 file changed

+16
-16
lines changed

1 file changed

+16
-16
lines changed

monai/losses/unified_focal_loss.py

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -71,26 +71,26 @@ def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor:
7171
y_pred: prediction logits or probabilities. Shape should be (B, C, spatial_dims).
7272
y_true: ground truth labels. Shape should match y_pred.
7373
"""
74-
74+
7575
# Track whether we've already converted to probabilities
7676
is_converted_to_prob = self.is_prob_input
77-
77+
7878
# Handle single-channel binary case
7979
if y_pred.shape[1] == 1:
8080
# For binary segmentation, enforce num_classes == 2
8181
if self.num_classes != 2:
8282
raise ValueError(
8383
f"Single-channel input requires num_classes=2, but got num_classes={self.num_classes}"
8484
)
85-
85+
8686
# Convert to probability if needed
8787
if not is_converted_to_prob:
8888
y_pred = torch.sigmoid(y_pred)
8989
is_converted_to_prob = True
90-
90+
9191
# Expand to 2 channels
9292
y_pred = torch.cat([1 - y_pred, y_pred], dim=1)
93-
93+
9494
# Convert y_true to one-hot with 2 classes (matching expanded y_pred)
9595
if y_true.shape[1] == 1:
9696
y_true = one_hot(y_true, num_classes=2)
@@ -213,23 +213,23 @@ def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor:
213213

214214
# Track whether we've already converted to probabilities
215215
is_converted_to_prob = self.is_prob_input
216-
216+
217217
# Handle single-channel binary case
218218
if y_pred.shape[1] == 1:
219219
# For binary segmentation, enforce num_classes == 2
220220
if self.num_classes != 2:
221221
raise ValueError(
222222
f"Single-channel input requires num_classes=2, but got num_classes={self.num_classes}"
223223
)
224-
224+
225225
# Convert to probability if needed
226226
if not is_converted_to_prob:
227227
y_pred = torch.sigmoid(y_pred)
228228
is_converted_to_prob = True
229-
229+
230230
# Expand to 2 channels
231231
y_pred = torch.cat([1 - y_pred, y_pred], dim=1)
232-
232+
233233
# Convert y_true to one-hot with 2 classes (matching expanded y_pred)
234234
if y_true.shape[1] == 1:
235235
y_true = one_hot(y_true, num_classes=2)
@@ -334,13 +334,13 @@ def __init__(
334334
>>> fl(pred, grnd)
335335
"""
336336
super().__init__(reduction=LossReduction(reduction).value)
337-
337+
338338
# Validate configuration
339339
if use_softmax and num_classes < 2:
340340
raise ValueError(
341341
f"use_softmax=True requires num_classes >= 2, but got num_classes={num_classes}"
342342
)
343-
343+
344344
self.to_onehot_y = to_onehot_y
345345
self.num_classes = num_classes
346346
self.gamma = gamma
@@ -382,18 +382,18 @@ def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor:
382382
ValueError: When the number of classes in ground truth exceeds the configured `num_classes`.
383383
ValueError: When use_softmax=True with single-channel input.
384384
"""
385-
385+
386386
# Validate input dimensions
387387
if len(y_pred.shape) not in [4, 5]:
388388
raise ValueError(f"Input shape must be 4 (2D) or 5 (3D), but got {y_pred.shape}")
389-
389+
390390
# Validate use_softmax with single channel
391391
if self.use_softmax and y_pred.shape[1] == 1:
392392
raise ValueError(
393393
"use_softmax=True is invalid with single-channel input (C=1). "
394394
"Use use_softmax=False for binary segmentation with sigmoid activation."
395395
)
396-
396+
397397
# Shape validation
398398
if y_pred.shape != y_true.shape:
399399
# Allow mismatch only for valid cases
@@ -413,7 +413,7 @@ def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor:
413413
raise ValueError(
414414
f"Ground truth contains class indices >= {self.num_classes}, which exceeds num_classes."
415415
)
416-
416+
417417
# Get losses from sub-losses
418418
# AsymmetricFocalLoss with NONE reduction returns (B, C, H, W, [D])
419419
# AsymmetricFocalTverskyLoss with NONE reduction returns (B, C)
@@ -442,4 +442,4 @@ def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor:
442442
else:
443443
loss = torch.mean(combined_loss)
444444

445-
return loss
445+
return loss

0 commit comments

Comments
 (0)