11import os
22import platform
33from pathlib import Path
4+ import importlib .util
5+ from typing import Optional
46
57import numpy as np
8+ from tifffile import imwrite
69import torch
10+ from tqdm import tqdm
711
812# MONAI
913from monai .data import CacheDataset
3741
3842# Qt
3943from qtpy .QtCore import Signal
40- from tifffile import imwrite
44+
4145
4246from napari_cellseg3d import utils
47+ from napari_cellseg3d import log_utility
4348
4449# local
4550from napari_cellseg3d .model_instance_seg import binary_connected
5762# https://napari-staging-site.github.io/guides/stable/threading.html
5863
5964WEIGHTS_DIR = os .path .dirname (os .path .realpath (__file__ )) + str (
60- Path ("/models/saved_weights " )
65+ Path ("/models/pretrained " )
6166)
6267
68+ class WeightsDownloader :
69+
70+ def __init__ (self , log_widget : Optional [log_utility .Log ]= None ):
71+ self .log_widget = log_widget
72+
73+ def download_weights (self ,model_name : str ):
74+ """
75+ Downloads a specific pretrained model.
76+ This code is adapted from DeepLabCut with permission from MWMathis.
77+
78+ Args:
79+ model_name (str): name of the model to download
80+ """
81+ import json
82+ import tarfile
83+ import urllib .request
84+
85+ def show_progress (count , block_size , total_size ):
86+ pbar .update (block_size )
87+
88+ cellseg3d_path = os .path .split (
89+ importlib .util .find_spec ("napari_cellseg3d" ).origin
90+ )[0 ]
91+ pretrained_folder_path = os .path .join (
92+ cellseg3d_path , "models" , "pretrained"
93+ )
94+ json_path = os .path .join (
95+ pretrained_folder_path , "pretrained_model_urls.json"
96+ )
97+ with open (json_path ) as f :
98+ neturls = json .load (f )
99+ if model_name in neturls .keys ():
100+ url = neturls [model_name ]
101+ response = urllib .request .urlopen (url )
102+
103+ start_message = f"Downloading the model from the M.W. Mathis Lab server { url } ...."
104+ total_size = int (response .getheader ("Content-Length" ))
105+ if self .log_widget is None :
106+ print (start_message )
107+ pbar = tqdm (unit = "B" , total = total_size , position = 0 )
108+ else :
109+ self .log_widget .print_and_log (start_message )
110+ pbar = tqdm (unit = "B" , total = total_size , position = 0 , file = self .log_widget )
111+
112+ filename , _ = urllib .request .urlretrieve (url , reporthook = show_progress )
113+ with tarfile .open (filename , mode = "r:gz" ) as tar :
114+ tar .extractall (pretrained_folder_path )
115+ else :
116+ raise ValueError (
117+ f"Unknown model: { model_name } . Should be one of { ', ' .join (neturls )} "
118+ )
119+
63120
64121class LogSignal (WorkerBaseSignals ):
65122 """Signal to send messages to be logged from another thread.
@@ -142,9 +199,12 @@ def __init__(
142199 self .window_infer_size = window_infer_size
143200 self .keep_on_cpu = keep_on_cpu
144201 self .stats_to_csv = stats_csv
145-
146202 """These attributes are all arguments of :py:func:~inference, please see that for reference"""
147203
204+ self .downloader = WeightsDownloader ()
205+ """Download utility"""
206+
207+
148208 @staticmethod
149209 def create_inference_dict (images_filepaths ):
150210 """Create a dict for MONAI with "image" keys with all image paths in :py:attr:`~self.images_filepaths`
@@ -154,6 +214,9 @@ def create_inference_dict(images_filepaths):
154214 data_dicts = [{"image" : image_name } for image_name in images_filepaths ]
155215 return data_dicts
156216
217+ def set_download_log (self , widget ):
218+ self .downloader .log_widget = widget
219+
157220 def log (self , text ):
158221 """Sends a signal that ``text`` should be logged
159222
@@ -233,8 +296,8 @@ def inference(self):
233296 """
234297 sys = platform .system ()
235298 print (f"OS is { sys } " )
236- if sys == "Darwin" : # required for macOS ?
237- torch .set_num_threads (1 )
299+ if sys == "Darwin" :
300+ torch .set_num_threads (1 ) # required for threading on macOS ?
238301 self .log ("Number of threads has been set to 1 for macOS" )
239302
240303 images_dict = self .create_inference_dict (self .images_filepaths )
@@ -260,7 +323,7 @@ def inference(self):
260323 model = self .model_dict ["class" ].get_net ()
261324 if self .model_dict ["name" ] == "SegResNet" :
262325 model = self .model_dict ["class" ].get_net ()(
263- input_image_size = [dims , dims , dims ], # TODO FIX !
326+ input_image_size = [dims , dims , dims ], # TODO FIX ! find a better way & remove model-specific code
264327 out_channels = 1 ,
265328 # dropout_prob=0.3,
266329 )
@@ -304,12 +367,13 @@ def inference(self):
304367 # print(weights)
305368 self .log (
306369 "\n Loading weights..."
307- ) # TODO add try/except for invalid weights
370+ ) # TODO add try/except for invalid weights for proper reset
308371
309372 if self .weights_dict ["custom" ]:
310373 weights = self .weights_dict ["path" ]
311374 else :
312- weights = os .path .join (WEIGHTS_DIR , self .weights_dict ["path" ])
375+ self .downloader .download_weights (self .model_dict ["name" ])
376+ weights = os .path .join (WEIGHTS_DIR , self .model_dict ["class" ].get_weights_file ())
313377
314378 model .load_state_dict (
315379 torch .load (
@@ -544,8 +608,6 @@ def __init__(
544608
545609
546610 """
547-
548- print ("init" )
549611 super ().__init__ (self .train )
550612 self ._signals = LogSignal ()
551613 self .log_signal = self ._signals .log_signal
@@ -571,7 +633,11 @@ def __init__(
571633
572634 self .train_files = []
573635 self .val_files = []
574- print ("end init" )
636+
637+ self .downloader = WeightsDownloader ()
638+
639+ def set_download_log (self , widget ):
640+ self .downloader .log_widget = widget
575641
576642 def log (self , text ):
577643 """Sends a signal that ``text`` should be logged
@@ -629,7 +695,7 @@ def log_parameters(self):
629695 self .log ("-" * 20 )
630696
631697 def train (self ):
632- """Trains the Pytorch model for the given number of epochs, with the selected model and data,
698+ """Trains the PyTorch model for the given number of epochs, with the selected model and data,
633699 using the chosen batch size, validation interval, loss function, and number of samples.
634700 Will perform validation once every :py:obj:`val_interval` and save results if the mean dice is better
635701
@@ -838,6 +904,7 @@ def train(self):
838904 if self .weights_path is not None :
839905 if self .weights_path == "use_pretrained" :
840906 weights_file = model_class .get_weights_file ()
907+ self .downloader .download_weights (model_name )
841908 weights = os .path .join (WEIGHTS_DIR , weights_file )
842909 self .weights_path = weights
843910 else :
0 commit comments