Skip to content

Commit 487560d

Browse files
committed
Worker warning + weights error handling
- Added signal to emit warning from worker (not functional for now) - Added clearer warning when weights are not compatible in training
1 parent e26285c commit 487560d

File tree

4 files changed

+35
-3
lines changed

4 files changed

+35
-3
lines changed

napari_cellseg3d/log_utility.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import threading
2+
import warnings
23

34
from qtpy import QtCore
45
from qtpy.QtGui import QTextCursor
@@ -79,3 +80,10 @@ def print_and_log(self, text, printing=True):
7980
)
8081
finally:
8182
self.lock.release()
83+
84+
def warn(self, warning):
85+
self.lock.acquire()
86+
try:
87+
warnings.warn(warning)
88+
finally:
89+
self.lock.release()

napari_cellseg3d/model_workers.py

Lines changed: 25 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -149,10 +149,12 @@ def show_progress(count, block_size, total_size):
149149
class LogSignal(WorkerBaseSignals):
150150
"""Signal to send messages to be logged from another thread.
151151
152-
Separate from Worker instances as indicated `here`_"""
152+
Separate from Worker instances as indicated `here`_""" # TODO link ?
153153

154154
log_signal = Signal(str)
155155
"""qtpy.QtCore.Signal: signal to be sent when some text should be logged"""
156+
warn_signal = Signal(str)
157+
"""qtpy.QtCore.Signal: signal to be sent when some warning should be emitted in main thread"""
156158

157159
# Should not be an instance variable but a class variable, not defined in __init__, see
158160
# https://stackoverflow.com/questions/2970312/pyqt4-qtcore-pyqtsignal-object-has-no-attribute-connect
@@ -213,6 +215,7 @@ def __init__(
213215
super().__init__(self.inference)
214216
self._signals = LogSignal() # add custom signals
215217
self.log_signal = self._signals.log_signal
218+
self.warn_signal = self._signals.warn_signal
216219
###########################################
217220
###########################################
218221
self.device = device
@@ -252,6 +255,10 @@ def log(self, text):
252255
"""
253256
self.log_signal.emit(text)
254257

258+
def warn(self, warning):
259+
"""Sends a warning to main thread"""
260+
self.warn_signal.emit(warning)
261+
255262
def log_parameters(self):
256263

257264
self.log("-" * 20)
@@ -647,7 +654,10 @@ def __init__(
647654
super().__init__(self.train)
648655
self._signals = LogSignal()
649656
self.log_signal = self._signals.log_signal
657+
self.warn_signal = self._signals.warn_signal
650658

659+
self._weight_error = False
660+
#############################################
651661
self.device = device
652662
self.model_dict = model_dict
653663
self.weights_path = weights_path
@@ -669,7 +679,7 @@ def __init__(
669679

670680
self.train_files = []
671681
self.val_files = []
672-
682+
#######################################
673683
self.downloader = WeightsDownloader()
674684

675685
def set_download_log(self, widget):
@@ -683,6 +693,10 @@ def log(self, text):
683693
"""
684694
self.log_signal.emit(text)
685695

696+
def warn(self, warning):
697+
"""Sends a warning to main thread"""
698+
self.warn_signal.emit(warning)
699+
686700
def log_parameters(self):
687701

688702
self.log("-" * 20)
@@ -726,6 +740,13 @@ def log_parameters(self):
726740

727741
if self.weights_path is not None:
728742
self.log(f"Using weights from : {self.weights_path}")
743+
if self._weight_error:
744+
self.log(
745+
">>>>>>>>>>>>>>>>>\n"
746+
"WARNING:\nChosen weights were incompatible with the model,\n"
747+
"the model will be trained from random weights\n"
748+
"<<<<<<<<<<<<<<<<<\n"
749+
)
729750

730751
# self.log("\n")
731752
self.log("-" * 20)
@@ -959,7 +980,8 @@ def train(self):
959980
"the model will be trained from random weights"
960981
)
961982
self.log(warn)
962-
warnings.warn(warn)
983+
self.warn(warn)
984+
self._weight_error = True
963985

964986
if self.device.type == "cuda":
965987
self.log("\nUsing GPU :")

napari_cellseg3d/plugin_model_inference.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -599,6 +599,7 @@ def start(self):
599599

600600
self.worker.started.connect(self.on_start)
601601
self.worker.log_signal.connect(self.log.print_and_log)
602+
self.worker.warn_signal.connect(self.log.warn)
602603
self.worker.yielded.connect(yield_connect_show_res)
603604
self.worker.errored.connect(
604605
yield_connect_show_res

napari_cellseg3d/plugin_model_training.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -857,6 +857,7 @@ def start(self):
857857
[btn.setVisible(False) for btn in self.close_buttons]
858858

859859
self.worker.log_signal.connect(self.log.print_and_log)
860+
self.worker.warn_signal.connect(self.log.warn)
860861

861862
self.worker.started.connect(self.on_start)
862863

0 commit comments

Comments
 (0)