33from pathlib import Path
44import importlib .util
55from typing import Optional
6+ import warnings
67
78import numpy as np
89from tifffile import imwrite
6566 Path ("/models/pretrained" )
6667)
6768
69+
6870class WeightsDownloader :
71+ """A utility class the downloads the weights of a model when needed."""
72+
73+ def __init__ (self , log_widget : Optional [log_utility .Log ] = None ):
74+ """
75+ Creates a WeightsDownloader, optionally with a log widget to display the progress.
6976
70- def __init__ (self , log_widget : Optional [log_utility .Log ]= None ):
77+ Args:
78+ log_widget (log_utility.Log): a Log to display the progress bar in. If None, uses print()
79+ """
7180 self .log_widget = log_widget
7281
73- def download_weights (self ,model_name : str ):
82+ def download_weights (self , model_name : str , model_weights_filename : str ):
7483 """
75- Downloads a specific pretrained model.
76- This code is adapted from DeepLabCut with permission from MWMathis.
84+ Downloads a specific pretrained model.
85+ This code is adapted from DeepLabCut with permission from MWMathis.
7786
78- Args:
79- model_name (str): name of the model to download
80- """
87+ Args:
88+ model_name (str): name of the model to download
89+ model_weights_filename (str): name of the .pth file expected for the model
90+ """
8191 import json
8292 import tarfile
8393 import urllib .request
@@ -94,6 +104,17 @@ def show_progress(count, block_size, total_size):
94104 json_path = os .path .join (
95105 pretrained_folder_path , "pretrained_model_urls.json"
96106 )
107+
108+ check_path = os .path .join (
109+ pretrained_folder_path , model_weights_filename
110+ )
111+ if os .path .exists (check_path ):
112+ message = f"Weight file { model_weights_filename } already exists, skipping download step"
113+ if self .log_widget is not None :
114+ self .log_widget .print_and_log (message , printing = False )
115+ print (message )
116+ return
117+
97118 with open (json_path ) as f :
98119 neturls = json .load (f )
99120 if model_name in neturls .keys ():
@@ -107,9 +128,16 @@ def show_progress(count, block_size, total_size):
107128 pbar = tqdm (unit = "B" , total = total_size , position = 0 )
108129 else :
109130 self .log_widget .print_and_log (start_message )
110- pbar = tqdm (unit = "B" , total = total_size , position = 0 , file = self .log_widget )
131+ pbar = tqdm (
132+ unit = "B" ,
133+ total = total_size ,
134+ position = 0 ,
135+ file = self .log_widget ,
136+ )
111137
112- filename , _ = urllib .request .urlretrieve (url , reporthook = show_progress )
138+ filename , _ = urllib .request .urlretrieve (
139+ url , reporthook = show_progress
140+ )
113141 with tarfile .open (filename , mode = "r:gz" ) as tar :
114142 tar .extractall (pretrained_folder_path )
115143 else :
@@ -204,7 +232,6 @@ def __init__(
204232 self .downloader = WeightsDownloader ()
205233 """Download utility"""
206234
207-
208235 @staticmethod
209236 def create_inference_dict (images_filepaths ):
210237 """Create a dict for MONAI with "image" keys with all image paths in :py:attr:`~self.images_filepaths`
@@ -297,7 +324,7 @@ def inference(self):
297324 sys = platform .system ()
298325 print (f"OS is { sys } " )
299326 if sys == "Darwin" :
300- torch .set_num_threads (1 ) # required for threading on macOS ?
327+ torch .set_num_threads (1 ) # required for threading on macOS ?
301328 self .log ("Number of threads has been set to 1 for macOS" )
302329
303330 images_dict = self .create_inference_dict (self .images_filepaths )
@@ -323,7 +350,11 @@ def inference(self):
323350 model = self .model_dict ["class" ].get_net ()
324351 if self .model_dict ["name" ] == "SegResNet" :
325352 model = self .model_dict ["class" ].get_net ()(
326- input_image_size = [dims , dims , dims ], # TODO FIX ! find a better way & remove model-specific code
353+ input_image_size = [
354+ dims ,
355+ dims ,
356+ dims ,
357+ ], # TODO FIX ! find a better way & remove model-specific code
327358 out_channels = 1 ,
328359 # dropout_prob=0.3,
329360 )
@@ -372,8 +403,13 @@ def inference(self):
372403 if self .weights_dict ["custom" ]:
373404 weights = self .weights_dict ["path" ]
374405 else :
375- self .downloader .download_weights (self .model_dict ["name" ])
376- weights = os .path .join (WEIGHTS_DIR , self .model_dict ["class" ].get_weights_file ())
406+ self .downloader .download_weights (
407+ self .model_dict ["name" ],
408+ self .model_dict ["class" ].get_weights_file (),
409+ )
410+ weights = os .path .join (
411+ WEIGHTS_DIR , self .model_dict ["class" ].get_weights_file ()
412+ )
377413
378414 model .load_state_dict (
379415 torch .load (
@@ -904,18 +940,26 @@ def train(self):
904940 if self .weights_path is not None :
905941 if self .weights_path == "use_pretrained" :
906942 weights_file = model_class .get_weights_file ()
907- self .downloader .download_weights (model_name )
943+ self .downloader .download_weights (model_name , weights_file )
908944 weights = os .path .join (WEIGHTS_DIR , weights_file )
909945 self .weights_path = weights
910946 else :
911947 weights = os .path .join (self .weights_path )
912948
913- model .load_state_dict (
914- torch .load (
915- weights ,
916- map_location = self .device ,
949+ try :
950+ model .load_state_dict (
951+ torch .load (
952+ weights ,
953+ map_location = self .device ,
954+ )
917955 )
918- )
956+ except RuntimeError :
957+ warn = (
958+ "WARNING:\n It seems the weights were incompatible with the model,\n "
959+ "the model will be trained from random weights"
960+ )
961+ self .log (warn )
962+ warnings .warn (warn )
919963
920964 if self .device .type == "cuda" :
921965 self .log ("\n Using GPU :" )
0 commit comments