Skip to content

Commit 8c6fc5b

Browse files
committed
tried BCE loss + fixed saving errors in training + fixed incorrect plot clearing + tested Trailmap more
1 parent 54df4b5 commit 8c6fc5b

File tree

2 files changed

+20
-13
lines changed

2 files changed

+20
-13
lines changed

src/napari_cellseg3d/model_workers.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -459,6 +459,7 @@ def log_parameters(self):
459459
self.log(f"Loss function is : {str(self.loss_function)}")
460460
self.log(f"Validation is performed every {self.val_interval} epochs")
461461
self.log(f"Batch size is {self.batch_size}")
462+
self.log(f"Learning rate is {self.learning_rate}")
462463

463464
if self.sampling:
464465
self.log(
@@ -661,14 +662,15 @@ def train(self):
661662
)
662663
print("\nDone")
663664

665+
print("Optimizer")
664666
optimizer = torch.optim.Adam(model.parameters(), self.learning_rate)
665667
dice_metric = DiceMetric(include_background=True, reduction="mean")
666668

667669
best_metric = -1
668670
best_metric_epoch = -1
669671

670672
# time = utils.get_date_time()
671-
673+
print("Weights")
672674
if self.weights_path is not None:
673675
if self.weights_path == "use_pretrained":
674676
weights_file = model_class.get_weights_file()

src/napari_cellseg3d/plugin_model_training.py

Lines changed: 17 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import os
22
import warnings
33
import torch
4+
from torch import nn
45
from pathlib import Path
56

67
import matplotlib.pyplot as plt
@@ -174,10 +175,12 @@ def __init__(
174175
"""Data dictionary containing file paths"""
175176
self.stop_requested = False
176177
"""Whether the worker should stop or not"""
178+
self.start_time = ""
177179

178180
self.loss_dict = {
179181
"Dice loss": DiceLoss(sigmoid=True),
180182
"Focal loss": FocalLoss(),
183+
# "BCELoss":nn.BCELoss(),
181184
"Dice-Focal loss": DiceFocalLoss(sigmoid=True, lambda_dice=0.5),
182185
"Generalized Dice loss": GeneralizedDiceLoss(sigmoid=True),
183186
"DiceCELoss": DiceCELoss(sigmoid=True, lambda_ce=0.5),
@@ -226,6 +229,7 @@ def __init__(
226229
self.lbl_val_interv_choice = QLabel("Validation interval : ", self)
227230

228231
self.learning_rate_dict = {
232+
"1e-2": 1e-2,
229233
"1e-3": 1e-3,
230234
"1e-4": 1e-4,
231235
"1e-5": 1e-5,
@@ -238,6 +242,7 @@ def __init__(
238242
) = ui.make_combobox(
239243
self.learning_rate_dict.keys(), label="Learning rate"
240244
)
245+
self.learning_rate_choice.setCurrentIndex(1)
241246

242247
self.augment_choice = ui.make_checkbox("Augment data")
243248

@@ -676,7 +681,7 @@ def start(self):
676681
Returns: Returns empty immediately if the file paths are not set correctly.
677682
678683
"""
679-
684+
self.start_time = utils.get_time_filepath()
680685
if self.stop_requested:
681686
self.log.print_and_log("Worker is already stopping !")
682687
return
@@ -720,8 +725,11 @@ def start(self):
720725

721726
self.results_path = (
722727
self.results_path
723-
+ f"/{model_dict['name']}_results_{utils.get_date_time()}"
728+
+ f"/{model_dict['name']}_results_{self.start_time}"
724729
)
730+
os.makedirs(
731+
self.results_path, exist_ok=False
732+
) # avoid overwrite where possible
725733

726734
if self.use_transfer_choice.isChecked():
727735
if self.custom_weights_choice.isChecked():
@@ -731,10 +739,6 @@ def start(self):
731739
else:
732740
weights_path = None
733741

734-
os.makedirs(
735-
self.results_path, exist_ok=False
736-
) # avoid overwrite where possible
737-
738742
self.log.print_and_log(
739743
f"Notice : Saving results to : {self.results_path}"
740744
)
@@ -782,13 +786,11 @@ def start(self):
782786

783787
def on_start(self):
784788
"""Catches started signal from worker"""
785-
if self.plot_dock is not None:
786-
self._viewer.window.remove_dock_widget(self.plot_dock)
787-
self.plot_dock = None
788789

790+
self.remove_docked_widgets()
789791
self.display_status_report()
790792

791-
self.log.print_and_log(f"Worker started at {utils.get_time()}")
793+
self.log.print_and_log(f"Worker started at {self.start_time}")
792794
self.log.print_and_log("\nWorker is running...")
793795

794796
def on_finish(self):
@@ -801,7 +803,7 @@ def on_finish(self):
801803
self.canvas.figure.savefig(
802804
(
803805
self.results_path
804-
+ f"/final_metric_plots_{utils.get_date_time()}.png"
806+
+ f"/final_metric_plots_{utils.get_time_filepath()}.png"
805807
),
806808
format="png",
807809
)
@@ -817,11 +819,14 @@ def on_finish(self):
817819

818820
self.worker = None
819821
self.empty_cuda_cache()
822+
823+
self.results_path = ""
820824
# self.clean_cache() # trying to fix memory leak
821825

822826
def on_error(self):
823827
"""Catches errored signal from worker"""
824828
self.log.print_and_log(f"WORKER ERRORED at {utils.get_time()}")
829+
self.worker=None
825830
self.empty_cuda_cache()
826831
# self.clean_cache()
827832

@@ -840,7 +845,7 @@ def on_yield(data, widget):
840845
data["weights"],
841846
os.path.join(
842847
widget.results_path,
843-
f"latest_weights_aborted_training_{utils.get_date_time()}.pth",
848+
f"latest_weights_aborted_training_{utils.get_time()}.pth",
844849
),
845850
)
846851
widget.stop_requested = False

0 commit comments

Comments
 (0)