6868logger .debug (f"PRETRAINED WEIGHT DIR LOCATION : { PRETRAINED_WEIGHTS_DIR } " )
6969
7070try :
71- import wandb
72-
7371 WANDB_INSTALLED = True
7472except ImportError :
7573 logger .warning (
@@ -411,25 +409,25 @@ def log_parameters(self):
411409 self .log (f"Using { self .config .num_classes } classes" )
412410 self .log (f"Weight decay: { self .config .weight_decay } " )
413411 self .log ("* NCuts : " )
414- self .log (f"- Insensity sigma { self .config .intensity_sigma } " )
412+ self .log (f"- Intensity sigma { self .config .intensity_sigma } " )
415413 self .log (f"- Spatial sigma { self .config .spatial_sigma } " )
416414 self .log (f"- Radius : { self .config .radius } " )
417415 self .log (f"* Reconstruction loss : { self .config .reconstruction_loss } " )
418416 self .log (
419- f"Weighted sum : { self .config .n_cuts_weight } *Ncuts + { self .config .rec_loss_weight } *Reconstruction"
417+ f"Weighted sum : { self .config .n_cuts_weight } *NCuts + { self .config .rec_loss_weight } *Reconstruction"
420418 )
421419 ##############
422420 self .log ("-- Data --" )
423421 self .log ("Training data :" )
424422 [
425- self .log (f"\n { v } " )
423+ self .log (f"{ v } " )
426424 for d in self .config .train_data_dict
427425 for k , v in d .items ()
428426 ]
429427 if self .config .eval_volume_dict is not None :
430428 self .log ("Validation data :" )
431429 [
432- self .log (f"\n { k } : { v } " )
430+ self .log (f"{ k } : { v } " )
433431 for d in self .config .eval_volume_dict
434432 for k , v in d .items ()
435433 ]
@@ -443,9 +441,9 @@ def train(self):
443441 set_track_meta (False )
444442 ##############
445443 # if WANDB_INSTALLED:
446- # wandb.init(
447- # config=WANDB_CONFIG, project="WNet-benchmark", mode=WANDB_MODE
448- # )
444+ # wandb.init(
445+ # config=WANDB_CONFIG, project="WNet-benchmark", mode=WANDB_MODE
446+ # )
449447
450448 set_determinism (
451449 seed = self .config .deterministic_config .seed
@@ -455,12 +453,8 @@ def train(self):
455453 normalize_function = utils .remap_image
456454 device = self .config .device
457455
458- self .log (f"Using device: { device } " )
459-
460- # self.log("Config:") # FIXME log_parameters func instead
461- # [self.log(str(a)) for a in self.config.__dict__.items()]
456+ # self.log(f"Using device: {device}")
462457 self .log_parameters ()
463-
464458 self .log ("Initializing training..." )
465459 self .log ("Getting the data" )
466460
@@ -473,7 +467,6 @@ def train(self):
473467 # Training the model #
474468 ###################################################
475469 self .log ("Initializing the model:" )
476-
477470 self .log ("- Getting the model" )
478471 # Initialize the model
479472 model = WNet (
@@ -494,8 +487,8 @@ def train(self):
494487 )
495488 )
496489
497- if WANDB_INSTALLED :
498- wandb .watch (model , log_freq = 100 )
490+ # if WANDB_INSTALLED:
491+ # wandb.watch(model, log_freq=100)
499492
500493 if self .config .weights_info .custom :
501494 if self .config .weights_info .use_pretrained :
@@ -619,10 +612,10 @@ def train(self):
619612 )
620613
621614 epoch_rec_loss += reconstruction_loss .item ()
622- if WANDB_INSTALLED :
623- wandb .log (
624- {"Reconstruction loss" : reconstruction_loss .item ()}
625- )
615+ # if WANDB_INSTALLED:
616+ # wandb.log(
617+ # {"Reconstruction loss": reconstruction_loss.item()}
618+ # )
626619
627620 # Backward pass for the reconstruction loss
628621 optimizer .zero_grad ()
@@ -631,8 +624,8 @@ def train(self):
631624
632625 loss = alpha * Ncuts + beta * reconstruction_loss
633626 epoch_loss += loss .item ()
634- if WANDB_INSTALLED :
635- wandb .log ({"Weighted sum of losses" : loss .item ()})
627+ # if WANDB_INSTALLED:
628+ # wandb.log({"Weighted sum of losses": loss.item()})
636629 loss .backward (loss )
637630 optimizer .step ()
638631
@@ -818,9 +811,9 @@ def train(self):
818811 self .log (f"Saving new best model to { save_path } " )
819812 torch .save (model .state_dict (), save_path )
820813
821- if WANDB_INSTALLED :
822- # log validation dice score for each validation round
823- wandb .log ({"val/dice_metric" : metric })
814+ # if WANDB_INSTALLED:
815+ # log validation dice score for each validation round
816+ # wandb.log({"val/dice_metric": metric})
824817
825818 dec_out_val = (
826819 val_decoder_outputs [0 ].detach ().cpu ().numpy ()
0 commit comments