Skip to content

Commit 7a100d9

Browse files
committed
fix: reduction=NONE
Signed-off-by: ytl0623 <[email protected]>
1 parent b08de65 commit 7a100d9

File tree

1 file changed

+24
-10
lines changed

1 file changed

+24
-10
lines changed

monai/losses/unified_focal_loss.py

Lines changed: 24 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,7 @@ def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor:
7171
y_pred = torch.sigmoid(y_pred)
7272
y_pred = torch.cat([1 - y_pred, y_pred], dim=1)
7373
is_already_prob = True
74+
# Expand y_true to match if it's single channel
7475
if y_true.shape[1] == 1:
7576
y_true = one_hot(y_true, num_classes=2)
7677
else:
@@ -213,17 +214,15 @@ def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor:
213214
# Concatenate losses
214215
all_ce = torch.cat([back_ce.unsqueeze(1), fore_ce], dim=1)
215216

216-
# Sum over classes (dim=1) to get total loss per pixel
217-
total_loss = torch.sum(all_ce, dim=1)
218-
219217
# Apply reduction
220218
if self.reduction == LossReduction.MEAN.value:
221-
return torch.mean(total_loss)
219+
return torch.mean(torch.sum(all_ce, dim=1))
222220
if self.reduction == LossReduction.SUM.value:
223-
return torch.sum(total_loss)
221+
return torch.sum(all_ce)
224222
if self.reduction == LossReduction.NONE.value:
225-
return total_loss
226-
return torch.mean(total_loss)
223+
return all_ce
224+
225+
return torch.mean(torch.sum(all_ce, dim=1))
227226

228227

229228
class AsymmetricUnifiedFocalLoss(_Loss):
@@ -268,14 +267,14 @@ def __init__(
268267
delta=self.delta,
269268
use_softmax=self.use_softmax,
270269
to_onehot_y=to_onehot_y,
271-
reduction=reduction,
270+
reduction=LossReduction.NONE,
272271
)
273272
self.asy_focal_tversky_loss = AsymmetricFocalTverskyLoss(
274273
gamma=self.gamma,
275274
delta=self.delta,
276275
use_softmax=self.use_softmax,
277276
to_onehot_y=to_onehot_y,
278-
reduction=reduction,
277+
reduction=LossReduction.NONE,
279278
)
280279

281280
def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor:
@@ -293,6 +292,21 @@ def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor:
293292
asy_focal_loss = self.asy_focal_loss(y_pred, y_true)
294293
asy_focal_tversky_loss = self.asy_focal_tversky_loss(y_pred, y_true)
295294

296-
loss: torch.Tensor = self.weight * asy_focal_loss + (1 - self.weight) * asy_focal_tversky_loss
295+
# Align Focal Loss to (B, C) by averaging over spatial dimensions
296+
spatial_dims = list(range(2, len(asy_focal_loss.shape)))
297+
focal_aligned = torch.mean(asy_focal_loss, dim=spatial_dims)
298+
299+
# Calculate weighted sum. Result shape: (B, C)
300+
combined_loss = self.weight * focal_aligned + (1 - self.weight) * asy_focal_tversky_loss
301+
302+
loss: torch.Tensor
303+
if self.reduction == LossReduction.MEAN.value:
304+
loss = torch.mean(combined_loss)
305+
elif self.reduction == LossReduction.SUM.value:
306+
loss = torch.sum(combined_loss)
307+
elif self.reduction == LossReduction.NONE.value:
308+
loss = combined_loss
309+
else:
310+
loss = torch.mean(combined_loss)
297311

298312
return loss

0 commit comments

Comments
 (0)