Skip to content

Commit 99c436b

Browse files
committed
WIP download weights
1 parent 7f46f2f commit 99c436b

File tree

6 files changed

+85
-24
lines changed

6 files changed

+85
-24
lines changed

docs/res/guides/custom_model_template.rst

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ To add a custom model, you will need a **.py** file with the following structure
1717

1818
def get_weights_file():
1919
return "weights_file.pth" # name of the weights file for the model,
20-
# which should be in *napari_cellseg3d/models/saved_weights*
20+
# which should be in *napari_cellseg3d/models/pretrained*
2121

2222

2323
def get_output(model, input):
@@ -35,5 +35,3 @@ To add a custom model, you will need a **.py** file with the following structure
3535
def ModelClass(x1,x2...):
3636
# your Pytorch model here...
3737
return results # should return as [C, N, D,H,W]
38-
39-

napari_cellseg3d/models/TRAILMAP_MS.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,14 @@
11
import torch
22
from torch import nn
3+
from napari_cellseg3d import utils
4+
import os
35

6+
modelname = "TRAILMAP_MS"
7+
target_dir = os.path.join("models","pretrained")
48

59
def get_weights_file():
6-
# return "TMP_TEST_40e.pth"
7-
return "TRAILMAP_DFl_best.pth"
10+
utils.DownloadModel(model, train_dir)
11+
return "TRAILMAP_MS_best_metric_epoch_26.pth" #model additionally trained on Mathis/Wyss mesoSPIM data
812

913

1014
def get_net():

napari_cellseg3d/models/model_SegResNet.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,16 @@
11
from monai.networks.nets import SegResNetVAE
2+
from napari_cellseg3d import utils
3+
import os
24

5+
modelname = "SegResNet"
6+
target_dir = os.path.join("models","pretrained")
37

48
def get_net():
59
return SegResNetVAE
610

711

812
def get_weights_file():
13+
utils.DownloadModel(modelname, target_dir)
914
return "SegResNet.pth"
1015

1116

napari_cellseg3d/models/model_TRAILMAP.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,13 @@
11
from napari_cellseg3d.models.unet.model import UNet3D
2+
from napari_cellseg3d import utils
3+
import os
24

5+
modelname = "TRAILMAP"
6+
target_dir = os.path.join("models","pretrained")
37

48
def get_weights_file():
5-
# return "TMP_TEST_40e.pth"
6-
return "trailmaptorchpretrained.pth"
9+
utils.DownloadModel(model, train_dir)
10+
return "TRAILMAP_PyTorch.pth" #original model form Luo lab, transfered to pytroch
711

812

913
def get_net():

napari_cellseg3d/models/model_VNet.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,17 @@
11
from monai.inferers import sliding_window_inference
22
from monai.networks.nets import VNet
3+
from napari_cellseg3d import utils
4+
import os
35

6+
modelname = "VNet"
7+
target_dir = os.path.join("models","pretrained")
48

59
def get_net():
610
return VNet()
711

812

913
def get_weights_file():
10-
# return "dice_VNet.pth"
14+
utils.DownloadModel(model, train_dir)
1115
return "VNet_40e.pth"
1216

1317

napari_cellseg3d/utils.py

Lines changed: 62 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -12,29 +12,14 @@
1212
from skimage.filters import gaussian
1313
from tifffile import imread as tfl_imread
1414
from tqdm import tqdm
15+
import importlib.util
1516

1617
"""
1718
utils.py
1819
====================================
1920
Definitions of utility functions and variables
2021
"""
2122

22-
##################
23-
##################
24-
# dev util
25-
def ENABLE_TEST_MODE():
26-
path = Path(os.path.expanduser("~"))
27-
# print(path)
28-
print("TEST MODE ENABLED, DEV ONLY")
29-
if path == Path("C:/Users/Cyril"):
30-
return True
31-
return False
32-
33-
34-
##################
35-
##################
36-
37-
3823
def normalize_x(image):
3924
"""Normalizes the values of an image array to be between [-1;1] rather than [0;255]
4025
@@ -996,3 +981,64 @@ def merge_imgs(imgs, original_image_shape):
996981

997982
print(merged_imgs.shape)
998983
return merged_imgs
984+
985+
986+
def read_plainconfig(configname):
987+
"""
988+
This code is adapted from DeepLabCut with permission from MWMathis
989+
"""
990+
if not os.path.exists(configname):
991+
raise FileNotFoundError(
992+
f"Config {configname} is not found. Please make sure that the file exists."
993+
)
994+
with open(configname) as file:
995+
return YAML().load(file)
996+
997+
def DownloadModel(modelname, target_dir):
998+
"""
999+
Downloads a specific pretained model.
1000+
This code is adapted from DeepLabCut with permission from MWMathis
1001+
"""
1002+
import urllib.request
1003+
import tarfile
1004+
from tqdm import tqdm
1005+
1006+
def show_progress(count, block_size, total_size):
1007+
pbar.update(block_size)
1008+
1009+
def tarfilenamecutting(tarf):
1010+
"""' auxfun to extract folder path
1011+
ie. /xyz-trainsetxyshufflez/
1012+
"""
1013+
for memberid, member in enumerate(tarf.getmembers()):
1014+
if memberid == 0:
1015+
parent = str(member.path)
1016+
l = len(parent) + 1
1017+
if member.path.startswith(parent):
1018+
member.path = member.path[l:]
1019+
yield member
1020+
#TODO: fix error in line 1021;
1021+
cellseg3d_path = os.path.split(importlib.util.find_spec("napari-cellseg3d").origin)[0]
1022+
neturls = read_plainconfig(os.path.join(cellseg3d_path,"models","pretrained","pretrained_model_urls.yaml",))
1023+
1024+
if modelname in neturls.keys():
1025+
url = neturls[modelname]
1026+
response = urllib.request.urlopen(url)
1027+
print(
1028+
"Downloading the model from the M.W. Mathis Lab server {}....".format(
1029+
url
1030+
)
1031+
)
1032+
total_size = int(response.getheader("Content-Length"))
1033+
pbar = tqdm(unit="B", total=total_size, position=0)
1034+
filename, _ = urllib.request.urlretrieve(url, reporthook=show_progress)
1035+
with tarfile.open(filename, mode="r:gz") as tar:
1036+
tar.extractall(target_dir, members=tarfilenamecutting(tar))
1037+
else:
1038+
models = [
1039+
fn
1040+
for fn in neturls.keys()
1041+
if "VNet_" not in fn and "SegResNet" not in fn and "TRAILMAP_" not in fn
1042+
]
1043+
print("Model does not exist: ", modelname)
1044+
#print("Pick one of the following: ", models)

0 commit comments

Comments
 (0)