Skip to content

Commit e63e36e

Browse files
committed
update test_unified_focal_loss.py
Signed-off-by: ytl0623 <[email protected]>
1 parent 1fba9d3 commit e63e36e

File tree

2 files changed

+86
-15
lines changed

2 files changed

+86
-15
lines changed

monai/losses/unified_focal_loss.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor:
7373

7474
if not self.include_background:
7575
if n_pred_ch == 1:
76-
warnings.warn("single channel prediction, `include_background=False` ignored.")
76+
warnings.warn("single channel prediction, `include_background=False` ignored.", stacklevel=2)
7777
else:
7878
# if skipping background, removing first channel
7979
y_true = y_true[:, 1:]
@@ -110,8 +110,8 @@ def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor:
110110
loss_list.append(1 - dice_class[:, i])
111111
else:
112112
# Foreground classes: apply focal modulation
113-
# Original logic: (1 - dice) * (1 - dice)^(-gamma) -> (1-dice)^(1-gamma)
114-
loss_list.append((1 - dice_class[:, i]) * torch.pow(1 - dice_class[:, i], -self.gamma))
113+
back_dice = torch.clamp(1 - dice_class[:, i], min=self.epsilon)
114+
loss_list.append(back_dice * torch.pow(back_dice, -self.gamma))
115115

116116
loss = torch.stack(loss_list, dim=-1)
117117

