Skip to content

Commit 2a56f54

Browse files
committed
Validate binary targets, clarify reduction, and fix AUCM typing
Signed-off-by: Shubham Chandravanshi <shubham.chandravanshi378@gmail.com>
1 parent c550c29 commit 2a56f54

File tree

2 files changed

+15
-5
lines changed

2 files changed

+15
-5
lines changed

monai/losses/aucm_loss.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -60,10 +60,8 @@ def __init__(
6060
'v1' includes class prior, 'v2' removes this dependency.
6161
reduction: {``"none"``, ``"mean"``, ``"sum"``}
6262
Specifies the reduction to apply to the output. Defaults to ``"mean"``.
63-
64-
- ``"none"``: no reduction will be applied.
65-
- ``"mean"``: the sum of the output will be divided by the number of elements in the output.
66-
- ``"sum"``: the output will be summed.
63+
Note: This loss is computed at the batch level and always returns a scalar.
64+
The reduction parameter is accepted for API consistency but has no effect.
6765
6866
Raises:
6967
ValueError: When ``version`` is not one of ["v1", "v2"].
@@ -97,6 +95,7 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
9795
9896
Raises:
9997
ValueError: When input or target have incorrect shapes.
98+
ValueError: When target contains non-binary values.
10099
"""
101100
if input.shape[1] != 1:
102101
raise ValueError(f"Input should have 1 channel for binary classification, got {input.shape[1]}")
@@ -108,11 +107,14 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
108107
input = input.flatten()
109108
target = target.flatten()
110109

110+
if not torch.all((target == 0) | (target == 1)):
111+
raise ValueError("Target must contain only binary values (0 or 1)")
112+
111113
pos_mask = (target == 1).float()
112114
neg_mask = (target == 0).float()
113115

114116
if self.version == "v1":
115-
p = self.imratio if self.imratio is not None else pos_mask.mean()
117+
p = float(self.imratio) if self.imratio is not None else float(pos_mask.mean().item())
116118
loss = (
117119
(1 - p) * self._safe_mean((input - self.a) ** 2, pos_mask)
118120
+ p * self._safe_mean((input - self.b) ** 2, neg_mask)

tests/losses/test_aucm_loss.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,14 @@ def test_shape_mismatch(self):
6969
with self.assertRaises(ValueError):
7070
loss_fn(input, target)
7171

72+
def test_non_binary_target(self):
73+
"""Test that non-binary target values raise ValueError."""
74+
loss_fn = AUCMLoss()
75+
input = torch.randn(32, 1)
76+
target = torch.tensor([[0.5], [1.0], [2.0]] * 10 + [[0.0]]) # Contains non-binary values
77+
with self.assertRaises(ValueError):
78+
loss_fn(input, target)
79+
7280
def test_backward(self):
7381
"""Test that gradients can be computed."""
7482
loss_fn = AUCMLoss()

0 commit comments

Comments
 (0)