Skip to content

Commit 1f37d0d

Browse files
committed
add test case for alpha as a sequence
Signed-off-by: ytl0623 <[email protected]>
1 parent 5043ac9 commit 1f37d0d

File tree

2 files changed

+43
-7
lines changed

2 files changed

+43
-7
lines changed

monai/losses/focal_loss.py

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -165,18 +165,13 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
165165
input = input.float()
166166
target = target.float()
167167

168-
alpha_arg: float | torch.Tensor | None = self.alpha
169-
if isinstance(alpha_arg, torch.Tensor):
170-
alpha_arg = alpha_arg.to(input.device)
171-
172168
if self.use_softmax:
173169
if not self.include_background and self.alpha is not None:
174170
if isinstance(self.alpha, (float, int)):
175-
alpha_arg = None
176171
warnings.warn("`include_background=False`, scalar `alpha` ignored when using softmax.")
177-
loss = softmax_focal_loss(input, target, self.gamma, alpha_arg)
172+
loss = softmax_focal_loss(input, target, self.gamma, self.alpha)
178173
else:
179-
loss = sigmoid_focal_loss(input, target, self.gamma, alpha_arg)
174+
loss = sigmoid_focal_loss(input, target, self.gamma, self.alpha)
180175

181176
num_of_classes = target.shape[1]
182177
if self.class_weight is not None and num_of_classes != 1:

tests/losses/test_focal_loss.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -374,6 +374,47 @@ def test_script(self):
374374
test_input = torch.ones(2, 2, 8, 8)
375375
test_script_save(loss, test_input, test_input)
376376

377+
def test_alpha_sequence_broadcasting(self):
378+
"""
379+
Test FocalLoss with alpha as a sequence for proper broadcasting.
380+
"""
381+
num_classes = 3
382+
alpha_seq = [0.1, 0.5, 2.0]
383+
batch_size = 2
384+
spatial_dims = (4, 4)
385+
386+
devices = ["cpu", "cuda"] if torch.cuda.is_available() else ["cpu"]
387+
388+
for device in devices:
389+
logits = torch.randn(batch_size, num_classes, *spatial_dims, device=device)
390+
target = torch.randint(0, num_classes, (batch_size, 1, *spatial_dims), device=device)
391+
392+
# Case 1: Softmax + Alpha Sequence
393+
loss_func_softmax = FocalLoss(
394+
to_onehot_y=True, gamma=2.0, alpha=alpha_seq, use_softmax=True, reduction="mean"
395+
)
396+
loss_soft = loss_func_softmax(logits, target)
397+
398+
self.assertTrue(torch.is_tensor(loss_soft))
399+
self.assertEqual(loss_soft.ndim, 0)
400+
self.assertTrue(loss_soft > 0, f"Softmax loss on {device} should be positive")
401+
402+
# Case 2: Sigmoid + Alpha Sequence
403+
loss_func_sigmoid = FocalLoss(
404+
to_onehot_y=True, gamma=2.0, alpha=alpha_seq, use_softmax=False, reduction="mean"
405+
)
406+
loss_sig = loss_func_sigmoid(logits, target)
407+
408+
self.assertTrue(torch.is_tensor(loss_sig))
409+
self.assertEqual(loss_sig.ndim, 0)
410+
self.assertTrue(loss_sig > 0, f"Sigmoid loss on {device} should be positive")
411+
412+
# Case 3: Error Handling (Mismatched alpha length)
413+
if device == devices[0]:
414+
wrong_alpha = [0.1, 0.5]
415+
with self.assertRaisesRegex(ValueError, "length of alpha"):
416+
FocalLoss(to_onehot_y=True, alpha=wrong_alpha, use_softmax=True)(logits, target)
417+
377418

378419
if __name__ == "__main__":
379420
unittest.main()

0 commit comments

Comments
 (0)