Skip to content

Commit 50dc037

Browse files
committed
another test to test without the background
Signed-off-by: ytl0623 <[email protected]>
1 parent 11a4f46 commit 50dc037

File tree

2 files changed

+37
-35
lines changed

2 files changed

+37
-35
lines changed

monai/losses/focal_loss.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -168,7 +168,9 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
168168
if self.use_softmax:
169169
if not self.include_background and self.alpha is not None:
170170
if isinstance(self.alpha, (float, int)):
171-
warnings.warn("`include_background=False`, scalar `alpha` ignored when using softmax.")
171+
warnings.warn(
172+
"`include_background=False`, scalar `alpha` ignored when using softmax.", stacklevel=2
173+
)
172174
loss = softmax_focal_loss(input, target, self.gamma, self.alpha)
173175
else:
174176
loss = sigmoid_focal_loss(input, target, self.gamma, self.alpha)

tests/losses/test_focal_loss.py

Lines changed: 34 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121

2222
from monai.losses import FocalLoss
2323
from monai.networks import one_hot
24-
from tests.test_utils import test_script_save
24+
from tests.test_utils import test_script_save, TEST_DEVICES
2525

2626
TEST_CASES = []
2727
for device in ["cpu", "cuda"] if torch.cuda.is_available() else ["cpu"]:
@@ -77,6 +77,13 @@
7777
TEST_CASES.append([{"to_onehot_y": True, "use_softmax": True}, input_data, 0.16276])
7878
TEST_CASES.append([{"to_onehot_y": True, "alpha": 0.8, "use_softmax": True}, input_data, 0.08138])
7979

80+
TEST_ALPHA_BROADCASTING = []
81+
for case in TEST_DEVICES:
82+
device = case[0]
83+
for include_background in [True, False]:
84+
for use_softmax in [True, False]:
85+
TEST_ALPHA_BROADCASTING.append([device, include_background, use_softmax])
86+
8087

8188
class TestFocalLoss(unittest.TestCase):
8289
@parameterized.expand(TEST_CASES)
@@ -374,46 +381,39 @@ def test_script(self):
374381
test_input = torch.ones(2, 2, 8, 8)
375382
test_script_save(loss, test_input, test_input)
376383

377-
def test_alpha_sequence_broadcasting(self):
384+
@parameterized.expand(TEST_ALPHA_BROADCASTING)
385+
def test_alpha_sequence_broadcasting(self, device, include_background, use_softmax):
378386
"""
379387
Test FocalLoss with alpha as a sequence for proper broadcasting.
380388
"""
381389
num_classes = 3
382-
alpha_seq = [0.1, 0.5, 2.0]
383390
batch_size = 2
384391
spatial_dims = (4, 4)
385392

386-
devices = ["cpu", "cuda"] if torch.cuda.is_available() else ["cpu"]
387-
388-
for device in devices:
389-
logits = torch.randn(batch_size, num_classes, *spatial_dims, device=device)
390-
target = torch.randint(0, num_classes, (batch_size, 1, *spatial_dims), device=device)
391-
392-
# Case 1: Softmax + Alpha Sequence
393-
loss_func_softmax = FocalLoss(
394-
to_onehot_y=True, gamma=2.0, alpha=alpha_seq, use_softmax=True, reduction="mean"
395-
)
396-
loss_soft = loss_func_softmax(logits, target)
397-
398-
self.assertTrue(torch.is_tensor(loss_soft))
399-
self.assertEqual(loss_soft.ndim, 0)
400-
self.assertTrue(loss_soft > 0, f"Softmax loss on {device} should be positive")
401-
402-
# Case 2: Sigmoid + Alpha Sequence
403-
loss_func_sigmoid = FocalLoss(
404-
to_onehot_y=True, gamma=2.0, alpha=alpha_seq, use_softmax=False, reduction="mean"
405-
)
406-
loss_sig = loss_func_sigmoid(logits, target)
407-
408-
self.assertTrue(torch.is_tensor(loss_sig))
409-
self.assertEqual(loss_sig.ndim, 0)
410-
self.assertTrue(loss_sig > 0, f"Sigmoid loss on {device} should be positive")
411-
412-
# Case 3: Error Handling (Mismatched alpha length)
413-
if device == devices[0]:
414-
wrong_alpha = [0.1, 0.5]
415-
with self.assertRaisesRegex(ValueError, "length of alpha"):
416-
FocalLoss(to_onehot_y=True, alpha=wrong_alpha, use_softmax=True)(logits, target)
393+
logits = torch.randn(batch_size, num_classes, *spatial_dims, device=device)
394+
target = torch.randint(0, num_classes, (batch_size, 1, *spatial_dims), device=device)
395+
396+
if include_background:
397+
alpha_seq = [0.1, 0.5, 2.0]
398+
else:
399+
alpha_seq = [0.5, 2.0]
400+
401+
loss_func = FocalLoss(
402+
to_onehot_y=True,
403+
gamma=2.0,
404+
alpha=alpha_seq,
405+
include_background=include_background,
406+
use_softmax=use_softmax,
407+
reduction="mean",
408+
)
409+
410+
result = loss_func(logits, target)
411+
412+
self.assertTrue(torch.is_tensor(result))
413+
self.assertEqual(result.ndim, 0)
414+
self.assertTrue(
415+
result > 0, f"Loss should be positive. params: dev={device}, bg={include_background}, softmax={use_softmax}"
416+
)
417417

418418

419419
if __name__ == "__main__":

0 commit comments

Comments
 (0)