Skip to content

Commit 1eed4ea

Browse files
committed
Plot + log_parameters
1 parent a6964ab commit 1eed4ea

File tree

2 files changed

+53
-7
lines changed

2 files changed

+53
-7
lines changed

napari_cellseg3d/code_models/worker_training.py

Lines changed: 48 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -388,6 +388,48 @@ def _get_data(self):
388388
eval_dataloader = None
389389
return dataloader, eval_dataloader, data_shape
390390

391+
def log_parameters(self):
392+
self.log("*" * 20)
393+
self.log("-- Parameters --")
394+
self.log(f"Device: {self.config.device}")
395+
self.log(f"Batch size: {self.config.batch_size}")
396+
self.log(f"Epochs: {self.config.max_epochs}")
397+
self.log(f"Learning rate: {self.config.learning_rate}")
398+
self.log(f"Validation interval: {self.config.validation_interval}")
399+
if self.config.weights_info.custom:
400+
self.log(f"Custom weights: {self.config.weights_info.path}")
401+
elif self.config.weights_info.use_pretrained:
402+
self.log(f"Pretrained weights: {self.config.weights_info.path}")
403+
if self.config.sampling:
404+
self.log(
405+
f"Using {self.config.num_samples} samples of size {self.config.sample_size}"
406+
)
407+
if self.config.do_augmentation:
408+
self.log("Using data augmentation")
409+
##############
410+
self.log("-- Model --")
411+
self.log(f"Using {self.config.num_classes} classes")
412+
self.log(f"Weight decay: {self.config.weight_decay}")
413+
self.log("* NCuts : ")
414+
self.log(f"- Insensity sigma {self.config.intensity_sigma}")
415+
self.log(f"- Spatial sigma {self.config.spatial_sigma}")
416+
self.log(f"- Radius : {self.config.radius}")
417+
self.log(f"* Reconstruction loss : {self.config.reconstruction_loss}")
418+
self.log(
419+
f"Weighted sum : {self.config.n_cuts_weight}*Ncuts + {self.config.rec_loss_weight}*Reconstruction"
420+
)
421+
##############
422+
self.log("-- Data --")
423+
self.log("Training data :")
424+
[self.log(f"\n{v}") for k, v in self.config.train_data_dict.items()]
425+
if self.config.eval_volume_dict is not None:
426+
self.log("Validation data :")
427+
[
428+
self.log(f"\n{k}: {v}")
429+
for d in self.config.eval_volume_dict
430+
for k, v in d.items()
431+
]
432+
391433
def train(self):
392434
try:
393435
if self.config is None:
@@ -411,8 +453,9 @@ def train(self):
411453

412454
self.log(f"Using device: {device}")
413455

414-
self.log("Config:") # FIXME log_parameters func instead
415-
[self.log(str(a)) for a in self.config.__dict__.items()]
456+
# self.log("Config:") # FIXME log_parameters func instead
457+
# [self.log(str(a)) for a in self.config.__dict__.items()]
458+
self.log_parameters()
416459

417460
self.log("Initializing training...")
418461
self.log("Getting the data")
@@ -783,11 +826,11 @@ def train(self):
783826
val_in = val_inputs[0].detach().cpu().numpy()
784827

785828
display_dict = {
786-
"Decoder output": {
829+
"Reconstruction": {
787830
"data": np.squeeze(dec_out_val),
788831
"cmap": "gist_earth",
789832
},
790-
"Encoder output": {
833+
"Segmentation": {
791834
"data": np.squeeze(enc_out_val),
792835
"cmap": "turbo",
793836
},
@@ -820,7 +863,7 @@ def train(self):
820863
* (self.config.max_epochs / (epoch + 1) - 1)
821864
/ 60
822865
)
823-
self.log(f"ETA: {eta:.2f} minutes")
866+
self.log(f"ETA: {eta:.1f} minutes")
824867
self.log("-" * 20)
825868

826869
# Save the model

napari_cellseg3d/code_plugins/plugin_model_training.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1036,7 +1036,8 @@ def _set_worker_config(
10361036
self.weights_config.path = self.weights_config.path
10371037
self.weights_config.custom = self.custom_weights_choice.isChecked()
10381038
self.weights_config.use_pretrained = (
1039-
not self.use_transfer_choice.isChecked()
1039+
self.use_transfer_choice.isChecked()
1040+
and not self.custom_weights_choice.isChecked()
10401041
)
10411042
deterministic_config = config.DeterministicConfig(
10421043
enabled=self.use_deterministic_choice.isChecked(),
@@ -1436,7 +1437,9 @@ def _plot_loss(
14361437
if metric_name == "Dice metric":
14371438
self._show_plot_max(self.plot_1, y)
14381439
if len(loss_values_1.keys()) > 1:
1439-
self.plot_1.legend(loc="best", fontsize="10", markerscale=0.6)
1440+
self.plot_1.legend(
1441+
loc="lower left", fontsize="10", markerscale=0.6
1442+
)
14401443

14411444
# update plot 2
14421445
if self._is_current_job_supervised():

0 commit comments

Comments
 (0)