@@ -150,11 +150,13 @@ def __init__(
150150
to_onehot_y: whether to convert `y` into the one-hot format. Defaults to False.
151151
use_softmax: whether to use softmax to transform the original logits into probabilities.
152152
If True, softmax is used. If False, sigmoid is used. Defaults to False.
153+
softmax for mutually exclusive classes (standard multi-class).
154+
sigmoid for multi-label/overlapping classes.
153155
reduction: Specifies the reduction to apply to the output. Defaults to ``"mean"``.
154156
"""
155157
super().__init__(reduction=LossReduction(reduction).value)
156158
self.weight = weight
157-
self.use_softmax = use_softmax # 儲存參數
159+
self.use_softmax = use_softmax
158160

159161
self.focal_loss = FocalLoss(
160162
include_background=include_background,
@@ -175,8 +177,11 @@ def __init__(
175177
def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor:
176178
"""
177179
Args:
178-
y_pred: (BNH[WD]) Logits (raw scores).
180+
y_pred: (BNH[WD]) Logits (raw scores, not probabilities).
181+
Do not pass pre-activated inputs; activation is applied internally.
179182
y_true: (BNH[WD]) Ground truth labels.
183+
Returns:
184+
torch.Tensor: Weighted combination of focal loss and asymmetric focal Tversky loss.
180185
"""
181186
focal_loss = self.focal_loss(y_pred, y_true)
182187

tests/losses/test_unified_focal_loss.py

Lines changed: 76 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -19,18 +19,74 @@
1919

2020
from monai.losses import AsymmetricUnifiedFocalLoss
2121

22+
logit_pos = 10.0
23+
logit_neg = -10.0
24+
2225
TEST_CASES = [
23-
[ # shape: (2, 1, 2, 2), (2, 1, 2, 2)
26+
[ # Case 0: Binary segmentation
27+
# shape: (2, 1, 2, 2), (2, 1, 2, 2)
28+
{
29+
"use_softmax": False,
30+
"include_background": True,
31+
},
2432
{
25-
"y_pred": torch.tensor([[[[1.0, 0], [0, 1.0]]], [[[1.0, 0], [0, 1.0]]]]),
26-
"y_true": torch.tensor([[[[1.0, 0], [0, 1.0]]], [[[1.0, 0], [0, 1.0]]]]),
33+
"y_pred": torch.tensor(
34+
[[[[logit_pos, logit_neg], [logit_neg, logit_pos]]], [[[logit_pos, logit_neg], [logit_neg, logit_pos]]]]
35+
),
36+
"y_true": torch.tensor([[[[1.0, 0.0], [0.0, 1.0]]], [[[1.0, 0.0], [0.0, 1.0]]]]),
2737
},
2838
0.0,
2939
],
30-
[ # shape: (2, 1, 2, 2), (2, 1, 2, 2)
40+
[ # Case 1: Multi-class segmentation with softmax
41+
# shape: (1, 3, 2, 2), (1, 3, 2, 2)
3142
{
32-
"y_pred": torch.tensor([[[[1.0, 0], [0, 1.0]]], [[[1.0, 0], [0, 1.0]]]]),
33-
"y_true": torch.tensor([[[[1.0, 0], [0, 1.0]]], [[[1.0, 0], [0, 1.0]]]]),
43+
"use_softmax": True,
44+
"include_background": True,
45+
},
46+
{
47+
"y_pred": torch.tensor(
48+
[
49+
[
50+
[[logit_pos, logit_neg], [logit_neg, logit_neg]], # Class 0 (background)
51+
[[logit_neg, logit_pos], [logit_neg, logit_neg]], # Class 1
52+
[[logit_neg, logit_neg], [logit_pos, logit_pos]], # Class 2
53+
]
54+
]
55+
),
56+
"y_true": torch.tensor(
57+
[
58+
[
59+
[[1.0, 0.0], [0.0, 0.0]], # Class 0 (background)
60+
[[0.0, 1.0], [0.0, 0.0]], # Class 1
61+
[[0.0, 0.0], [1.0, 1.0]], # Class 2
62+
]
63+
]
64+
),
65+
},
66+
0.0,
67+
],
68+
[ # Case 2: Multi-class with background excluded
69+
# shape: (1, 3, 2, 2), (1, 3, 2, 2)
70+
{"use_softmax": True, "include_background": False},
71+
{
72+
"y_pred": torch.tensor(
73+
[
74+
[
75+
[[logit_pos, logit_neg], [logit_neg, logit_neg]], # Class 0 (background)
76+
[[logit_neg, logit_pos], [logit_pos, logit_neg]], # Class 1 (foreground)
77+
[[logit_neg, logit_neg], [logit_neg, logit_pos]], # Class 2 (foreground)
78+
]
79+
]
80+
),
81+
"y_true": torch.tensor(
82+
[
83+
[
84+
[[1.0, 0.0], [0.0, 0.0]], # Class 0 (background)
85+
[[0.0, 1.0], [1.0, 0.0]], # Class 1 (foreground)
86+
[[0.0, 0.0], [0.0, 1.0]], # Class 2 (foreground)
87+
]
88+
]
89+
),
3490
},
3591
0.0,
3692
],
@@ -40,8 +96,16 @@
4096
class TestAsymmetricUnifiedFocalLoss(unittest.TestCase):
4197

4298
@parameterized.expand(TEST_CASES)
43-
def test_result(self, input_data, expected_val):
44-
loss = AsymmetricUnifiedFocalLoss()
99+
def test_result(self, input_param, input_data, expected_val):
100+
"""
101+
Test AsymmetricUnifiedFocalLoss with various configurations.
102+
103+
Args:
104+
input_param: Dict of loss constructor parameters (use_softmax, include_background, etc.).
105+
input_data: Dict containing y_pred (logits) and y_true (ground truth) tensors.
106+
expected_val: Expected loss value.
107+
"""
108+
loss = AsymmetricUnifiedFocalLoss(**input_param)
45109
result = loss(**input_data)
46110
np.testing.assert_allclose(result.detach().cpu().numpy(), expected_val, atol=1e-4, rtol=1e-4)
47111

@@ -52,8 +116,10 @@ def test_ill_shape(self):
52116

53117
def test_with_cuda(self):
54118
loss = AsymmetricUnifiedFocalLoss()
55-
i = torch.tensor([[[[1.0, 0], [0, 1.0]]], [[[1.0, 0], [0, 1.0]]]])
56-
j = torch.tensor([[[[1.0, 0], [0, 1.0]]], [[[1.0, 0], [0, 1.0]]]])
119+
i = torch.tensor(
120+
[[[[logit_pos, logit_neg], [logit_neg, logit_pos]]], [[[logit_pos, logit_neg], [logit_neg, logit_pos]]]]
121+
)
122+
j = torch.tensor([[[[1.0, 0.0], [0.0, 1.0]]], [[[1.0, 0.0], [0.0, 1.0]]]])
57123
if torch.cuda.is_available():
58124
i = i.cuda()
59125
j = j.cuda()

0 commit comments

Comments
 (0)