Skip to content

Commit 599fb73

Browse files
committed
WIP autosave weights if worker has been canceled
1 parent 93f5ea8 commit 599fb73

File tree

3 files changed

+19
-2
lines changed

3 files changed

+19
-2
lines changed

src/napari_cellseg3d/model_workers.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -768,6 +768,7 @@ def train(self):
768768
"epoch": epoch,
769769
"losses": epoch_loss_values,
770770
"val_metrics": val_metric_values,
771+
"weights": model.state_dict(),
771772
}
772773
yield train_report
773774

@@ -794,8 +795,8 @@ def train(self):
794795
f"at epoch: {best_metric_epoch}"
795796
)
796797
model.to("cpu")
797-
optimizer = None
798-
del optimizer
798+
# optimizer = None
799+
# del optimizer
799800
# del device
800801
# del model_id
801802
# del model_name

src/napari_cellseg3d/plugin_model_inference.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import os
22
import warnings
3+
import numpy as np
34

45
import napari
56
# Qt
@@ -624,6 +625,9 @@ def on_yield(data, widget):
624625
)
625626

626627
if data["instance_labels"] is not None:
628+
629+
widget.log.print_and_log(f"\nNUMBER OF CELLS : {np.amax(data['instance_labels'])}\n")
630+
627631
name = f"instance_labels_{image_id}"
628632
instance_layer = viewer.add_labels(
629633
data["instance_labels"], name=name

src/napari_cellseg3d/plugin_model_training.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import os
22
import warnings
3+
import torch
34
from pathlib import Path
45

56
import matplotlib.pyplot as plt
@@ -168,6 +169,8 @@ def __init__(
168169
"""Training worker for multithreading, should be a TrainingWorker instance from :doc:model_workers.py"""
169170
self.data = None
170171
"""Data dictionary containing file paths"""
172+
self.stop_requested = False
173+
"""Whether the worker should stop or not"""
171174

172175
self.loss_dict = {
173176
"Dice loss": DiceLoss(sigmoid=True),
@@ -645,6 +648,10 @@ def start(self):
645648
646649
"""
647650

651+
if self.stop_requested:
652+
self.log.print_and_log("Worker is already stopping !")
653+
return
654+
648655
if not self.check_ready(): # issues a warning if not ready
649656
err = "Aborting, please set all required paths"
650657
self.log.print_and_log(err)
@@ -734,6 +741,7 @@ def start(self):
734741
self.log.print_and_log(
735742
f"Stop requested at {utils.get_time()}. \nWaiting for next validation step..."
736743
)
744+
self.stop_requested = True
737745
self.btn_start.setText("Stopping... Please wait for next saving")
738746
self.worker.quit()
739747
else:
@@ -795,6 +803,10 @@ def on_yield(data, widget):
795803
)
796804
widget.update_loss_plot(data["losses"], data["val_metrics"])
797805

806+
if widget.stop_requested:
807+
torch.save(data["weights"], os.path.join(widget.results_path, f"latest_weights_aborted_training_{utils.get_date_time()}.pth"))
808+
widget.stop_requested = False
809+
798810
# def clean_cache(self):
799811
# """Attempts to clear memory after training"""
800812
# # del self.worker

0 commit comments

Comments
 (0)