Skip to content

Commit e26285c

Browse files
committed
Skip download + weights error handling
- Skips download if the weights file is already present - Skips state_dict loading step in training if the weights are found to be incompatible
1 parent 1709528 commit e26285c

File tree

7 files changed

+73
-32
lines changed

7 files changed

+73
-32
lines changed

napari_cellseg3d/log_utility.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -28,13 +28,15 @@ def write(self, message):
2828
try:
2929
if not hasattr(self, "flag"):
3030
self.flag = False
31-
message = message.replace('\r', '').rstrip()
31+
message = message.replace("\r", "").rstrip()
3232
if message:
3333
method = "replace_last_line" if self.flag else "append"
34-
QtCore.QMetaObject.invokeMethod(self,
35-
method,
36-
QtCore.Qt.QueuedConnection,
37-
QtCore.Q_ARG(str, message))
34+
QtCore.QMetaObject.invokeMethod(
35+
self,
36+
method,
37+
QtCore.Qt.QueuedConnection,
38+
QtCore.Q_ARG(str, message),
39+
)
3840
self.flag = True
3941
else:
4042
self.flag = False

napari_cellseg3d/model_workers.py

Lines changed: 64 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from pathlib import Path
44
import importlib.util
55
from typing import Optional
6+
import warnings
67

78
import numpy as np
89
from tifffile import imwrite
@@ -65,19 +66,28 @@
6566
Path("/models/pretrained")
6667
)
6768

69+
6870
class WeightsDownloader:
71+
"""A utility class the downloads the weights of a model when needed."""
72+
73+
def __init__(self, log_widget: Optional[log_utility.Log] = None):
74+
"""
75+
Creates a WeightsDownloader, optionally with a log widget to display the progress.
6976
70-
def __init__(self, log_widget: Optional[log_utility.Log]= None):
77+
Args:
78+
log_widget (log_utility.Log): a Log to display the progress bar in. If None, uses print()
79+
"""
7180
self.log_widget = log_widget
7281

73-
def download_weights(self,model_name: str):
82+
def download_weights(self, model_name: str, model_weights_filename: str):
7483
"""
75-
Downloads a specific pretrained model.
76-
This code is adapted from DeepLabCut with permission from MWMathis.
84+
Downloads a specific pretrained model.
85+
This code is adapted from DeepLabCut with permission from MWMathis.
7786
78-
Args:
79-
model_name (str): name of the model to download
80-
"""
87+
Args:
88+
model_name (str): name of the model to download
89+
model_weights_filename (str): name of the .pth file expected for the model
90+
"""
8191
import json
8292
import tarfile
8393
import urllib.request
@@ -94,6 +104,17 @@ def show_progress(count, block_size, total_size):
94104
json_path = os.path.join(
95105
pretrained_folder_path, "pretrained_model_urls.json"
96106
)
107+
108+
check_path = os.path.join(
109+
pretrained_folder_path, model_weights_filename
110+
)
111+
if os.path.exists(check_path):
112+
message = f"Weight file {model_weights_filename} already exists, skipping download step"
113+
if self.log_widget is not None:
114+
self.log_widget.print_and_log(message, printing=False)
115+
print(message)
116+
return
117+
97118
with open(json_path) as f:
98119
neturls = json.load(f)
99120
if model_name in neturls.keys():
@@ -107,9 +128,16 @@ def show_progress(count, block_size, total_size):
107128
pbar = tqdm(unit="B", total=total_size, position=0)
108129
else:
109130
self.log_widget.print_and_log(start_message)
110-
pbar = tqdm(unit="B", total=total_size, position=0, file=self.log_widget)
131+
pbar = tqdm(
132+
unit="B",
133+
total=total_size,
134+
position=0,
135+
file=self.log_widget,
136+
)
111137

112-
filename, _ = urllib.request.urlretrieve(url, reporthook=show_progress)
138+
filename, _ = urllib.request.urlretrieve(
139+
url, reporthook=show_progress
140+
)
113141
with tarfile.open(filename, mode="r:gz") as tar:
114142
tar.extractall(pretrained_folder_path)
115143
else:
@@ -204,7 +232,6 @@ def __init__(
204232
self.downloader = WeightsDownloader()
205233
"""Download utility"""
206234

207-
208235
@staticmethod
209236
def create_inference_dict(images_filepaths):
210237
"""Create a dict for MONAI with "image" keys with all image paths in :py:attr:`~self.images_filepaths`
@@ -297,7 +324,7 @@ def inference(self):
297324
sys = platform.system()
298325
print(f"OS is {sys}")
299326
if sys == "Darwin":
300-
torch.set_num_threads(1) # required for threading on macOS ?
327+
torch.set_num_threads(1) # required for threading on macOS ?
301328
self.log("Number of threads has been set to 1 for macOS")
302329

