Skip to content

Commit ecf33eb

Browse files
committed
added a local variable is_converted_to_prob that tracks whether we've already converted to probabilities
Signed-off-by: ytl0623 <[email protected]>
1 parent f0772ed commit ecf33eb

File tree

1 file changed

+44
-23
lines changed

1 file changed

+44
-23
lines changed

monai/losses/unified_focal_loss.py

Lines changed: 44 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -71,20 +71,26 @@ def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor:
7171
y_pred: prediction logits or probabilities. Shape should be (B, C, spatial_dims).
7272
y_true: ground truth labels. Shape should match y_pred.
7373
"""
74-
74+
75+
# Track whether we've already converted to probabilities
76+
is_converted_to_prob = self.is_prob_input
77+
7578
# Handle single-channel binary case
7679
if y_pred.shape[1] == 1:
7780
# For binary segmentation, enforce num_classes == 2
7881
if self.num_classes != 2:
79-
raise ValueError(f"Single-channel input requires num_classes=2, but got num_classes={self.num_classes}")
80-
82+
raise ValueError(
83+
f"Single-channel input requires num_classes=2, but got num_classes={self.num_classes}"
84+
)
85+
8186
# Convert to probability if needed
82-
if not self.is_prob_input:
87+
if not is_converted_to_prob:
8388
y_pred = torch.sigmoid(y_pred)
84-
89+
is_converted_to_prob = True
90+
8591
# Expand to 2 channels
8692
y_pred = torch.cat([1 - y_pred, y_pred], dim=1)
87-
93+
8894
# Convert y_true to one-hot with 2 classes (matching expanded y_pred)
8995
if y_true.shape[1] == 1:
9096
y_true = one_hot(y_true, num_classes=2)
@@ -107,14 +113,16 @@ def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor:
107113
if y_true.shape[1] == 1:
108114
# Validate class indices
109115
if torch.max(y_true) >= self.num_classes:
110-
raise ValueError(f"Ground truth contains class indices >= {self.num_classes}")
116+
raise ValueError(
117+
f"Ground truth contains class indices >= {self.num_classes}"
118+
)
111119
y_true = one_hot(y_true, num_classes=self.num_classes)
112120

113121
if y_true.shape != y_pred.shape:
114122
raise ValueError(f"ground truth has different shape ({y_true.shape}) from input ({y_pred.shape})")
115123

116124
# Convert logits to probabilities if not already done
117-
if not self.is_prob_input and y_pred.shape[1] != 1:
125+
if not is_converted_to_prob:
118126
if self.use_softmax:
119127
y_pred = torch.softmax(y_pred, dim=1)
120128
else:
@@ -203,19 +211,25 @@ def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor:
203211
y_true: ground truth labels.
204212
"""
205213

214+
# Track whether we've already converted to probabilities
215+
is_converted_to_prob = self.is_prob_input
216+
206217
# Handle single-channel binary case
207218
if y_pred.shape[1] == 1:
208219
# For binary segmentation, enforce num_classes == 2
209220
if self.num_classes != 2:
210-
raise ValueError(f"Single-channel input requires num_classes=2, but got num_classes={self.num_classes}")
211-
221+
raise ValueError(
222+
f"Single-channel input requires num_classes=2, but got num_classes={self.num_classes}"
223+
)
224+
212225
# Convert to probability if needed
213-
if not self.is_prob_input:
226+
if not is_converted_to_prob:
214227
y_pred = torch.sigmoid(y_pred)
215-
228+
is_converted_to_prob = True
229+
216230
# Expand to 2 channels
217231
y_pred = torch.cat([1 - y_pred, y_pred], dim=1)
218-
232+
219233
# Convert y_true to one-hot with 2 classes (matching expanded y_pred)
220234
if y_true.shape[1] == 1:
221235
y_true = one_hot(y_true, num_classes=2)
@@ -238,14 +252,16 @@ def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor:
238252
if y_true.shape[1] == 1:
239253
# Validate class indices
240254
if torch.max(y_true) >= self.num_classes:
241-
raise ValueError(f"Ground truth contains class indices >= {self.num_classes}")
255+
raise ValueError(
256+
f"Ground truth contains class indices >= {self.num_classes}"
257+
)
242258
y_true = one_hot(y_true, num_classes=self.num_classes)
243259

