Skip to content

Commit c550c29

Browse files
committed
Correct masked mean computation in AUCMLoss and update docstrings
Signed-off-by: Shubham Chandravanshi <[email protected]>
1 parent f1d38f4 commit c550c29

File tree

2 files changed

+27
-11
lines changed

2 files changed

+27
-11
lines changed

monai/losses/aucm_loss.py

Lines changed: 17 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111

1212
from __future__ import annotations
1313

14-
1514
import torch
1615
import torch.nn as nn
1716
from torch.nn.modules.loss import _Loss
@@ -93,6 +92,9 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
9392
input: the shape should be B1HW[D], where the channel dimension is 1 for binary classification.
9493
target: the shape should be B1HW[D], with values 0 or 1.
9594
95+
Returns:
96+
torch.Tensor: scalar AUCM loss.
97+
9698
Raises:
9799
ValueError: When input or target have incorrect shapes.
98100
"""
@@ -112,25 +114,29 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
112114
if self.version == "v1":
113115
p = self.imratio if self.imratio is not None else pos_mask.mean()
114116
loss = (
115-
(1 - p) * self._safe_mean((input - self.a) ** 2 * pos_mask)
116-
+ p * self._safe_mean((input - self.b) ** 2 * neg_mask)
117+
(1 - p) * self._safe_mean((input - self.a) ** 2, pos_mask)
118+
+ p * self._safe_mean((input - self.b) ** 2, neg_mask)
117119
+ 2
118120
* self.alpha
119-
* (p * (1 - p) * self.margin + self._safe_mean(p * input * neg_mask - (1 - p) * input * pos_mask))
121+
* (
122+
p * (1 - p) * self.margin
123+
+ self._safe_mean(p * input * neg_mask - (1 - p) * input * pos_mask, pos_mask + neg_mask)
124+
)
120125
- p * (1 - p) * self.alpha**2
121126
)
122127
else:
123128
loss = (
124-
self._safe_mean((input - self.a) ** 2 * pos_mask)
125-
+ self._safe_mean((input - self.b) ** 2 * neg_mask)
126-
+ 2 * self.alpha * (self.margin + self._safe_mean(input * neg_mask) - self._safe_mean(input * pos_mask))
129+
self._safe_mean((input - self.a) ** 2, pos_mask)
130+
+ self._safe_mean((input - self.b) ** 2, neg_mask)
131+
+ 2 * self.alpha * (self.margin + self._safe_mean(input, neg_mask) - self._safe_mean(input, pos_mask))
127132
- self.alpha**2
128133
)
129134

130135
return loss
131136

132-
def _safe_mean(self, tensor: torch.Tensor) -> torch.Tensor:
133-
"""Compute mean safely, returning 0 if tensor is empty."""
134-
if tensor.numel() == 0 or tensor.count_nonzero() == 0:
137+
def _safe_mean(self, tensor: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
138+
"""Compute mean safely over masked elements."""
139+
denom = mask.sum()
140+
if denom == 0:
135141
return torch.tensor(0.0, device=tensor.device, dtype=tensor.dtype)
136-
return tensor.sum() / tensor.count_nonzero()
142+
return (tensor * mask).sum() / denom

tests/losses/test_aucm_loss.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,10 @@
2020

2121

2222
class TestAUCMLoss(unittest.TestCase):
23+
"""Test cases for AUCMLoss."""
24+
2325
def test_v1(self):
26+
"""Test AUCMLoss with version 'v1'."""
2427
loss_fn = AUCMLoss(version="v1")
2528
input = torch.randn(32, 1, requires_grad=True)
2629
target = torch.randint(0, 2, (32, 1)).float()
@@ -29,6 +32,7 @@ def test_v1(self):
2932
self.assertEqual(loss.ndim, 0)
3033

3134
def test_v2(self):
35+
"""Test AUCMLoss with version 'v2'."""
3236
loss_fn = AUCMLoss(version="v2")
3337
input = torch.randn(32, 1, requires_grad=True)
3438
target = torch.randint(0, 2, (32, 1)).float()
@@ -37,31 +41,36 @@ def test_v2(self):
3741
self.assertEqual(loss.ndim, 0)
3842

3943
def test_invalid_version(self):
44+
"""Test that invalid version raises ValueError."""
4045
with self.assertRaises(ValueError):
4146
AUCMLoss(version="invalid")
4247

4348
def test_invalid_input_shape(self):
49+
"""Test that invalid input shape raises ValueError."""
4450
loss_fn = AUCMLoss()
4551
input = torch.randn(32, 2) # Wrong channel
4652
target = torch.randint(0, 2, (32, 1)).float()
4753
with self.assertRaises(ValueError):
4854
loss_fn(input, target)
4955

5056
def test_invalid_target_shape(self):
57+
"""Test that invalid target shape raises ValueError."""
5158
loss_fn = AUCMLoss()
5259
input = torch.randn(32, 1)
5360
target = torch.randint(0, 2, (32, 2)).float() # Wrong channel
5461
with self.assertRaises(ValueError):
5562
loss_fn(input, target)
5663

5764
def test_shape_mismatch(self):
65+
"""Test that mismatched shapes raise ValueError."""
5866
loss_fn = AUCMLoss()
5967
input = torch.randn(32, 1)
6068
target = torch.randint(0, 2, (16, 1)).float()
6169
with self.assertRaises(ValueError):
6270
loss_fn(input, target)
6371

6472
def test_backward(self):
73+
"""Test that gradients can be computed."""
6574
loss_fn = AUCMLoss()
6675
input = torch.randn(32, 1, requires_grad=True)
6776
target = torch.randint(0, 2, (32, 1)).float()
@@ -70,6 +79,7 @@ def test_backward(self):
7079
self.assertIsNotNone(input.grad)
7180

7281
def test_script_save(self):
82+
"""Test that the loss can be saved as TorchScript."""
7383
loss_fn = AUCMLoss()
7484
test_script_save(loss_fn, torch.randn(32, 1), torch.randint(0, 2, (32, 1)).float())
7585

0 commit comments

Comments
 (0)