33from pathlib import Path
44import importlib .util
55from typing import Optional
6+ import warnings
67
78import numpy as np
89from tifffile import imwrite
6566 Path ("/models/pretrained" )
6667)
6768
69+
6870class WeightsDownloader :
71+ """A utility class the downloads the weights of a model when needed."""
72+
73+ def __init__ (self , log_widget : Optional [log_utility .Log ] = None ):
74+ """
75+ Creates a WeightsDownloader, optionally with a log widget to display the progress.
6976
70- def __init__ (self , log_widget : Optional [log_utility .Log ]= None ):
77+ Args:
78+ log_widget (log_utility.Log): a Log to display the progress bar in. If None, uses print()
79+ """
7180 self .log_widget = log_widget
7281
73- def download_weights (self ,model_name : str ):
82+ def download_weights (self , model_name : str , model_weights_filename : str ):
7483 """
75- Downloads a specific pretrained model.
76- This code is adapted from DeepLabCut with permission from MWMathis.
84+ Downloads a specific pretrained model.
85+ This code is adapted from DeepLabCut with permission from MWMathis.
7786
78- Args:
79- model_name (str): name of the model to download
80- """
87+ Args:
88+ model_name (str): name of the model to download
89+ model_weights_filename (str): name of the .pth file expected for the model
90+ """
8191 import json
8292 import tarfile
8393 import urllib .request
@@ -94,6 +104,17 @@ def show_progress(count, block_size, total_size):
94104 json_path = os .path .join (
95105 pretrained_folder_path , "pretrained_model_urls.json"
96106 )
107+
108+ check_path = os .path .join (
109+ pretrained_folder_path , model_weights_filename
110+ )
111+ if os .path .exists (check_path ):
112+ message = f"Weight file { model_weights_filename } already exists, skipping download step"
113+ if self .log_widget is not None :
114+ self .log_widget .print_and_log (message , printing = False )
115+ print (message )
116+ return
117+
97118 with open (json_path ) as f :
98119 neturls = json .load (f )
99120 if model_name in neturls .keys ():
@@ -107,9 +128,16 @@ def show_progress(count, block_size, total_size):
107128 pbar = tqdm (unit = "B" , total = total_size , position = 0 )
108129 else :
109130 self .log_widget .print_and_log (start_message )
110- pbar = tqdm (unit = "B" , total = total_size , position = 0 , file = self .log_widget )
131+ pbar = tqdm (
132+ unit = "B" ,
133+ total = total_size ,
134+ position = 0 ,
135+ file = self .log_widget ,
136+ )
111137
112- filename , _ = urllib .request .urlretrieve (url , reporthook = show_progress )
138+ filename , _ = urllib .request .urlretrieve (
139+ url , reporthook = show_progress
140+ )
113141 with tarfile .open (filename , mode = "r:gz" ) as tar :
114142 tar .extractall (pretrained_folder_path )
115143 else :
@@ -121,10 +149,12 @@ def show_progress(count, block_size, total_size):
121149class LogSignal (WorkerBaseSignals ):
122150 """Signal to send messages to be logged from another thread.
123151
124- Separate from Worker instances as indicated `here`_"""
152+ Separate from Worker instances as indicated `here`_""" # TODO link ?
125153
126154 log_signal = Signal (str )
127155 """qtpy.QtCore.Signal: signal to be sent when some text should be logged"""
156+ warn_signal = Signal (str )
157+ """qtpy.QtCore.Signal: signal to be sent when some warning should be emitted in main thread"""
128158
129159 # Should not be an instance variable but a class variable, not defined in __init__, see
130160 # https://stackoverflow.com/questions/2970312/pyqt4-qtcore-pyqtsignal-object-has-no-attribute-connect
@@ -185,6 +215,7 @@ def __init__(
185215 super ().__init__ (self .inference )
186216 self ._signals = LogSignal () # add custom signals
187217 self .log_signal = self ._signals .log_signal
218+ self .warn_signal = self ._signals .warn_signal
188219 ###########################################
189220 ###########################################
190221 self .device = device
@@ -204,7 +235,6 @@ def __init__(
204235 self .downloader = WeightsDownloader ()
205236 """Download utility"""
206237
207-
208238 @staticmethod
209239 def create_inference_dict (images_filepaths ):
210240 """Create a dict for MONAI with "image" keys with all image paths in :py:attr:`~self.images_filepaths`
@@ -225,6 +255,10 @@ def log(self, text):
225255 """
226256 self .log_signal .emit (text )
227257
258+ def warn (self , warning ):
259+ """Sends a warning to main thread"""
260+ self .warn_signal .emit (warning )
261+
228262 def log_parameters (self ):
229263
230264 self .log ("-" * 20 )
@@ -297,7 +331,7 @@ def inference(self):
297331 sys = platform .system ()
298332 print (f"OS is { sys } " )
299333 if sys == "Darwin" :
300- torch .set_num_threads (1 ) # required for threading on macOS ?
334+ torch .set_num_threads (1 ) # required for threading on macOS ?
301335 self .log ("Number of threads has been set to 1 for macOS" )
302336
303337 images_dict = self .create_inference_dict (self .images_filepaths )
@@ -323,7 +357,11 @@ def inference(self):
323357 model = self .model_dict ["class" ].get_net ()
324358 if self .model_dict ["name" ] == "SegResNet" :
325359 model = self .model_dict ["class" ].get_net ()(
326- input_image_size = [dims , dims , dims ], # TODO FIX ! find a better way & remove model-specific code
360+ input_image_size = [
361+ dims ,
362+ dims ,
363+ dims ,
364+ ], # TODO FIX ! find a better way & remove model-specific code
327365 out_channels = 1 ,
328366 # dropout_prob=0.3,
329367 )
@@ -372,8 +410,13 @@ def inference(self):
372410 if self .weights_dict ["custom" ]:
373411 weights = self .weights_dict ["path" ]
374412 else :
375- self .downloader .download_weights (self .model_dict ["name" ])
376- weights = os .path .join (WEIGHTS_DIR , self .model_dict ["class" ].get_weights_file ())
413+ self .downloader .download_weights (
414+ self .model_dict ["name" ],
415+ self .model_dict ["class" ].get_weights_file (),
416+ )
417+ weights = os .path .join (
418+ WEIGHTS_DIR , self .model_dict ["class" ].get_weights_file ()
419+ )
377420
378421 model .load_state_dict (
379422 torch .load (
@@ -611,7 +654,10 @@ def __init__(
611654 super ().__init__ (self .train )
612655 self ._signals = LogSignal ()
613656 self .log_signal = self ._signals .log_signal
657+ self .warn_signal = self ._signals .warn_signal
614658
659+ self ._weight_error = False
660+ #############################################
615661 self .device = device
616662 self .model_dict = model_dict
617663 self .weights_path = weights_path
@@ -633,7 +679,7 @@ def __init__(
633679
634680 self .train_files = []
635681 self .val_files = []
636-
682+ #######################################
637683 self .downloader = WeightsDownloader ()
638684
639685 def set_download_log (self , widget ):
@@ -647,6 +693,10 @@ def log(self, text):
647693 """
648694 self .log_signal .emit (text )
649695
696+ def warn (self , warning ):
697+ """Sends a warning to main thread"""
698+ self .warn_signal .emit (warning )
699+
650700 def log_parameters (self ):
651701
652702 self .log ("-" * 20 )
@@ -690,6 +740,13 @@ def log_parameters(self):
690740
691741 if self .weights_path is not None :
692742 self .log (f"Using weights from : { self .weights_path } " )
743+ if self ._weight_error :
744+ self .log (
745+ ">>>>>>>>>>>>>>>>>\n "
746+ "WARNING:\n Chosen weights were incompatible with the model,\n "
747+ "the model will be trained from random weights\n "
748+ "<<<<<<<<<<<<<<<<<\n "
749+ )
693750
694751 # self.log("\n")
695752 self .log ("-" * 20 )
@@ -904,18 +961,27 @@ def train(self):
904961 if self .weights_path is not None :
905962 if self .weights_path == "use_pretrained" :
906963 weights_file = model_class .get_weights_file ()
907- self .downloader .download_weights (model_name )
964+ self .downloader .download_weights (model_name , weights_file )
908965 weights = os .path .join (WEIGHTS_DIR , weights_file )
909966 self .weights_path = weights
910967 else :
911968 weights = os .path .join (self .weights_path )
912969
913- model .load_state_dict (
914- torch .load (
915- weights ,
916- map_location = self .device ,
970+ try :
971+ model .load_state_dict (
972+ torch .load (
973+ weights ,
974+ map_location = self .device ,
975+ )
917976 )
918- )
977+ except RuntimeError :
978+ warn = (
979+ "WARNING:\n It seems the weights were incompatible with the model,\n "
980+ "the model will be trained from random weights"
981+ )
982+ self .log (warn )
983+ self .warn (warn )
984+ self ._weight_error = True
919985
920986 if self .device .type == "cuda" :
921987 self .log ("\n Using GPU :" )
0 commit comments