Skip to content

Commit 272449e

Browse files
committed
Moved weights download + enhancements
- Weight download now always happens in worker thread to avoid freezing UI - Download progress bar now shows on the log in the plugin - Models .py files changed to be as simple as possible (so users can more easily add custom models)
1 parent fc4af1e commit 272449e

File tree

11 files changed

+125
-60
lines changed

11 files changed

+125
-60
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ __pycache__/
1010
*.tif
1111
napari_cellseg3d/_tests/res/*.csv
1212
*.pth
13+
*.db
1314

1415
# Distribution / packaging
1516
.Python

napari_cellseg3d/interface.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from typing import Optional
22
from typing import Union
33

4+
45
from qtpy.QtCore import Qt
56
from qtpy.QtCore import QUrl
67
from qtpy.QtGui import QDesktopServices

napari_cellseg3d/log_utility.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import threading
22

3+
from qtpy import QtCore
34
from qtpy.QtGui import QTextCursor
45
from qtpy.QtWidgets import QTextEdit
56

@@ -22,6 +23,38 @@ def __init__(self, parent):
2223

2324
# def receive_log(self, text):
2425
# self.print_and_log(text)
26+
def write(self, message):
27+
self.lock.acquire()
28+
try:
29+
if not hasattr(self, "flag"):
30+
self.flag = False
31+
message = message.replace('\r', '').rstrip()
32+
if message:
33+
method = "replace_last_line" if self.flag else "append"
34+
QtCore.QMetaObject.invokeMethod(self,
35+
method,
36+
QtCore.Qt.QueuedConnection,
37+
QtCore.Q_ARG(str, message))
38+
self.flag = True
39+
else:
40+
self.flag = False
41+
42+
finally:
43+
self.lock.release()
44+
45+
@QtCore.Slot(str)
46+
def replace_last_line(self, text):
47+
self.lock.acquire()
48+
try:
49+
cursor = self.textCursor()
50+
cursor.movePosition(QTextCursor.End)
51+
cursor.select(QTextCursor.BlockUnderCursor)
52+
cursor.removeSelectedText()
53+
cursor.insertBlock()
54+
self.setTextCursor(cursor)
55+
self.insertPlainText(text)
56+
finally:
57+
self.lock.release()
2558

2659
def print_and_log(self, text):
2760
"""Utility used to both print to terminal and log text to a QTextEdit

napari_cellseg3d/model_workers.py

Lines changed: 78 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,13 @@
11
import os
22
import platform
33
from pathlib import Path
4+
import importlib.util
5+
from typing import Optional
46

57
import numpy as np
8+
from tifffile import imwrite
69
import torch
10+
from tqdm import tqdm
711

812
# MONAI
913
from monai.data import CacheDataset
@@ -37,9 +41,10 @@
3741

3842
# Qt
3943
from qtpy.QtCore import Signal
40-
from tifffile import imwrite
44+
4145

4246
from napari_cellseg3d import utils
47+
from napari_cellseg3d import log_utility
4348

4449
# local
4550
from napari_cellseg3d.model_instance_seg import binary_connected
@@ -57,9 +62,61 @@
5762
# https://napari-staging-site.github.io/guides/stable/threading.html
5863

5964
WEIGHTS_DIR = os.path.dirname(os.path.realpath(__file__)) + str(
60-
Path("/models/saved_weights")
65+
Path("/models/pretrained")
6166
)
6267

68+
class WeightsDownloader:
69+
70+
def __init__(self, log_widget: Optional[log_utility.Log]= None):
71+
self.log_widget = log_widget
72+
73+
def download_weights(self,model_name: str):
74+
"""
75+
Downloads a specific pretrained model.
76+
This code is adapted from DeepLabCut with permission from MWMathis.
77+
78+
Args:
79+
model_name (str): name of the model to download
80+
"""
81+
import json
82+
import tarfile
83+
import urllib.request
84+
85+
def show_progress(count, block_size, total_size):
86+
pbar.update(block_size)
87+
88+
cellseg3d_path = os.path.split(
89+
importlib.util.find_spec("napari_cellseg3d").origin
90+
)[0]
91+
pretrained_folder_path = os.path.join(
92+
cellseg3d_path, "models", "pretrained"
93+
)
94+
json_path = os.path.join(
95+
pretrained_folder_path, "pretrained_model_urls.json"
96+
)
97+
with open(json_path) as f:
98+
neturls = json.load(f)
99+
if model_name in neturls.keys():
100+
url = neturls[model_name]
101+
response = urllib.request.urlopen(url)
102+
103+
start_message = f"Downloading the model from the M.W. Mathis Lab server {url}...."
104+
total_size = int(response.getheader("Content-Length"))
105+
if self.log_widget is None:
106+
print(start_message)
107+
pbar = tqdm(unit="B", total=total_size, position=0)
108+
else:
109+
self.log_widget.print_and_log(start_message)
110+
pbar = tqdm(unit="B", total=total_size, position=0, file=self.log_widget)
111+
112+
filename, _ = urllib.request.urlretrieve(url, reporthook=show_progress)
113+
with tarfile.open(filename, mode="r:gz") as tar:
114+
tar.extractall(pretrained_folder_path)
115+
else:
116+
raise ValueError(
117+
f"Unknown model. `modelname` should be one of {', '.join(neturls)}"
118+
)
119+
63120

64121
class LogSignal(WorkerBaseSignals):
65122
"""Signal to send messages to be logged from another thread.
@@ -142,9 +199,12 @@ def __init__(
142199
self.window_infer_size = window_infer_size
143200
self.keep_on_cpu = keep_on_cpu
144201
self.stats_to_csv = stats_csv
145-
146202
"""These attributes are all arguments of :py:func:~inference, please see that for reference"""
147203

204+
self.downloader = WeightsDownloader()
205+
"""Download utility"""
206+
207+
148208
@staticmethod
149209
def create_inference_dict(images_filepaths):
150210
"""Create a dict for MONAI with "image" keys with all image paths in :py:attr:`~self.images_filepaths`
@@ -154,6 +214,9 @@ def create_inference_dict(images_filepaths):
154214
data_dicts = [{"image": image_name} for image_name in images_filepaths]
155215
return data_dicts
156216

217+
def set_download_log(self, widget):
218+
self.downloader.log_widget = widget
219+
157220
def log(self, text):
158221
"""Sends a signal that ``text`` should be logged
159222
@@ -233,8 +296,8 @@ def inference(self):
233296
"""
234297
sys = platform.system()
235298
print(f"OS is {sys}")
236-
if sys == "Darwin": # required for macOS ?
237-
torch.set_num_threads(1)
299+
if sys == "Darwin":
300+
torch.set_num_threads(1) # required for threading on macOS ?
238301
self.log("Number of threads has been set to 1 for macOS")
239302

240303
images_dict = self.create_inference_dict(self.images_filepaths)
@@ -260,7 +323,7 @@ def inference(self):
260323
model = self.model_dict["class"].get_net()
261324
if self.model_dict["name"] == "SegResNet":
262325
model = self.model_dict["class"].get_net()(
263-
input_image_size=[dims, dims, dims], # TODO FIX !
326+
input_image_size=[dims, dims, dims], # TODO FIX ! find a better way & remove model-specific code
264327
out_channels=1,
265328
# dropout_prob=0.3,
266329
)
@@ -304,12 +367,13 @@ def inference(self):
304367
# print(weights)
305368
self.log(
306369
"\nLoading weights..."
307-
) # TODO add try/except for invalid weights
370+
) # TODO add try/except for invalid weights for proper reset
308371

309372
if self.weights_dict["custom"]:
310373
weights = self.weights_dict["path"]
311374
else:
312-
weights = os.path.join(WEIGHTS_DIR, self.weights_dict["path"])
375+
self.downloader.download_weights(self.model_dict["name"])
376+
weights = os.path.join(WEIGHTS_DIR, self.model_dict["class"].get_weights_file())
313377

314378
model.load_state_dict(
315379
torch.load(
@@ -544,8 +608,6 @@ def __init__(
544608
545609
546610
"""
547-
548-
print("init")
549611
super().__init__(self.train)
550612
self._signals = LogSignal()
551613
self.log_signal = self._signals.log_signal
@@ -571,7 +633,11 @@ def __init__(
571633

572634
self.train_files = []
573635
self.val_files = []
574-
print("end init")
636+
637+
self.downloader = WeightsDownloader()
638+
639+
def set_download_log(self, widget):
640+
self.downloader.log_widget = widget
575641

576642
def log(self, text):
577643
"""Sends a signal that ``text`` should be logged
@@ -838,6 +904,7 @@ def train(self):
838904
if self.weights_path is not None:
839905
if self.weights_path == "use_pretrained":
840906
weights_file = model_class.get_weights_file()
907+
self.downloader.download_weights(model_name)
841908
weights = os.path.join(WEIGHTS_DIR, weights_file)
842909
self.weights_path = weights
843910
else:

napari_cellseg3d/models/TRAILMAP_MS.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,7 @@
88

99
def get_weights_file():
1010
# model additionally trained on Mathis/Wyss mesoSPIM data
11-
target_dir = utils.download_model("TRAILMAP_MS")
12-
return os.path.join(target_dir, "TRAILMAP_MS_best_metric_epoch_26.pth")
11+
return "TRAILMAP_MS_best_metric_epoch_26.pth"
1312

1413

1514
def get_net():

napari_cellseg3d/models/model_SegResNet.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,7 @@ def get_net():
1010

1111

1212
def get_weights_file():
13-
target_dir = utils.download_model("SegResNet")
14-
return os.path.join(target_dir, "SegResNet.pth")
13+
return "SegResNet.pth"
1514

1615

1716
def get_output(model, input):

napari_cellseg3d/models/model_TRAILMAP.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,7 @@
66

77
def get_weights_file():
88
# original model from Liqun Luo lab, transfered to pytorch
9-
target_dir = utils.download_model("TRAILMAP")
10-
return os.path.join(target_dir, "TRAILMAP_PyTorch.pth")
9+
return "TRAILMAP_PyTorch.pth"
1110

1211

1312
def get_net():

napari_cellseg3d/models/model_VNet.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,7 @@ def get_net():
1111

1212

1313
def get_weights_file():
14-
target_dir = utils.download_model("VNet")
15-
return os.path.join(target_dir, "VNet_40e.pth")
14+
return "VNet_40e.pth"
1615

1716

1817
def get_output(model, input):

napari_cellseg3d/plugin_model_inference.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -534,7 +534,6 @@ def start(self):
534534
else:
535535
weights_dict = {
536536
"custom": False,
537-
"path": self.get_model(model_key).get_weights_file(),
538537
}
539538

540539
if self.anisotropy_wdgt.is_enabled():
@@ -591,6 +590,7 @@ def start(self):
591590
keep_on_cpu=self.keep_on_cpu,
592591
stats_csv=self.stats_to_csv,
593592
)
593+
self.worker.set_download_log(self.log)
594594

595595
yield_connect_show_res = lambda data: self.on_yield(
596596
data,

napari_cellseg3d/plugin_model_training.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -851,6 +851,7 @@ def start(self):
851851
do_augmentation=self.augment_choice.isChecked(),
852852
deterministic=seed_dict,
853853
)
854+
self.worker.set_download_log(self.log)
854855

855856
[btn.setVisible(False) for btn in self.close_buttons]
856857

@@ -988,6 +989,11 @@ def on_yield(data, widget):
988989
def make_csv(self):
989990

990991
size_column = range(1, self.max_epochs + 1)
992+
993+
if len(self.loss_values) == 0 or self.loss_values is None:
994+
warnings.warn("No loss values to add to csv !")
995+
return
996+
991997
self.df = pd.DataFrame(
992998
{
993999
"epoch": size_column,

0 commit comments

Comments
 (0)