@@ -282,9 +282,7 @@ def train(self):
282282 size = self .config .sample_size if do_sampling else check
283283
284284 PADDING = utils .get_padding_dim (size )
285- model = model_class ( # FIXME check if correct
286- input_img_size = PADDING , use_checkpoint = True
287- )
285+ model = model_class (input_img_size = PADDING , use_checkpoint = True )
288286 model = model .to (self .config .device )
289287
290288 epoch_loss_values = []
@@ -570,7 +568,7 @@ def get_loader_func(num_samples):
570568 if outputs .shape [1 ] > 1 :
571569 outputs = outputs [
572570 :, 1 :, :, :
573- ] # FIXME fix channel number
571+ ] # TODO(cyril): adapt if additional channels
574572 if len (outputs .shape ) < 4 :
575573 outputs = outputs .unsqueeze (0 )
576574 loss = self .config .loss_function (outputs , labels )
@@ -632,17 +630,15 @@ def get_loader_func(num_samples):
632630 )
633631 except Exception as e :
634632 self .raise_error (e , "Error during validation" )
633+
635634 logger .debug (
636635 f"val_outputs shape : { val_outputs .shape } "
637636 )
638637 # val_outputs = model(val_inputs)
639638
640639 pred = decollate_batch (val_outputs )
641-
642640 labs = decollate_batch (val_labels )
643-
644641 # TODO : more parameters/flexibility
645-
646642 post_pred = Compose (
647643 [
648644 RemapTensor (new_max = 1 , new_min = 0 ),
@@ -668,15 +664,15 @@ def get_loader_func(num_samples):
668664
669665 # logger.debug(len(val_outputs))
670666 # logger.debug(len(val_labels))
671- dice_test = np .array ( # TODO(cyril): remove
672- [
673- utils .dice_coeff (i , j )
674- for i , j in zip (val_outputs , val_labels )
675- ]
676- )
677- logger .debug (
678- f"TEST VALIDATION Dice score : { dice_test .mean ()} "
679- )
667+ # dice_test = np.array(
668+ # [
669+ # utils.dice_coeff(i, j)
670+ # for i, j in zip(val_outputs, val_labels)
671+ # ]
672+ # )
673+ # logger.debug(
674+ # f"TEST VALIDATION Dice score : {dice_test.mean()}"
675+ # )
680676
681677 dice_metric (y_pred = val_outputs , y = val_labels )
682678
0 commit comments