Skip to content

Commit 8368ef2

Browse files
committed
fix bugs
Signed-off-by: ytl0623 <[email protected]>
1 parent 2066793 commit 8368ef2

File tree

2 files changed

+51
-44
lines changed

2 files changed

+51
-44
lines changed

monai/losses/unified_focal_loss.py

Lines changed: 38 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -66,12 +66,11 @@ def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor:
6666
y_true: ground truth labels. Shape should match y_pred.
6767
"""
6868

69-
# Auto-handle single channel input (binary segmentation case)
70-
if y_pred.shape[1] == 1 and not self.use_softmax:
69+
if y_pred.shape[1] == 1:
7170
y_pred = torch.sigmoid(y_pred)
7271
y_pred = torch.cat([1 - y_pred, y_pred], dim=1)
7372
is_already_prob = True
74-
# Expand y_true to match if it's single channel
73+
7574
if y_true.shape[1] == 1:
7675
y_true = one_hot(y_true, num_classes=2)
7776
else:
@@ -122,12 +121,12 @@ def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor:
122121
# Apply reduction
123122
if self.reduction == LossReduction.MEAN.value:
124123
return torch.mean(all_losses)
125-
if self.reduction == LossReduction.SUM.value:
124+
elif self.reduction == LossReduction.SUM.value:
126125
return torch.sum(all_losses)
127-
if self.reduction == LossReduction.NONE.value:
126+
elif self.reduction == LossReduction.NONE.value:
128127
return all_losses
129-
130-
return torch.mean(all_losses)
128+
else:
129+
return torch.mean(all_losses)
131130

132131

133132
class AsymmetricFocalLoss(_Loss):
@@ -253,6 +252,14 @@ def __init__(
253252
delta: background/foreground balancing weight. Defaults to 0.7.
254253
reduction: specifies the reduction to apply to the output. Defaults to "mean".
255254
use_softmax: whether to use softmax for probability conversion. Defaults to False.
255+
256+
Example:
257+
>>> import torch
258+
>>> from monai.losses import AsymmetricUnifiedFocalLoss
259+
>>> pred = torch.ones((1,1,32,32), dtype=torch.float32)
260+
>>> grnd = torch.ones((1,1,32,32), dtype=torch.int64)
261+
>>> fl = AsymmetricUnifiedFocalLoss(to_onehot_y=True)
262+
>>> fl(pred, grnd)
256263
"""
257264
super().__init__(reduction=LossReduction(reduction).value)
258265
self.to_onehot_y = to_onehot_y
@@ -283,30 +290,36 @@ def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor:
283290
y_pred: Prediction logits. Shape: (B, C, H, W, [D]).
284291
Supports binary (C=1 or C=2) and multi-class (C>2) segmentation.
285292
y_true: Ground truth labels. Shape should match y_pred (or be indices if to_onehot_y is True).
293+
294+
Raises:
295+
ValueError: When ground truth shape does not match input shape.
296+
ValueError: When input tensor shape is not 4D or 5D.
297+
ValueError: When the number of classes in ground truth exceeds the configured `num_classes`.
286298
"""
287299
if y_pred.shape != y_true.shape:
288-
is_binary_logits = y_pred.shape[1] == 1 and not self.use_softmax
289-
if not self.to_onehot_y and not is_binary_logits:
290-
raise ValueError(f"ground truth has different shape ({y_true.shape}) from input ({y_pred.shape})")
291-
raise ValueError(f"ground truth has different shape ({y_true.shape}) from input ({y_pred.shape})")
300+
is_binary_logits = (y_pred.shape[1] == 1) and (not self.use_softmax)
301+
is_target_needs_onehot = self.to_onehot_y and (y_true.shape[1] == 1)
292302

293-
if len(y_pred.shape) != 4 and len(y_pred.shape) != 5:
294-
raise ValueError(f"input shape must be 4 or 5, but got {y_pred.shape}")
295-
296-
if y_pred.shape[1] == 1:
297-
y_pred = one_hot(y_pred, num_classes=self.num_classes)
298-
y_true = one_hot(y_true, num_classes=self.num_classes)
303+
if not is_binary_logits and not is_target_needs_onehot:
304+
raise ValueError(
305+
f"Ground truth has different shape ({y_true.shape}) from input ({y_pred.shape}), "
306+
"and this mismatch cannot be resolved by `to_onehot_y` or binary expansion."
307+
)
299308

300-
if torch.max(y_true) != self.num_classes - 1:
301-
raise ValueError(f"Please make sure the number of classes is {self.num_classes - 1}")
309+
if len(y_pred.shape) not in [4, 5]:
310+
raise ValueError(f"Input shape must be 4 (2D) or 5 (3D), but got {y_pred.shape}")
302311

303-
n_pred_ch = y_pred.shape[1]
304312
if self.to_onehot_y:
305-
if n_pred_ch == 1:
306-
warnings.warn("single channel prediction, `to_onehot_y=True` ignored.")
307-
else:
308-
y_true = one_hot(y_true, num_classes=n_pred_ch)
309-
313+
# Only convert if y_true is single channel (Indices)
314+
if y_true.shape[1] == 1:
315+
# Check indices validity before conversion
316+
if torch.max(y_true) >= self.num_classes:
317+
raise ValueError(
318+
f"Ground truth contains class indices >= {self.num_classes}, which exceeds num_classes."
319+
)
320+
321+
# Convert to One-hot
322+
y_true = one_hot(y_true, num_classes=self.num_classes)
310323
asy_focal_loss = self.asy_focal_loss(y_pred, y_true)
311324
asy_focal_tversky_loss = self.asy_focal_tversky_loss(y_pred, y_true)
312325

tests/losses/test_unified_focal_loss.py

Lines changed: 13 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -20,24 +20,20 @@
2020
from monai.losses import AsymmetricUnifiedFocalLoss
2121

2222
# 1. Binary Case (Logits input): Prediction matches GT perfectly
23-
# Input Shape: (B, 1, H, W) -> Auto expanded internally
2423
TEST_CASE_BINARY_LOGITS = [
2524
{"y_pred": torch.tensor([[[[10.0, -10.0], [-10.0, 10.0]]]]), "y_true": torch.tensor([[[[1.0, 0.0], [0.0, 1.0]]]])},
2625
0.0,
27-
{"use_softmax": False, "to_onehot_y": False},
26+
{"use_softmax": False, "to_onehot_y": False, "num_classes": 2},
2827
]
2928

3029
# 2. Binary Case (2 Channels input): Prediction matches GT perfectly
31-
# Input Shape: (B, 2, H, W)
3230
TEST_CASE_BINARY_2CH = [
3331
{
34-
"y_pred": torch.tensor(
35-
[[[[-10.0, 10.0], [10.0, -10.0]], [[10.0, -10.0], [-10.0, 10.0]]]] # Ch0 (Background): Low, High, High, Low
36-
), # Ch1 (Foreground): High, Low, Low, High
32+
"y_pred": torch.tensor([[[[-10.0, 10.0], [10.0, -10.0]], [[10.0, -10.0], [-10.0, 10.0]]]]),
3733
"y_true": torch.tensor([[[[1, 0], [0, 1]]]]),
3834
},
3935
0.0,
40-
{"use_softmax": True, "to_onehot_y": True},
36+
{"use_softmax": True, "to_onehot_y": True, "num_classes": 2},
4137
]
4238

4339
# 3. Multi-Class Case (3 Channels): Prediction matches GT perfectly
@@ -46,16 +42,16 @@
4642
"y_pred": torch.tensor(
4743
[
4844
[
49-
[[10.0, -10.0], [-10.0, 10.0]], # Class 0 Logits
50-
[[-10.0, 10.0], [-10.0, -10.0]], # Class 1 Logits
51-
[[-10.0, -10.0], [10.0, -10.0]],
45+
[[10.0, -10.0], [-10.0, 10.0]], # Class 0
46+
[[-10.0, 10.0], [-10.0, -10.0]], # Class 1
47+
[[-10.0, -10.0], [10.0, -10.0]], # Class 2
5248
]
5349
]
54-
), # Class 2 Logits
55-
"y_true": torch.tensor([[[[0, 1], [2, 0]]]]), # Indices
50+
),
51+
"y_true": torch.tensor([[[[0, 1], [2, 0]]]]),
5652
},
5753
0.0,
58-
{"use_softmax": True, "to_onehot_y": True},
54+
{"use_softmax": True, "to_onehot_y": True, "num_classes": 3},
5955
]
6056

6157
# 4. Multi-Class Case: Wrong Prediction
@@ -64,10 +60,10 @@
6460
"y_pred": torch.tensor(
6561
[[[[-10.0, -10.0], [-10.0, -10.0]], [[10.0, 10.0], [10.0, 10.0]], [[-10.0, -10.0], [-10.0, -10.0]]]]
6662
),
67-
"y_true": torch.tensor([[[[0, 0], [0, 0]]]]), # GT is class 0, but Pred is class 1
63+
"y_true": torch.tensor([[[[0, 0], [0, 0]]]]),
6864
},
6965
None,
70-
{"use_softmax": True, "to_onehot_y": True},
66+
{"use_softmax": True, "to_onehot_y": True, "num_classes": 3},
7167
]
7268

7369

@@ -77,11 +73,11 @@ class TestAsymmetricUnifiedFocalLoss(unittest.TestCase):
7773
def test_perfect_prediction(self, input_data, expected_val, args):
7874
loss_func = AsymmetricUnifiedFocalLoss(**args)
7975
result = loss_func(**input_data)
80-
# We use a small tolerance because 10.0 logits is not exactly probability 1.0
76+
# Using a relaxed tolerance for logits -> probability conversion
8177
np.testing.assert_allclose(result.detach().cpu().numpy(), expected_val, atol=1e-3, rtol=1e-3)
8278

8379
@parameterized.expand([TEST_CASE_MULTICLASS_WRONG])
84-
def test_wrong_prediction(self, input_data, expected_val, args):
80+
def test_wrong_prediction(self, input_data, _, args):
8581
loss_func = AsymmetricUnifiedFocalLoss(**args)
8682
result = loss_func(**input_data)
8783
self.assertGreater(result.item(), 1.0, "Loss should be high for wrong predictions")
@@ -93,7 +89,6 @@ def test_ill_shape(self):
9389

9490
def test_with_cuda(self):
9591
if not torch.cuda.is_available():
96-
print("CUDA not available, skipping test_with_cuda")
9792
return
9893

9994
loss = AsymmetricUnifiedFocalLoss(use_softmax=False, to_onehot_y=False)
@@ -102,7 +97,6 @@ def test_with_cuda(self):
10297
j = torch.tensor([[[[1.0, 0], [0, 1.0]]], [[[1.0, 0], [0, 1.0]]]]).cuda()
10398

10499
output = loss(i, j)
105-
print(f"CUDA Output: {output.item()}")
106100
self.assertTrue(output.is_cuda)
107101
self.assertLess(output.item(), 1.0)
108102

0 commit comments

Comments
 (0)