Skip to content

Commit 7f3a118

Browse files
committed
Disable WANDB for now + log param tweaks
1 parent 79724dd commit 7f3a118

File tree

2 files changed

+20
-27
lines changed

2 files changed

+20
-27
lines changed

napari_cellseg3d/code_models/worker_training.py

Lines changed: 19 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -68,8 +68,6 @@
6868
logger.debug(f"PRETRAINED WEIGHT DIR LOCATION : {PRETRAINED_WEIGHTS_DIR}")
6969

7070
try:
71-
import wandb
72-
7371
WANDB_INSTALLED = True
7472
except ImportError:
7573
logger.warning(
@@ -411,25 +409,25 @@ def log_parameters(self):
411409
self.log(f"Using {self.config.num_classes} classes")
412410
self.log(f"Weight decay: {self.config.weight_decay}")
413411
self.log("* NCuts : ")
414-
self.log(f"- Insensity sigma {self.config.intensity_sigma}")
412+
self.log(f"- Intensity sigma {self.config.intensity_sigma}")
415413
self.log(f"- Spatial sigma {self.config.spatial_sigma}")
416414
self.log(f"- Radius : {self.config.radius}")
417415
self.log(f"* Reconstruction loss : {self.config.reconstruction_loss}")
418416
self.log(
419-
f"Weighted sum : {self.config.n_cuts_weight}*Ncuts + {self.config.rec_loss_weight}*Reconstruction"
417+
f"Weighted sum : {self.config.n_cuts_weight}*NCuts + {self.config.rec_loss_weight}*Reconstruction"
420418
)
421419
##############
422420
self.log("-- Data --")
423421
self.log("Training data :")
424422
[
425-
self.log(f"\n{v}")
423+
self.log(f"{v}")
426424
for d in self.config.train_data_dict
427425
for k, v in d.items()
428426
]
429427
if self.config.eval_volume_dict is not None:
430428
self.log("Validation data :")
431429
[
432-
self.log(f"\n{k}: {v}")
430+
self.log(f"{k}: {v}")
433431
for d in self.config.eval_volume_dict
434432
for k, v in d.items()
435433
]
@@ -443,9 +441,9 @@ def train(self):
443441
set_track_meta(False)
444442
##############
445443
# if WANDB_INSTALLED:
446-
# wandb.init(
447-
# config=WANDB_CONFIG, project="WNet-benchmark", mode=WANDB_MODE
448-
# )
444+
# wandb.init(
445+
# config=WANDB_CONFIG, project="WNet-benchmark", mode=WANDB_MODE
446+
# )
449447

450448
set_determinism(
451449
seed=self.config.deterministic_config.seed
@@ -455,12 +453,8 @@ def train(self):
455453
normalize_function = utils.remap_image
456454
device = self.config.device
457455

458-
self.log(f"Using device: {device}")
459-
460-
# self.log("Config:") # FIXME log_parameters func instead
461-
# [self.log(str(a)) for a in self.config.__dict__.items()]
456+
# self.log(f"Using device: {device}")
462457
self.log_parameters()
463-
464458
self.log("Initializing training...")
465459
self.log("Getting the data")
466460

@@ -473,7 +467,6 @@ def train(self):
473467
# Training the model #
474468
###################################################
475469
self.log("Initializing the model:")
476-
477470
self.log("- Getting the model")
478471
# Initialize the model
479472
model = WNet(
@@ -494,8 +487,8 @@ def train(self):
494487
)
495488
)
496489

497-
if WANDB_INSTALLED:
498-
wandb.watch(model, log_freq=100)
490+
# if WANDB_INSTALLED:
491+
# wandb.watch(model, log_freq=100)
499492

500493
if self.config.weights_info.custom:
501494
if self.config.weights_info.use_pretrained:
@@ -619,10 +612,10 @@ def train(self):
619612
)
620613

621614
epoch_rec_loss += reconstruction_loss.item()
622-
if WANDB_INSTALLED:
623-
wandb.log(
624-
{"Reconstruction loss": reconstruction_loss.item()}
625-
)
615+
# if WANDB_INSTALLED:
616+
# wandb.log(
617+
# {"Reconstruction loss": reconstruction_loss.item()}
618+
# )
626619

627620
# Backward pass for the reconstruction loss
628621
optimizer.zero_grad()
@@ -631,8 +624,8 @@ def train(self):
631624

632625
loss = alpha * Ncuts + beta * reconstruction_loss
633626
epoch_loss += loss.item()
634-
if WANDB_INSTALLED:
635-
wandb.log({"Weighted sum of losses": loss.item()})
627+
# if WANDB_INSTALLED:
628+
# wandb.log({"Weighted sum of losses": loss.item()})
636629
loss.backward(loss)
637630
optimizer.step()
638631

@@ -818,9 +811,9 @@ def train(self):
818811
self.log(f"Saving new best model to {save_path}")
819812
torch.save(model.state_dict(), save_path)
820813

821-
if WANDB_INSTALLED:
822-
# log validation dice score for each validation round
823-
wandb.log({"val/dice_metric": metric})
814+
# if WANDB_INSTALLED:
815+
# log validation dice score for each validation round
816+
# wandb.log({"val/dice_metric": metric})
824817

825818
dec_out_val = (
826819
val_decoder_outputs[0].detach().cpu().numpy()

napari_cellseg3d/code_plugins/plugin_model_training.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1402,7 +1402,7 @@ def _show_plot_max(self, plot, y):
14021402
x_max,
14031403
dice_max,
14041404
c="r",
1405-
label="Max. Dice.",
1405+
label="Max. Dice",
14061406
zorder=5,
14071407
)
14081408

0 commit comments

Comments
 (0)