@@ -60,10 +60,8 @@ def __init__(
6060 'v1' includes class prior, 'v2' removes this dependency.
6161 reduction: {``"none"``, ``"mean"``, ``"sum"``}
6262 Specifies the reduction to apply to the output. Defaults to ``"mean"``.
63-
64- - ``"none"``: no reduction will be applied.
65- - ``"mean"``: the sum of the output will be divided by the number of elements in the output.
66- - ``"sum"``: the output will be summed.
63+ Note: This loss is computed at the batch level and always returns a scalar.
64+ The reduction parameter is accepted for API consistency but has no effect.
6765
6866 Raises:
6967 ValueError: When ``version`` is not one of ["v1", "v2"].
@@ -97,6 +95,7 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
9795
9896 Raises:
9997 ValueError: When input or target have incorrect shapes.
98+ ValueError: When target contains non-binary values.
10099 """
101100 if input .shape [1 ] != 1 :
102101 raise ValueError (f"Input should have 1 channel for binary classification, got { input .shape [1 ]} " )
@@ -108,11 +107,14 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
108107 input = input .flatten ()
109108 target = target .flatten ()
110109
110+ if not torch .all ((target == 0 ) | (target == 1 )):
111+ raise ValueError ("Target must contain only binary values (0 or 1)" )
112+
111113 pos_mask = (target == 1 ).float ()
112114 neg_mask = (target == 0 ).float ()
113115
114116 if self .version == "v1" :
115- p = self .imratio if self .imratio is not None else pos_mask .mean ()
117+ p = float ( self .imratio ) if self .imratio is not None else float ( pos_mask .mean (). item () )
116118 loss = (
117119 (1 - p ) * self ._safe_mean ((input - self .a ) ** 2 , pos_mask )
118120 + p * self ._safe_mean ((input - self .b ) ** 2 , neg_mask )
0 commit comments