@@ -287,10 +287,10 @@ def _get_data(self):
287287 train_transforms = EnsureTyped (keys = ["image" ])
288288
289289 if self .config .sampling :
290- self . log ("Loading patch dataset" )
290+ logger . debug ("Loading patch dataset" )
291291 (data_shape , dataset ) = self .get_patch_dataset (train_transforms )
292292 else :
293- self . log ("Loading volume dataset" )
293+ logger . debug ("Loading volume dataset" )
294294 (data_shape , dataset ) = self .get_dataset (train_transforms )
295295
296296 logger .debug (f"Data shape : { data_shape } " )
@@ -388,7 +388,7 @@ def train(self):
388388 dataloader , eval_dataloader , data_shape = self ._get_data ()
389389
390390 dice_metric = DiceMetric (
391- include_background = False , reduction = "mean" , get_not_nans = False
391+ include_background = True , reduction = "mean" , get_not_nans = False
392392 )
393393 ###################################################
394394 # Training the model #
@@ -510,15 +510,13 @@ def train(self):
510510 )
511511
512512 # Forward pass
513- enc = model . forward_encoder (image_batch )
513+ enc , dec = model (image_batch )
514514 # Compute the Ncuts loss
515515 Ncuts = criterionE (enc , image_batch )
516516 epoch_ncuts_loss += Ncuts .item ()
517517 # if WANDB_INSTALLED:
518518 # wandb.log({"Ncuts loss": Ncuts.item()})
519519
520- dec = model .forward_decoder (enc )
521-
522520 # Compute the reconstruction loss
523521 if isinstance (criterionW , nn .MSELoss ):
524522 reconstruction_loss = criterionW (dec , image_batch )
@@ -685,32 +683,33 @@ def train(self):
685683 f"Val decoder outputs shape: { val_decoder_outputs .shape } "
686684 )
687685
688- dices = []
686+ # dices = []
689687 # Find in which channel the labels are (avoid background)
690- for channel in range (val_outputs .shape [1 ]):
691- dices .append (
692- utils .dice_coeff (
693- y_pred = val_outputs [
694- 0 , channel : (channel + 1 ), :, :, :
695- ],
696- y_true = val_labels [0 ],
697- )
698- )
699- logger .debug (f"DICE COEFF: { dices } " )
700- max_dice_channel = torch .argmax (
701- torch .Tensor (dices )
702- )
703- logger .debug (
704- f"MAX DICE CHANNEL: { max_dice_channel } "
705- )
688+ # for channel in range(val_outputs.shape[1]):
689+ # dices.append(
690+ # utils.dice_coeff(
691+ # y_pred=val_outputs[
692+ # 0, channel : (channel + 1), :, :, :
693+ # ],
694+ # y_true=val_labels[0],
695+ # )
696+ # )
697+ # logger.debug(f"DICE COEFF: {dices}")
698+ # max_dice_channel = torch.argmax(
699+ # torch.Tensor(dices)
700+ # )
701+ # logger.debug(
702+ # f"MAX DICE CHANNEL: {max_dice_channel}"
703+ # )
706704 dice_metric (
707- y_pred = val_outputs [
708- :,
709- max_dice_channel : (max_dice_channel + 1 ),
710- :,
711- :,
712- :,
713- ],
705+ y_pred = val_outputs ,
706+ # [
707+ # :,
708+ # max_dice_channel : (max_dice_channel + 1),
709+ # :,
710+ # :,
711+ # :,
712+ # ],
714713 y = val_labels ,
715714 )
716715
@@ -736,11 +735,19 @@ def train(self):
736735 # wandb.log({"val/dice_metric": metric})
737736
738737 dec_out_val = (
739- val_decoder_outputs [0 ].detach ().cpu ().numpy ()
738+ val_decoder_outputs [0 ]
739+ .detach ()
740+ .cpu ()
741+ .numpy ()
742+ .copy ()
743+ )
744+ enc_out_val = (
745+ val_outputs [0 ].detach ().cpu ().numpy ().copy ()
746+ )
747+ lab_out_val = (
748+ val_labels [0 ].detach ().cpu ().numpy ().copy ()
740749 )
741- enc_out_val = val_outputs [0 ].detach ().cpu ().numpy ()
742- lab_out_val = val_labels [0 ].detach ().cpu ().numpy ()
743- val_in = val_inputs [0 ].detach ().cpu ().numpy ()
750+ val_in = val_inputs [0 ].detach ().cpu ().numpy ().copy ()
744751
745752 display_dict = {
746753 "Reconstruction" : {
@@ -760,6 +767,14 @@ def train(self):
760767 "cmap" : "bop blue" ,
761768 },
762769 }
770+ val_decoder_outputs = None
771+ del val_decoder_outputs
772+ val_outputs = None
773+ del val_outputs
774+ val_labels = None
775+ del val_labels
776+ val_inputs = None
777+ del val_inputs
763778
764779 yield TrainingReport (
765780 epoch = epoch ,
0 commit comments