|
21 | 21 |
|
22 | 22 | from monai.losses import FocalLoss |
23 | 23 | from monai.networks import one_hot |
24 | | -from tests.test_utils import test_script_save |
| 24 | +from tests.test_utils import test_script_save, TEST_DEVICES |
25 | 25 |
|
26 | 26 | TEST_CASES = [] |
27 | 27 | for device in ["cpu", "cuda"] if torch.cuda.is_available() else ["cpu"]: |
|
77 | 77 | TEST_CASES.append([{"to_onehot_y": True, "use_softmax": True}, input_data, 0.16276]) |
78 | 78 | TEST_CASES.append([{"to_onehot_y": True, "alpha": 0.8, "use_softmax": True}, input_data, 0.08138]) |
79 | 79 |
|
| 80 | +TEST_ALPHA_BROADCASTING = [] |
| 81 | +for case in TEST_DEVICES: |
| 82 | + device = case[0] |
| 83 | + for include_background in [True, False]: |
| 84 | + for use_softmax in [True, False]: |
| 85 | + TEST_ALPHA_BROADCASTING.append([device, include_background, use_softmax]) |
| 86 | + |
80 | 87 |
|
81 | 88 | class TestFocalLoss(unittest.TestCase): |
82 | 89 | @parameterized.expand(TEST_CASES) |
@@ -374,46 +381,39 @@ def test_script(self): |
374 | 381 | test_input = torch.ones(2, 2, 8, 8) |
375 | 382 | test_script_save(loss, test_input, test_input) |
376 | 383 |
|
377 | | - def test_alpha_sequence_broadcasting(self): |
| 384 | + @parameterized.expand(TEST_ALPHA_BROADCASTING) |
| 385 | + def test_alpha_sequence_broadcasting(self, device, include_background, use_softmax): |
378 | 386 | """ |
379 | 387 | Test FocalLoss with alpha as a sequence for proper broadcasting. |
380 | 388 | """ |
381 | 389 | num_classes = 3 |
382 | | - alpha_seq = [0.1, 0.5, 2.0] |
383 | 390 | batch_size = 2 |
384 | 391 | spatial_dims = (4, 4) |
385 | 392 |
|
386 | | - devices = ["cpu", "cuda"] if torch.cuda.is_available() else ["cpu"] |
387 | | - |
388 | | - for device in devices: |
389 | | - logits = torch.randn(batch_size, num_classes, *spatial_dims, device=device) |
390 | | - target = torch.randint(0, num_classes, (batch_size, 1, *spatial_dims), device=device) |
391 | | - |
392 | | - # Case 1: Softmax + Alpha Sequence |
393 | | - loss_func_softmax = FocalLoss( |
394 | | - to_onehot_y=True, gamma=2.0, alpha=alpha_seq, use_softmax=True, reduction="mean" |
395 | | - ) |
396 | | - loss_soft = loss_func_softmax(logits, target) |
397 | | - |
398 | | - self.assertTrue(torch.is_tensor(loss_soft)) |
399 | | - self.assertEqual(loss_soft.ndim, 0) |
400 | | - self.assertTrue(loss_soft > 0, f"Softmax loss on {device} should be positive") |
401 | | - |
402 | | - # Case 2: Sigmoid + Alpha Sequence |
403 | | - loss_func_sigmoid = FocalLoss( |
404 | | - to_onehot_y=True, gamma=2.0, alpha=alpha_seq, use_softmax=False, reduction="mean" |
405 | | - ) |
406 | | - loss_sig = loss_func_sigmoid(logits, target) |
407 | | - |
408 | | - self.assertTrue(torch.is_tensor(loss_sig)) |
409 | | - self.assertEqual(loss_sig.ndim, 0) |
410 | | - self.assertTrue(loss_sig > 0, f"Sigmoid loss on {device} should be positive") |
411 | | - |
412 | | - # Case 3: Error Handling (Mismatched alpha length) |
413 | | - if device == devices[0]: |
414 | | - wrong_alpha = [0.1, 0.5] |
415 | | - with self.assertRaisesRegex(ValueError, "length of alpha"): |
416 | | - FocalLoss(to_onehot_y=True, alpha=wrong_alpha, use_softmax=True)(logits, target) |
| 393 | + logits = torch.randn(batch_size, num_classes, *spatial_dims, device=device) |
| 394 | + target = torch.randint(0, num_classes, (batch_size, 1, *spatial_dims), device=device) |
| 395 | + |
| 396 | + if include_background: |
| 397 | + alpha_seq = [0.1, 0.5, 2.0] |
| 398 | + else: |
| 399 | + alpha_seq = [0.5, 2.0] |
| 400 | + |
| 401 | + loss_func = FocalLoss( |
| 402 | + to_onehot_y=True, |
| 403 | + gamma=2.0, |
| 404 | + alpha=alpha_seq, |
| 405 | + include_background=include_background, |
| 406 | + use_softmax=use_softmax, |
| 407 | + reduction="mean", |
| 408 | + ) |
| 409 | + |
| 410 | + result = loss_func(logits, target) |
| 411 | + |
| 412 | + self.assertTrue(torch.is_tensor(result)) |
| 413 | + self.assertEqual(result.ndim, 0) |
| 414 | + self.assertTrue( |
| 415 | + result > 0, f"Loss should be positive. params: dev={device}, bg={include_background}, softmax={use_softmax}" |
| 416 | + ) |
417 | 417 |
|
418 | 418 |
|
419 | 419 | if __name__ == "__main__": |
|
0 commit comments