Skip to content

Commit a7e34de

Browse files
authored
DiceBCELossWithLogits is expected to support multiclass (#110)
* Updated loss computation to handle multi-class logit conversion. (#109) * Fixed label conversion order in loss calculation to ensure compatible data types. (#109) * Ensured tensor contiguity in one-hot encoding conversion. (#109)
1 parent 86d0517 commit a7e34de

File tree

2 files changed

+7
-1
lines changed

2 files changed

+7
-1
lines changed

mipcandy/common/optim/loss.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import torch
44
from torch import nn
55

6+
from mipcandy.data import convert_ids_to_logits
67
from mipcandy.metrics import do_reduction, soft_dice_coefficient
78

89

@@ -33,6 +34,11 @@ def __init__(self, num_classes: int, *, lambda_bce: float = .5, lambda_soft_dice
3334
self.smooth: float = smooth
3435

3536
def forward(self, masks: torch.Tensor, labels: torch.Tensor) -> tuple[torch.Tensor, dict[str, float]]:
37+
if self.num_classes != 1 and labels.shape[1] == 1:
38+
d = labels.ndim - 2
39+
if d not in (1, 2, 3):
40+
raise ValueError(f"Expected labels to be 1D, 2D, or 3D, got {d} spatial dimensions")
41+
labels = convert_ids_to_logits(labels.int(), d, self.num_classes)
3642
labels = labels.float()
3743
bce = nn.functional.binary_cross_entropy_with_logits(masks, labels)
3844
masks = masks.sigmoid()

mipcandy/data/convertion.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ def convert_ids_to_logits(ids: torch.Tensor, d: Literal[1, 2, 3], num_classes: i
1515
ids = ids.squeeze(1)
1616
else:
1717
raise ValueError(f"`ids` should be {d} dimensional or {d + 1} dimensional with single channel")
18-
return nn.functional.one_hot(ids.long(), num_classes).movedim(-1, 1).float()
18+
return nn.functional.one_hot(ids.long(), num_classes).movedim(-1, 1).contiguous().float()
1919

2020

2121
def convert_logits_to_ids(logits: torch.Tensor, *, channel_dim: int = 1) -> torch.Tensor:

0 commit comments

Comments
 (0)