244260
if y_true.shape != y_pred.shape:
245261
raise ValueError(f"ground truth has different shape ({y_true.shape}) from input ({y_pred.shape})")
246262

247263
# Convert logits to probabilities if not already done
248-
if not self.is_prob_input and y_pred.shape[1] != 1:
264+
if not is_converted_to_prob:
249265
if self.use_softmax:
250266
y_pred = torch.softmax(y_pred, dim=1)
251267
else:
@@ -267,10 +283,12 @@ def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor:
267283

268284
# Apply reduction
269285
if self.reduction == LossReduction.MEAN.value:
286+
# For MEAN: average over all elements (batch and classes and spatial)
270287
return torch.mean(all_ce)
271288
if self.reduction == LossReduction.SUM.value:
272289
return torch.sum(all_ce)
273290
if self.reduction == LossReduction.NONE.value:
291+
# For NONE: return full tensor with all spatial dimensions
274292
return all_ce
275293

276294
return torch.mean(all_ce)
@@ -316,11 +334,13 @@ def __init__(
316334
>>> fl(pred, grnd)
317335
"""
318336
super().__init__(reduction=LossReduction(reduction).value)
319-
337+
320338
# Validate configuration
321339
if use_softmax and num_classes < 2:
322-
raise ValueError(f"use_softmax=True requires num_classes >= 2, but got num_classes={num_classes}")
323-
340+
raise ValueError(
341+
f"use_softmax=True requires num_classes >= 2, but got num_classes={num_classes}"
342+
)
343+
324344
self.to_onehot_y = to_onehot_y
325345
self.num_classes = num_classes
326346
self.gamma = gamma
@@ -362,18 +382,18 @@ def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor:
362382
ValueError: When the number of classes in ground truth exceeds the configured `num_classes`.
363383
ValueError: When use_softmax=True with single-channel input.
364384
"""
365-
385+
366386
# Validate input dimensions
367387
if len(y_pred.shape) not in [4, 5]:
368388
raise ValueError(f"Input shape must be 4 (2D) or 5 (3D), but got {y_pred.shape}")
369-
389+
370390
# Validate use_softmax with single channel
371391
if self.use_softmax and y_pred.shape[1] == 1:
372392
raise ValueError(
373393
"use_softmax=True is invalid with single-channel input (C=1). "
374394
"Use use_softmax=False for binary segmentation with sigmoid activation."
375395
)
376-
396+
377397
# Shape validation
378398
if y_pred.shape != y_true.shape:
379399
# Allow mismatch only for valid cases
@@ -386,13 +406,14 @@ def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor:
386406
"and this mismatch cannot be resolved by `to_onehot_y` or binary expansion."
387407
)
388408

409+
# Pre-process y_true if needed (will be done inside sub-losses, but validate here)
389410
if self.to_onehot_y and y_true.shape[1] == 1:
390411
# Check indices validity before conversion
391412
if torch.max(y_true) >= self.num_classes:
392413
raise ValueError(
393414
f"Ground truth contains class indices >= {self.num_classes}, which exceeds num_classes."
394415
)
395-
416+
396417
# Get losses from sub-losses
397418
# AsymmetricFocalLoss with NONE reduction returns (B, C, H, W, [D])
398419
# AsymmetricFocalTverskyLoss with NONE reduction returns (B, C)
@@ -421,4 +442,4 @@ def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor:
421442
else:
422443
loss = torch.mean(combined_loss)
423444

424-
return loss
445+
return loss

0 commit comments

Comments
 (0)