Skip to content

Commit 1709528

Browse files
authored
Merge pull request #19 from AdaptiveMotorControlLab/cy/log_weights_download
Moved weights download + enhancements
2 parents eb33f31 + 8d143d2 commit 1709528

File tree

12 files changed

+259
-204
lines changed

12 files changed

+259
-204
lines changed

napari_cellseg3d/interface.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from typing import Union
33
from typing import List
44

5+
56
from qtpy.QtCore import Qt
67
from qtpy.QtCore import QUrl
78
from qtpy.QtGui import QDesktopServices

napari_cellseg3d/log_utility.py

Lines changed: 38 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import threading
22

3+
from qtpy import QtCore
34
from qtpy.QtGui import QTextCursor
45
from qtpy.QtWidgets import QTextEdit
56

@@ -22,19 +23,53 @@ def __init__(self, parent):
2223

2324
# def receive_log(self, text):
2425
# self.print_and_log(text)
26+
def write(self, message):
27+
self.lock.acquire()
28+
try:
29+
if not hasattr(self, "flag"):
30+
self.flag = False
31+
message = message.replace('\r', '').rstrip()
32+
if message:
33+
method = "replace_last_line" if self.flag else "append"
34+
QtCore.QMetaObject.invokeMethod(self,
35+
method,
36+
QtCore.Qt.QueuedConnection,
37+
QtCore.Q_ARG(str, message))
38+
self.flag = True
39+
else:
40+
self.flag = False
41+
42+
finally:
43+
self.lock.release()
44+
45+
@QtCore.Slot(str)
46+
def replace_last_line(self, text):
47+
self.lock.acquire()
48+
try:
49+
cursor = self.textCursor()
50+
cursor.movePosition(QTextCursor.End)
51+
cursor.select(QTextCursor.BlockUnderCursor)
52+
cursor.removeSelectedText()
53+
cursor.insertBlock()
54+
self.setTextCursor(cursor)
55+
self.insertPlainText(text)
56+
finally:
57+
self.lock.release()
2558

