Skip to content

Commit 015a894

Browse files
committed
add test case for alpha as a sequence
Signed-off-by: ytl0623 <[email protected]>
1 parent 5043ac9 commit 015a894

File tree

2 files changed

+42
-3
lines changed

2 files changed

+42
-3
lines changed

monai/losses/focal_loss.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -165,9 +165,7 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
165165
input = input.float()
166166
target = target.float()
167167

168-
alpha_arg: float | torch.Tensor | None = self.alpha
169-
if isinstance(alpha_arg, torch.Tensor):
170-
alpha_arg = alpha_arg.to(input.device)
168+
alpha_arg = self.alpha
171169

172170
if self.use_softmax:
173171
if not self.include_background and self.alpha is not None:

tests/losses/test_focal_loss.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -374,6 +374,47 @@ def test_script(self):
374374
test_input = torch.ones(2, 2, 8, 8)
375375
test_script_save(loss, test_input, test_input)
376376

377+
def test_alpha_sequence_broadcasting(self):
378+
"""
379+
Test FocalLoss with alpha as a sequence for proper broadcasting.
380+
"""
381+
num_classes = 3
382+
alpha_seq = [0.1, 0.5, 2.0]
383+
batch_size = 2
384+
spatial_dims = (4, 4)
385+
386+
devices = ["cpu", "cuda"] if torch.cuda.is_available() else ["cpu"]
387+
388+
for device in devices:
389+
logits = torch.randn(batch_size, num_classes, *spatial_dims, device=device)
390+
target = torch.randint(0, num_classes, (batch_size, 1, *spatial_dims), device=device)
391+
392+
# Case 1: Softmax + Alpha Sequence
393+
loss_func_softmax = FocalLoss(
394+
to_onehot_y=True, gamma=2.0, alpha=alpha_seq, use_softmax=True, reduction="mean"
395+
)
396+
loss_soft = loss_func_softmax(logits, target)
397+
398+
self.assertTrue(torch.is_tensor(loss_soft))
399+
self.assertEqual(loss_soft.ndim, 0)
400+
self.assertTrue(loss_soft > 0, f"Softmax loss on {device} should be positive")
401+
402+
# Case 2: Sigmoid + Alpha Sequence
403+
loss_func_sigmoid = FocalLoss(
404+
to_onehot_y=True, gamma=2.0, alpha=alpha_seq, use_softmax=False, reduction="mean"
405+
)
406+
loss_sig = loss_func_sigmoid(logits, target)
407+
408+
self.assertTrue(torch.is_tensor(loss_sig))
409+
self.assertEqual(loss_sig.ndim, 0)
410+
self.assertTrue(loss_sig > 0, f"Sigmoid loss on {device} should be positive")
411+
412+
# Case 3: Error Handling (Mismatched alpha length)
413+
if device == devices[0]:
414+
wrong_alpha = [0.1, 0.5]
415+
with self.assertRaisesRegex(ValueError, "length of alpha"):
416+
FocalLoss(to_onehot_y=True, alpha=wrong_alpha, use_softmax=True)(logits, target)
417+
377418

378419
if __name__ == "__main__":
379420
unittest.main()

0 commit comments

Comments
 (0)