303330
images_dict = self.create_inference_dict(self.images_filepaths)
@@ -323,7 +350,11 @@ def inference(self):
323350
model = self.model_dict["class"].get_net()
324351
if self.model_dict["name"] == "SegResNet":
325352
model = self.model_dict["class"].get_net()(
326-
input_image_size=[dims, dims, dims], # TODO FIX ! find a better way & remove model-specific code
353+
input_image_size=[
354+
dims,
355+
dims,
356+
dims,
357+
], # TODO FIX ! find a better way & remove model-specific code
327358
out_channels=1,
328359
# dropout_prob=0.3,
329360
)
@@ -372,8 +403,13 @@ def inference(self):
372403
if self.weights_dict["custom"]:
373404
weights = self.weights_dict["path"]
374405
else:
375-
self.downloader.download_weights(self.model_dict["name"])
376-
weights = os.path.join(WEIGHTS_DIR, self.model_dict["class"].get_weights_file())
406+
self.downloader.download_weights(
407+
self.model_dict["name"],
408+
self.model_dict["class"].get_weights_file(),
409+
)
410+
weights = os.path.join(
411+
WEIGHTS_DIR, self.model_dict["class"].get_weights_file()
412+
)
377413

378414
model.load_state_dict(
379415
torch.load(
@@ -904,18 +940,26 @@ def train(self):
904940
if self.weights_path is not None:
905941
if self.weights_path == "use_pretrained":
906942
weights_file = model_class.get_weights_file()
907-
self.downloader.download_weights(model_name)
943+
self.downloader.download_weights(model_name, weights_file)
908944
weights = os.path.join(WEIGHTS_DIR, weights_file)
909945
self.weights_path = weights
910946
else:
911947
weights = os.path.join(self.weights_path)
912948

913-
model.load_state_dict(
914-
torch.load(
915-
weights,
916-
map_location=self.device,
949+
try:
950+
model.load_state_dict(
951+
torch.load(
952+
weights,
953+
map_location=self.device,
954+
)
917955
)
918-
)
956+
except RuntimeError:
957+
warn = (
958+
"WARNING:\nIt seems the weights were incompatible with the model,\n"
959+
"the model will be trained from random weights"
960+
)
961+
self.log(warn)
962+
warnings.warn(warn)
919963

920964
if self.device.type == "cuda":
921965
self.log("\nUsing GPU :")

napari_cellseg3d/models/model_SegResNet.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
from monai.networks.nets import SegResNetVAE
22

33

4-
54
def get_net():
65
return SegResNetVAE
76

napari_cellseg3d/models/model_TRAILMAP.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,12 @@
22
from torch import nn
33

44

5-
65
def get_weights_file():
76
# model additionally trained on Mathis/Wyss mesoSPIM data
87
return "TRAILMAP.pth"
98
# FIXME currently incorrect, find good weights from TRAILMAP_test and upload them
109

10+
1111
def get_net():
1212
return TRAILMAP(1, 1)
1313

@@ -120,4 +120,3 @@ def outBlock(self, in_ch, out_ch, kernel_size, padding="same"):
120120
# nn.BatchNorm3d(out_ch),
121121
)
122122
return out
123-

napari_cellseg3d/models/model_VNet.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
from monai.networks.nets import VNet
33

44

5-
65
def get_net():
76
return VNet()
87

napari_cellseg3d/plugin_model_training.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -994,7 +994,7 @@ def make_csv(self):
994994
if len(self.loss_values) == 0 or self.loss_values is None:
995995
warnings.warn("No loss values to add to csv !")
996996
return
997-
997+
998998
self.df = pd.DataFrame(
999999
{
10001000
"epoch": size_column,

napari_cellseg3d/utils.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -978,5 +978,3 @@ def merge_imgs(imgs, original_image_shape):
978978

979979
print(merged_imgs.shape)
980980
return merged_imgs
981-
982-

0 commit comments

Comments
 (0)