11import os
22import platform
33from pathlib import Path
4+ import importlib .util
5+ from typing import Optional
46
57import numpy as np
8+ from tifffile import imwrite
69import torch
10+ from tqdm import tqdm
711
812# MONAI
913from monai .data import CacheDataset
3741
3842# Qt
3943from qtpy .QtCore import Signal
40- from tifffile import imwrite
44+
4145
4246from napari_cellseg3d import utils
47+ from napari_cellseg3d import log_utility
4348
4449# local
4550from napari_cellseg3d .model_instance_seg import binary_connected
5762# https://napari-staging-site.github.io/guides/stable/threading.html
5863
5964WEIGHTS_DIR = os .path .dirname (os .path .realpath (__file__ )) + str (
60- Path ("/models/saved_weights " )
65+ Path ("/models/pretrained " )
6166)
6267
68+ class WeightsDownloader :
69+
70+ def __init__ (self , log_widget : Optional [log_utility .Log ]= None ):
71+ self .log_widget = log_widget
72+
73+ def download_weights (self ,model_name : str ):
74+ """
75+ Downloads a specific pretrained model.
76+ This code is adapted from DeepLabCut with permission from MWMathis.
77+
78+ Args:
79+ model_name (str): name of the model to download
80+ """
81+ import json
82+ import tarfile
83+ import urllib .request
84+
85+ def show_progress (count , block_size , total_size ):
86+ pbar .update (block_size )
87+
88+ cellseg3d_path = os .path .split (
89+ importlib .util .find_spec ("napari_cellseg3d" ).origin
90+ )[0 ]
91+ pretrained_folder_path = os .path .join (
92+ cellseg3d_path , "models" , "pretrained"
93+ )
94+ json_path = os .path .join (
95+ pretrained_folder_path , "pretrained_model_urls.json"
96+ )
97+ with open (json_path ) as f :
98+ neturls = json .load (f )
99+ if model_name in neturls .keys ():
100+ url = neturls [model_name ]
101+ response = urllib .request .urlopen (url )
102+
103+ start_message = f"Downloading the model from the M.W. Mathis Lab server { url } ...."
104+ total_size = int (response .getheader ("Content-Length" ))
105+ if self .log_widget is None :
106+ print (start_message )
107+ pbar = tqdm (unit = "B" , total = total_size , position = 0 )
108+ else :
109+ self .log_widget .print_and_log (start_message )
110+ pbar = tqdm (unit = "B" , total = total_size , position = 0 , file = self .log_widget )
111+
112+ filename , _ = urllib .request .urlretrieve (url , reporthook = show_progress )
113+ with tarfile .open (filename , mode = "r:gz" ) as tar :
114+ tar .extractall (pretrained_folder_path )
115+ else :
116+ raise ValueError (
117+ f"Unknown model. `modelname` should be one of { ', ' .join (neturls )} "
118+ )
119+
63120
64121class LogSignal (WorkerBaseSignals ):
65122 """Signal to send messages to be logged from another thread.
@@ -142,9 +199,12 @@ def __init__(
142199 self .window_infer_size = window_infer_size
143200 self .keep_on_cpu = keep_on_cpu
144201 self .stats_to_csv = stats_csv
145-
146202 """These attributes are all arguments of :py:func:~inference, please see that for reference"""
147203
204+ self .downloader = WeightsDownloader ()
205+ """Download utility"""
206+
207+
148208 @staticmethod
149209 def create_inference_dict (images_filepaths ):
150210 """Create a dict for MONAI with "image" keys with all image paths in :py:attr:`~self.images_filepaths`
@@ -154,6 +214,9 @@ def create_inference_dict(images_filepaths):
154214 data_dicts = [{"image" : image_name } for image_name in images_filepaths ]
155215 return data_dicts
156216
217+ def set_download_log (self , widget ):
218+ self .downloader .log_widget = widget
219+
157220 def log (self , text ):
158221 """Sends a signal that ``text`` should be logged
159222
@@ -233,8 +296,8 @@ def inference(self):
233296 """
234297 sys = platform .system ()
235298 print (f"OS is { sys } " )
236- if sys == "Darwin" : # required for macOS ?
237- torch .set_num_threads (1 )
299+ if sys == "Darwin" :
300+ torch .set_num_threads (1 ) # required for threading on macOS ?
238301 self .log ("Number of threads has been set to 1 for macOS" )
239302
240303 images_dict = self .create_inference_dict (self .images_filepaths )
@@ -260,7 +323,7 @@ def inference(self):
260323 model = self .model_dict ["class" ].get_net ()
261324 if self .model_dict ["name" ] == "SegResNet" :
262325 model = self .model_dict ["class" ].get_net ()(
263- input_image_size = [dims , dims , dims ], # TODO FIX !
326+ input_image_size = [dims , dims , dims ], # TODO FIX ! find a better way & remove model-specific code
264327 out_channels = 1 ,
265328 # dropout_prob=0.3,
266329 )
@@ -304,12 +367,13 @@ def inference(self):
304367 # print(weights)
305368 self .log (
306369 "\n Loading weights..."
307- ) # TODO add try/except for invalid weights
370+ ) # TODO add try/except for invalid weights for proper reset
308371
309372 if self .weights_dict ["custom" ]:
310373 weights = self .weights_dict ["path" ]
311374 else :
312- weights = os .path .join (WEIGHTS_DIR , self .weights_dict ["path" ])
375+ self .downloader .download_weights (self .model_dict ["name" ])
376+ weights = os .path .join (WEIGHTS_DIR , self .model_dict ["class" ].get_weights_file ())
313377
314378 model .load_state_dict (
315379 torch .load (
@@ -544,8 +608,6 @@ def __init__(
544608
545609
546610 """
547-
548- print ("init" )
549611 super ().__init__ (self .train )
550612 self ._signals = LogSignal ()
551613 self .log_signal = self ._signals .log_signal
@@ -571,7 +633,11 @@ def __init__(
571633
572634 self .train_files = []
573635 self .val_files = []
574- print ("end init" )
636+
637+ self .downloader = WeightsDownloader ()
638+
639+ def set_download_log (self , widget ):
640+ self .downloader .log_widget = widget
575641
576642 def log (self , text ):
577643 """Sends a signal that ``text`` should be logged
@@ -838,6 +904,7 @@ def train(self):
838904 if self .weights_path is not None :
839905 if self .weights_path == "use_pretrained" :
840906 weights_file = model_class .get_weights_file ()
907+ self .downloader .download_weights (model_name )
841908 weights = os .path .join (WEIGHTS_DIR , weights_file )
842909 self .weights_path = weights
843910 else :
0 commit comments