Skip to content

Commit 53dabb5

Browse files
committed
Change Dice metric include_background for WNet
To avoid Max Dice calculation
1 parent d3414e8 commit 53dabb5

File tree

2 files changed

+50
-35
lines changed

2 files changed

+50
-35
lines changed

napari_cellseg3d/code_models/models/wnet/model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ def __init__(
6262
)
6363

6464
def forward(self, x):
65-
"""Forward pass of the W-Net model."""
65+
"""Forward pass of the W-Net model. Returns the segmentation and the reconstructed image."""
6666
enc = self.forward_encoder(x)
6767
return enc, self.forward_decoder(enc)
6868

napari_cellseg3d/code_models/worker_training.py

Lines changed: 49 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -287,10 +287,10 @@ def _get_data(self):
287287
train_transforms = EnsureTyped(keys=["image"])
288288

289289
if self.config.sampling:
290-
self.log("Loading patch dataset")
290+
logger.debug("Loading patch dataset")
291291
(data_shape, dataset) = self.get_patch_dataset(train_transforms)
292292
else:
293-
self.log("Loading volume dataset")
293+
logger.debug("Loading volume dataset")
294294
(data_shape, dataset) = self.get_dataset(train_transforms)
295295

296296
logger.debug(f"Data shape : {data_shape}")
@@ -388,7 +388,7 @@ def train(self):
388388
dataloader, eval_dataloader, data_shape = self._get_data()
389389

390390
dice_metric = DiceMetric(
391-
include_background=False, reduction="mean", get_not_nans=False
391+
include_background=True, reduction="mean", get_not_nans=False
392392
)
393393
###################################################
394394
# Training the model #
@@ -510,15 +510,13 @@ def train(self):
510510
)
511511

512512
# Forward pass
513-
enc = model.forward_encoder(image_batch)
513+
enc, dec = model(image_batch)
514514
# Compute the Ncuts loss
515515
Ncuts = criterionE(enc, image_batch)
516516
epoch_ncuts_loss += Ncuts.item()
517517
# if WANDB_INSTALLED:
518518
# wandb.log({"Ncuts loss": Ncuts.item()})
519519

520-
dec = model.forward_decoder(enc)
521-
522520
# Compute the reconstruction loss
523521
if isinstance(criterionW, nn.MSELoss):
524522
reconstruction_loss = criterionW(dec, image_batch)
@@ -685,32 +683,33 @@ def train(self):
685683
f"Val decoder outputs shape: {val_decoder_outputs.shape}"
686684
)
687685

688-
dices = []
686+
# dices = []
689687
# Find in which channel the labels are (avoid background)
690-
for channel in range(val_outputs.shape[1]):
691-
dices.append(
692-
utils.dice_coeff(
693-
y_pred=val_outputs[
694-
0, channel : (channel + 1), :, :, :
695-
],
696-
y_true=val_labels[0],
697-
)
698-
)
699-
logger.debug(f"DICE COEFF: {dices}")
700-
max_dice_channel = torch.argmax(
701-
torch.Tensor(dices)
702-
)
703-
logger.debug(
704-
f"MAX DICE CHANNEL: {max_dice_channel}"
705-
)
688+
# for channel in range(val_outputs.shape[1]):
689+
# dices.append(
690+
# utils.dice_coeff(
691+
# y_pred=val_outputs[
692+
# 0, channel : (channel + 1), :, :, :
693+
# ],
694+
# y_true=val_labels[0],
695+
# )
696+
# )
697+
# logger.debug(f"DICE COEFF: {dices}")
698+
# max_dice_channel = torch.argmax(
699+
# torch.Tensor(dices)
700+
# )
701+
# logger.debug(
702+
# f"MAX DICE CHANNEL: {max_dice_channel}"
703+
# )
706704
dice_metric(
707-
y_pred=val_outputs[
708-
:,
709-
max_dice_channel : (max_dice_channel + 1),
710-
:,
711-
:,
712-
:,
713-
],
705+
y_pred=val_outputs,
706+
# [
707+
# :,
708+
# max_dice_channel : (max_dice_channel + 1),
709+
# :,
710+
# :,
711+
# :,
712+
# ],
714713
y=val_labels,
715714
)
716715

@@ -736,11 +735,19 @@ def train(self):
736735
# wandb.log({"val/dice_metric": metric})
737736

738737
dec_out_val = (
739-
val_decoder_outputs[0].detach().cpu().numpy()
738+
val_decoder_outputs[0]
739+
.detach()
740+
.cpu()
741+
.numpy()
742+
.copy()
743+
)
744+
enc_out_val = (
745+
val_outputs[0].detach().cpu().numpy().copy()
746+
)
747+
lab_out_val = (
748+
val_labels[0].detach().cpu().numpy().copy()
740749
)
741-
enc_out_val = val_outputs[0].detach().cpu().numpy()
742-
lab_out_val = val_labels[0].detach().cpu().numpy()
743-
val_in = val_inputs[0].detach().cpu().numpy()
750+
val_in = val_inputs[0].detach().cpu().numpy().copy()
744751

745752
display_dict = {
746753
"Reconstruction": {
@@ -760,6 +767,14 @@ def train(self):
760767
"cmap": "bop blue",
761768
},
762769
}
770+
val_decoder_outputs = None
771+
del val_decoder_outputs
772+
val_outputs = None
773+
del val_outputs
774+
val_labels = None
775+
del val_labels
776+
val_inputs = None
777+
del val_inputs
763778

764779
yield TrainingReport(
765780
epoch=epoch,

0 commit comments

Comments
 (0)