Skip to content

Commit 3ac3ffb

Browse files
authored
soft_dice_coefficient incorrectly flattens all classes together (#125)
* compute soft dice per-class instead of flattening all classes * add do_bg parameter to DiceBCELossWithLogits * rename do_bg to include_bg * add include_bg argument in SegmentationTrainer
1 parent 47e0867 commit 3ac3ffb

File tree

3 files changed

+13
-10
lines changed

3 files changed

+13
-10
lines changed

mipcandy/common/optim/loss.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,12 +26,13 @@ def forward(self, logits: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:
2626

2727
class DiceBCELossWithLogits(nn.Module):
2828
def __init__(self, num_classes: int, *, lambda_bce: float = .5, lambda_soft_dice: float = 1,
29-
smooth: float = 1e-5) -> None:
29+
smooth: float = 1e-5, include_bg: bool = True) -> None:
3030
super().__init__()
3131
self.num_classes: int = num_classes
3232
self.lambda_bce: float = lambda_bce
3333
self.lambda_soft_dice: float = lambda_soft_dice
3434
self.smooth: float = smooth
35+
self.include_bg: bool = include_bg
3536

3637
def forward(self, masks: torch.Tensor, labels: torch.Tensor) -> tuple[torch.Tensor, dict[str, float]]:
3738
if self.num_classes != 1 and labels.shape[1] == 1:
@@ -42,6 +43,6 @@ def forward(self, masks: torch.Tensor, labels: torch.Tensor) -> tuple[torch.Tens
4243
labels = labels.float()
4344
bce = nn.functional.binary_cross_entropy_with_logits(masks, labels)
4445
masks = masks.sigmoid()
45-
soft_dice = soft_dice_coefficient(masks, labels, smooth=self.smooth)
46+
soft_dice = soft_dice_coefficient(masks, labels, smooth=self.smooth, include_bg=self.include_bg)
4647
c = self.lambda_bce * bce + self.lambda_soft_dice * (1 - soft_dice)
4748
return c, {"soft dice": soft_dice.item(), "bce loss": bce.item()}

mipcandy/metrics.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -60,14 +60,15 @@ def dice_similarity_coefficient_multiclass(output: torch.Tensor, label: torch.Te
6060
return apply_multiclass_to_binary(dice_similarity_coefficient_binary, output, label, num_classes, if_empty)
6161

6262

63-
def soft_dice_coefficient(output: torch.Tensor, label: torch.Tensor, *, smooth: float = 1e-5) -> torch.Tensor:
63+
def soft_dice_coefficient(output: torch.Tensor, label: torch.Tensor, *,
64+
smooth: float = 1e-5, include_bg: bool = True) -> torch.Tensor:
6465
_args_check(output, label)
65-
num = label.size(0)
66-
output = output.view(num, -1)
67-
label = label.view(num, -1)
68-
intersection = (output * label)
69-
dice = (2 * intersection.sum(1) + smooth) / (output.sum(1) + label.sum(1) + smooth)
70-
return dice.sum() / num
66+
axes = tuple(range(2, output.ndim))
67+
intersection = (output * label).sum(dim=axes)
68+
dice = (2 * intersection + smooth) / (output.sum(dim=axes) + label.sum(dim=axes) + smooth)
69+
if not include_bg:
70+
dice = dice[:, 1:]
71+
return dice.mean()
7172

7273

7374
def accuracy_binary(output: torch.Tensor, label: torch.Tensor, *, if_empty: float = 1) -> torch.Tensor:

mipcandy/presets/segmentation.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313

1414
class SegmentationTrainer(Trainer, metaclass=ABCMeta):
1515
num_classes: int = 1
16+
include_bg: bool = True
1617

1718
def _save_preview(self, x: torch.Tensor, title: str, quality: float) -> None:
1819
path = f"{self.experiment_folder()}/{title} (preview).png"
@@ -36,7 +37,7 @@ def save_preview(self, image: torch.Tensor, label: torch.Tensor, output: torch.T
3637

3738
@override
3839
def build_criterion(self) -> nn.Module:
39-
return DiceBCELossWithLogits(self.num_classes)
40+
return DiceBCELossWithLogits(self.num_classes, include_bg=self.include_bg)
4041

4142
@override
4243
def build_optimizer(self, params: Params) -> optim.Optimizer:

0 commit comments

Comments
 (0)