Skip to content

Commit 3ee0c07

Browse files
committed
Add AUC-Margin loss for AUROC optimization (#4609)
Signed-off-by: Shubham Chandravanshi <[email protected]>
1 parent 57fdd59 commit 3ee0c07

File tree

3 files changed

+216
-0
lines changed

3 files changed

+216
-0
lines changed

monai/losses/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from __future__ import annotations
1313

1414
from .adversarial_loss import PatchAdversarialLoss
15+
from .aucm_loss import AUCMLoss
1516
from .barlow_twins import BarlowTwinsLoss
1617
from .cldice import SoftclDiceLoss, SoftDiceclDiceLoss
1718
from .contrastive import ContrastiveLoss

monai/losses/aucm_loss.py

Lines changed: 137 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,137 @@
1+
# Copyright (c) MONAI Consortium
2+
# Licensed under the Apache License, Version 2.0 (the "License");
3+
# you may not use this file except in compliance with the License.
4+
# You may obtain a copy of the License at
5+
# http://www.apache.org/licenses/LICENSE-2.0
6+
# Unless required by applicable law or agreed to in writing, software
7+
# distributed under the License is distributed on an "AS IS" BASIS,
8+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9+
# See the License for the specific language governing permissions and
10+
# limitations under the License.
11+
12+
from __future__ import annotations
13+
14+
import warnings
15+
16+
import torch
17+
import torch.nn as nn
18+
from torch.nn.modules.loss import _Loss
19+
20+
from monai.utils import LossReduction
21+
22+
23+
class AUCMLoss(_Loss):
24+
"""
25+
AUC-Margin loss with squared-hinge surrogate loss for optimizing AUROC.
26+
27+
The loss optimizes the Area Under the ROC Curve (AUROC) by using margin-based constraints
28+
on positive and negative predictions. It supports two versions: 'v1' includes class prior
29+
information, while 'v2' removes this dependency for better generalization.
30+
31+
Reference:
32+
Yuan, Zhuoning, Yan, Yan, Sonka, Milan, and Yang, Tianbao.
33+
"Large-scale robust deep auc maximization: A new surrogate loss and empirical studies on medical image classification."
34+
Proceedings of the IEEE/CVF International Conference on Computer Vision. 2021.
35+
https://arxiv.org/abs/2012.03173
36+
37+
Implementation based on: https://github.com/Optimization-AI/LibAUC/blob/1.4.0/libauc/losses/auc.py
38+
39+
Example:
40+
>>> import torch
41+
>>> from monai.losses import AUCMLoss
42+
>>> loss_fn = AUCMLoss()
43+
>>> input = torch.randn(32, 1, requires_grad=True)
44+
>>> target = torch.randint(0, 2, (32, 1)).float()
45+
>>> loss = loss_fn(input, target)
46+
"""
47+
48+
def __init__(
49+
self,
50+
margin: float = 1.0,
51+
imratio: float | None = None,
52+
version: str = "v1",
53+
reduction: LossReduction | str = LossReduction.MEAN,
54+
) -> None:
55+
"""
56+
Args:
57+
margin: margin for squared-hinge surrogate loss (default: ``1.0``).
58+
imratio: the ratio of the number of positive samples to the number of total samples in the training dataset.
59+
If this value is not given, it will be automatically calculated with mini-batch samples.
60+
This value is ignored when ``version`` is set to ``'v2'``.
61+
version: whether to include prior class information in the objective function (default: ``'v1'``).
62+
'v1' includes class prior, 'v2' removes this dependency.
63+
reduction: {``"none"``, ``"mean"``, ``"sum"``}
64+
Specifies the reduction to apply to the output. Defaults to ``"mean"``.
65+
66+
- ``"none"``: no reduction will be applied.
67+
- ``"mean"``: the sum of the output will be divided by the number of elements in the output.
68+
- ``"sum"``: the output will be summed.
69+
70+
Raises:
71+
ValueError: When ``version`` is not one of ["v1", "v2"].
72+
73+
Example:
74+
>>> import torch
75+
>>> from monai.losses import AUCMLoss
76+
>>> loss_fn = AUCMLoss(version='v2')
77+
>>> input = torch.randn(32, 1, requires_grad=True)
78+
>>> target = torch.randint(0, 2, (32, 1)).float()
79+
>>> loss = loss_fn(input, target)
80+
"""
81+
super().__init__(reduction=LossReduction(reduction).value)
82+
if version not in ["v1", "v2"]:
83+
raise ValueError(f"version should be 'v1' or 'v2', got {version}")
84+
self.margin = margin
85+
self.imratio = imratio
86+
self.version = version
87+
self.a = nn.Parameter(torch.tensor(0.0))
88+
self.b = nn.Parameter(torch.tensor(0.0))
89+
self.alpha = nn.Parameter(torch.tensor(0.0))
90+
91+
def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
92+
"""
93+
Args:
94+
input: the shape should be B1HW[D], where the channel dimension is 1 for binary classification.
95+
target: the shape should be B1HW[D], with values 0 or 1.
96+
97+
Raises:
98+
ValueError: When input or target have incorrect shapes.
99+
"""
100+
if input.shape[1] != 1:
101+
raise ValueError(f"Input should have 1 channel for binary classification, got {input.shape[1]}")
102+
if target.shape[1] != 1:
103+
raise ValueError(f"Target should have 1 channel, got {target.shape[1]}")
104+
if input.shape != target.shape:
105+
raise ValueError(f"Input and target shapes do not match: {input.shape} vs {target.shape}")
106+
107+
input = input.flatten()
108+
target = target.flatten()
109+
110+
pos_mask = (target == 1).float()
111+
neg_mask = (target == 0).float()
112+
113+
if self.version == "v1":
114+
p = self.imratio if self.imratio is not None else pos_mask.mean()
115+
loss = (
116+
(1 - p) * self._safe_mean((input - self.a) ** 2 * pos_mask)
117+
+ p * self._safe_mean((input - self.b) ** 2 * neg_mask)
118+
+ 2
119+
* self.alpha
120+
* (p * (1 - p) * self.margin + self._safe_mean(p * input * neg_mask - (1 - p) * input * pos_mask))
121+
- p * (1 - p) * self.alpha**2
122+
)
123+
else:
124+
loss = (
125+
self._safe_mean((input - self.a) ** 2 * pos_mask)
126+
+ self._safe_mean((input - self.b) ** 2 * neg_mask)
127+
+ 2 * self.alpha * (self.margin + self._safe_mean(input * neg_mask) - self._safe_mean(input * pos_mask))
128+
- self.alpha**2
129+
)
130+
131+
return loss
132+
133+
def _safe_mean(self, tensor: torch.Tensor) -> torch.Tensor:
134+
"""Compute mean safely, returning 0 if tensor is empty."""
135+
if tensor.numel() == 0 or tensor.count_nonzero() == 0:
136+
return torch.tensor(0.0, device=tensor.device, dtype=tensor.dtype)
137+
return tensor.sum() / tensor.count_nonzero()

tests/losses/test_aucm_loss.py

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
# Copyright (c) MONAI Consortium
2+
# Licensed under the Apache License, Version 2.0 (the "License");
3+
# you may not use this file except in compliance with the License.
4+
# You may obtain a copy of the License at
5+
# http://www.apache.org/licenses/LICENSE-2.0
6+
# Unless required by applicable law or agreed to in writing, software
7+
# distributed under the License is distributed on an "AS IS" BASIS,
8+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9+
# See the License for the specific language governing permissions and
10+
# limitations under the License.
11+
12+
from __future__ import annotations
13+
14+
import unittest
15+
16+
import torch
17+
18+
from monai.losses import AUCMLoss
19+
from tests.test_utils import test_script_save
20+
21+
22+
class TestAUCMLoss(unittest.TestCase):
23+
def test_v1(self):
24+
loss_fn = AUCMLoss(version="v1")
25+
input = torch.randn(32, 1, requires_grad=True)
26+
target = torch.randint(0, 2, (32, 1)).float()
27+
loss = loss_fn(input, target)
28+
self.assertIsInstance(loss, torch.Tensor)
29+
self.assertEqual(loss.ndim, 0)
30+
31+
def test_v2(self):
32+
loss_fn = AUCMLoss(version="v2")
33+
input = torch.randn(32, 1, requires_grad=True)
34+
target = torch.randint(0, 2, (32, 1)).float()
35+
loss = loss_fn(input, target)
36+
self.assertIsInstance(loss, torch.Tensor)
37+
self.assertEqual(loss.ndim, 0)
38+
39+
def test_invalid_version(self):
40+
with self.assertRaises(ValueError):
41+
AUCMLoss(version="invalid")
42+
43+
def test_invalid_input_shape(self):
44+
loss_fn = AUCMLoss()
45+
input = torch.randn(32, 2) # Wrong channel
46+
target = torch.randint(0, 2, (32, 1)).float()
47+
with self.assertRaises(ValueError):
48+
loss_fn(input, target)
49+
50+
def test_invalid_target_shape(self):
51+
loss_fn = AUCMLoss()
52+
input = torch.randn(32, 1)
53+
target = torch.randint(0, 2, (32, 2)).float() # Wrong channel
54+
with self.assertRaises(ValueError):
55+
loss_fn(input, target)
56+
57+
def test_shape_mismatch(self):
58+
loss_fn = AUCMLoss()
59+
input = torch.randn(32, 1)
60+
target = torch.randint(0, 2, (16, 1)).float()
61+
with self.assertRaises(ValueError):
62+
loss_fn(input, target)
63+
64+
def test_backward(self):
65+
loss_fn = AUCMLoss()
66+
input = torch.randn(32, 1, requires_grad=True)
67+
target = torch.randint(0, 2, (32, 1)).float()
68+
loss = loss_fn(input, target)
69+
loss.backward()
70+
self.assertIsNotNone(input.grad)
71+
72+
def test_script_save(self):
73+
loss_fn = AUCMLoss()
74+
test_script_save(loss_fn, torch.randn(32, 1), torch.randint(0, 2, (32, 1)).float())
75+
76+
77+
if __name__ == "__main__":
78+
unittest.main()

0 commit comments

Comments
 (0)