26-
def print_and_log(self, text):
59+
def print_and_log(self, text, printing=True):
2760
"""Utility used to both print to terminal and log text to a QTextEdit
2861
item in a thread-safe manner. Use only for important user info.
2962
3063
Args:
3164
text (str): Text to be printed and logged
65+
printing (bool): Whether to print the message as well or not using print(). Defaults to True.
3266
3367
"""
3468
self.lock.acquire()
3569
try:
36-
print(text)
37-
# causes issue if you clik on terminal (tied to CMD QuickEdit mode)
70+
if printing:
71+
print(text)
72+
# causes issue if you clik on terminal (tied to CMD QuickEdit mode on Windows)
3873
self.moveCursor(QTextCursor.End)
3974
self.insertPlainText(f"\n{text}")
4075
self.verticalScrollBar().setValue(

napari_cellseg3d/model_framework.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
from napari_cellseg3d.models import model_SegResNet as SegResNet
1616
from napari_cellseg3d.models import model_TRAILMAP as TRAILMAP
1717
from napari_cellseg3d.models import model_VNet as VNet
18-
from napari_cellseg3d.models import TRAILMAP_MS as TMAP
18+
from napari_cellseg3d.models import model_TRAILMAP_MS as TRAILMAP_MS
1919
from napari_cellseg3d.plugin_base import BasePluginFolder
2020

2121
warnings.formatwarning = utils.format_Warning
@@ -62,8 +62,8 @@ def __init__(self, viewer: "napari.viewer.Viewer"):
6262
self.models_dict = {
6363
"VNet": VNet,
6464
"SegResNet": SegResNet,
65-
"TRAILMAP pre-trained": TRAILMAP,
66-
"TRAILMAP_MS": TMAP,
65+
"TRAILMAP": TRAILMAP,
66+
"TRAILMAP_MS": TRAILMAP_MS,
6767
}
6868
"""dict: dictionary of available models, with string for widget display as key
6969

napari_cellseg3d/model_workers.py

Lines changed: 79 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,13 @@
11
import os
22
import platform
33
from pathlib import Path
4+
import importlib.util
5+
from typing import Optional
46

57
import numpy as np
8+
from tifffile import imwrite
69
import torch
10+
from tqdm import tqdm
711

812
# MONAI
913
from monai.data import CacheDataset
@@ -37,9 +41,10 @@
3741

3842
# Qt
3943
from qtpy.QtCore import Signal
40-
from tifffile import imwrite
44+
4145

4246
from napari_cellseg3d import utils
47+
from napari_cellseg3d import log_utility
4348

4449
# local
4550
from napari_cellseg3d.model_instance_seg import binary_connected
@@ -57,9 +62,61 @@
5762
# https://napari-staging-site.github.io/guides/stable/threading.html
5863

5964
WEIGHTS_DIR = os.path.dirname(os.path.realpath(__file__)) + str(
60-
Path("/models/saved_weights")
65+
Path("/models/pretrained")
6166
)
6267

68+
class WeightsDownloader:
69+
70+
def __init__(self, log_widget: Optional[log_utility.Log]= None):
71+
self.log_widget = log_widget
72+
73+
def download_weights(self,model_name: str):
74+
"""
75+
Downloads a specific pretrained model.
76+
This code is adapted from DeepLabCut with permission from MWMathis.
77+
78+
Args:
79+
model_name (str): name of the model to download
80+
"""
81+
import json
82+
import tarfile
83+
import urllib.request
84+
85+
def show_progress(count, block_size, total_size):
86+
pbar.update(block_size)
87+
88+
cellseg3d_path = os.path.split(
89+
importlib.util.find_spec("napari_cellseg3d").origin
90+
)[0]
91+
pretrained_folder_path = os.path.join(
92+
cellseg3d_path, "models", "pretrained"
93+
)
94+
json_path = os.path.join(
95+
pretrained_folder_path, "pretrained_model_urls.json"
96+
)
97+
with open(json_path) as f:
98+
neturls = json.load(f)
99+
if model_name in neturls.keys():
100+
url = neturls[model_name]
101+
response = urllib.request.urlopen(url)
102+
103+
start_message = f"Downloading the model from the M.W. Mathis Lab server {url}...."
104+
total_size = int(response.getheader("Content-Length"))
105+
if self.log_widget is None:
106+
print(start_message)
107+
pbar = tqdm(unit="B", total=total_size, position=0)
108+
else:
109+
self.log_widget.print_and_log(start_message)
110+
pbar = tqdm(unit="B", total=total_size, position=0, file=self.log_widget)
111+
112+
filename, _ = urllib.request.urlretrieve(url, reporthook=show_progress)
113+
with tarfile.open(filename, mode="r:gz") as tar:
114+
tar.extractall(pretrained_folder_path)
115+
else:
116+
raise ValueError(
117+
f"Unknown model: {model_name}. Should be one of {', '.join(neturls)}"
118+
)
119+
63120

64121
class LogSignal(WorkerBaseSignals):
65122
"""Signal to send messages to be logged from another thread.
@@ -142,9 +199,12 @@ def __init__(
142199
self.window_infer_size = window_infer_size
143200
self.keep_on_cpu = keep_on_cpu
144201
self.stats_to_csv = stats_csv
145-
146202
"""These attributes are all arguments of :py:func:~inference, please see that for reference"""
147203

204+
self.downloader = WeightsDownloader()
205+
"""Download utility"""
206+
207+
148208
@staticmethod
149209
def create_inference_dict(images_filepaths):
150210
"""Create a dict for MONAI with "image" keys with all image paths in :py:attr:`~self.images_filepaths`
@@ -154,6 +214,9 @@ def create_inference_dict(images_filepaths):
154214
data_dicts = [{"image": image_name} for image_name in images_filepaths]
155215
return data_dicts
156216

217+
def set_download_log(self, widget):
218+
self.downloader.log_widget = widget
219+
157220
def log(self, text):
158221
"""Sends a signal that ``text`` should be logged
159222
@@ -233,8 +296,8 @@ def inference(self):
233296
"""
234297
sys = platform.system()
235298
print(f"OS is {sys}")
236-
if sys == "Darwin": # required for macOS ?
237-
torch.set_num_threads(1)
299+
if sys == "Darwin":
300+
torch.set_num_threads(1) # required for threading on macOS ?
238301
self.log("Number of threads has been set to 1 for macOS")
239302

240303
images_dict = self.create_inference_dict(self.images_filepaths)
@@ -260,7 +323,7 @@ def inference(self):
260323
model = self.model_dict["class"].get_net()
261324
if self.model_dict["name"] == "SegResNet":
262325
model = self.model_dict["class"].get_net()(
263-
input_image_size=[dims, dims, dims], # TODO FIX !
326+
input_image_size=[dims, dims, dims], # TODO FIX ! find a better way & remove model-specific code
264327
out_channels=1,
265328
# dropout_prob=0.3,
266329
)
@@ -304,12 +367,13 @@ def inference(self):
304367
# print(weights)
305368
self.log(
306369
"\nLoading weights..."
307-
) # TODO add try/except for invalid weights
370+
) # TODO add try/except for invalid weights for proper reset
308371

309372
if self.weights_dict["custom"]:
310373
weights = self.weights_dict["path"]
311374
else:
312-
weights = os.path.join(WEIGHTS_DIR, self.weights_dict["path"])
375+
self.downloader.download_weights(self.model_dict["name"])
376+
weights = os.path.join(WEIGHTS_DIR, self.model_dict["class"].get_weights_file())
313377

314378
model.load_state_dict(
315379
torch.load(
@@ -544,8 +608,6 @@ def __init__(
544608
545609
546610
"""
547-
548-
print("init")
549611
super().__init__(self.train)
550612
self._signals = LogSignal()
551613
self.log_signal = self._signals.log_signal
@@ -571,7 +633,11 @@ def __init__(
571633

572634
self.train_files = []
573635
self.val_files = []
574-
print("end init")
636+
637+
self.downloader = WeightsDownloader()
638+
639+
def set_download_log(self, widget):
640+
self.downloader.log_widget = widget
575641

576642
def log(self, text):
577643
"""Sends a signal that ``text`` should be logged
@@ -629,7 +695,7 @@ def log_parameters(self):
629695
self.log("-" * 20)
630696

631697
def train(self):
632-
"""Trains the Pytorch model for the given number of epochs, with the selected model and data,
698+
"""Trains the PyTorch model for the given number of epochs, with the selected model and data,
633699
using the chosen batch size, validation interval, loss function, and number of samples.
634700
Will perform validation once every :py:obj:`val_interval` and save results if the mean dice is better
635701
@@ -838,6 +904,7 @@ def train(self):
838904
if self.weights_path is not None:
839905
if self.weights_path == "use_pretrained":
840906
weights_file = model_class.get_weights_file()
907+
self.downloader.download_weights(model_name)
841908
weights = os.path.join(WEIGHTS_DIR, weights_file)
842909
self.weights_path = weights
843910
else:

0 commit comments

Comments
 (0)