@@ -388,6 +388,48 @@ def _get_data(self):
388388 eval_dataloader = None
389389 return dataloader , eval_dataloader , data_shape
390390
391+ def log_parameters (self ):
392+ self .log ("*" * 20 )
393+ self .log ("-- Parameters --" )
394+ self .log (f"Device: { self .config .device } " )
395+ self .log (f"Batch size: { self .config .batch_size } " )
396+ self .log (f"Epochs: { self .config .max_epochs } " )
397+ self .log (f"Learning rate: { self .config .learning_rate } " )
398+ self .log (f"Validation interval: { self .config .validation_interval } " )
399+ if self .config .weights_info .custom :
400+ self .log (f"Custom weights: { self .config .weights_info .path } " )
401+ elif self .config .weights_info .use_pretrained :
402+ self .log (f"Pretrained weights: { self .config .weights_info .path } " )
403+ if self .config .sampling :
404+ self .log (
405+ f"Using { self .config .num_samples } samples of size { self .config .sample_size } "
406+ )
407+ if self .config .do_augmentation :
408+ self .log ("Using data augmentation" )
409+ ##############
410+ self .log ("-- Model --" )
411+ self .log (f"Using { self .config .num_classes } classes" )
412+ self .log (f"Weight decay: { self .config .weight_decay } " )
413+ self .log ("* NCuts : " )
414+ self .log (f"- Insensity sigma { self .config .intensity_sigma } " )
415+ self .log (f"- Spatial sigma { self .config .spatial_sigma } " )
416+ self .log (f"- Radius : { self .config .radius } " )
417+ self .log (f"* Reconstruction loss : { self .config .reconstruction_loss } " )
418+ self .log (
419+ f"Weighted sum : { self .config .n_cuts_weight } *Ncuts + { self .config .rec_loss_weight } *Reconstruction"
420+ )
421+ ##############
422+ self .log ("-- Data --" )
423+ self .log ("Training data :" )
424+ [self .log (f"\n { v } " ) for k , v in self .config .train_data_dict .items ()]
425+ if self .config .eval_volume_dict is not None :
426+ self .log ("Validation data :" )
427+ [
428+ self .log (f"\n { k } : { v } " )
429+ for d in self .config .eval_volume_dict
430+ for k , v in d .items ()
431+ ]
432+
391433 def train (self ):
392434 try :
393435 if self .config is None :
@@ -411,8 +453,9 @@ def train(self):
411453
412454 self .log (f"Using device: { device } " )
413455
414- self .log ("Config:" ) # FIXME log_parameters func instead
415- [self .log (str (a )) for a in self .config .__dict__ .items ()]
456+ # self.log("Config:") # FIXME log_parameters func instead
457+ # [self.log(str(a)) for a in self.config.__dict__.items()]
458+ self .log_parameters ()
416459
417460 self .log ("Initializing training..." )
418461 self .log ("Getting the data" )
@@ -783,11 +826,11 @@ def train(self):
783826 val_in = val_inputs [0 ].detach ().cpu ().numpy ()
784827
785828 display_dict = {
786- "Decoder output " : {
829+ "Reconstruction " : {
787830 "data" : np .squeeze (dec_out_val ),
788831 "cmap" : "gist_earth" ,
789832 },
790- "Encoder output " : {
833+ "Segmentation " : {
791834 "data" : np .squeeze (enc_out_val ),
792835 "cmap" : "turbo" ,
793836 },
@@ -820,7 +863,7 @@ def train(self):
820863 * (self .config .max_epochs / (epoch + 1 ) - 1 )
821864 / 60
822865 )
823- self .log (f"ETA: { eta :.2f } minutes" )
866+ self .log (f"ETA: { eta :.1f } minutes" )
824867 self .log ("-" * 20 )
825868
826869 # Save the model
0 commit comments