Skip to content

Commit 04e0af2

Browse files
committed
Convert the URL file to json to save a dependency
1 parent 84486d9 commit 04e0af2

File tree

7 files changed

+17
-30
lines changed

7 files changed

+17
-30
lines changed

napari_cellseg3d/models/TRAILMAP_MS.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
target_dir = os.path.join("models","pretrained")
88

99
def get_weights_file():
10-
utils.DownloadModel(modelname, target_dir)
10+
utils.download_model(modelname, target_dir)
1111
return "TRAILMAP_MS_best_metric_epoch_26.pth" #model additionally trained on Mathis/Wyss mesoSPIM data
1212

1313

napari_cellseg3d/models/model_SegResNet.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ def get_net():
1010

1111

1212
def get_weights_file():
13-
utils.DownloadModel(modelname, target_dir)
13+
utils.download_model(modelname, target_dir)
1414
return "SegResNet.pth"
1515

1616

napari_cellseg3d/models/model_TRAILMAP.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
target_dir = os.path.join("models","pretrained")
77

88
def get_weights_file():
9-
uutils.DownloadModel(modelname, target_dir)
9+
utils.download_model(modelname, target_dir)
1010
return "TRAILMAP_PyTorch.pth" #original model from Liqun Luo lab, transfered to pytorch
1111

1212

napari_cellseg3d/models/model_VNet.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ def get_net():
1111

1212

1313
def get_weights_file():
14-
utils.DownloadModel(modelname, target_dir)
14+
utils.download_model(modelname, target_dir)
1515
return "VNet_40e.pth"
1616

1717

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
{
2+
"TRAILMAP_MS": "http://deeplabcut.rowland.harvard.edu/cellseg3dmodels/TRAILMAP_MS.tar.gz",
3+
"TRAILMAP": "http://deeplabcut.rowland.harvard.edu/cellseg3dmodels/TRAILMAP.tar.gz",
4+
"SegResNet": "http://deeplabcut.rowland.harvard.edu/cellseg3dmodels/SegResNet.tar.gz",
5+
"VNet": "http://deeplabcut.rowland.harvard.edu/cellseg3dmodels/VNet.tar.gz"
6+
}

napari_cellseg3d/models/pretrained/pretrained_model_urls.yaml

Lines changed: 0 additions & 7 deletions
This file was deleted.

napari_cellseg3d/utils.py

Lines changed: 7 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -983,25 +983,14 @@ def merge_imgs(imgs, original_image_shape):
983983
return merged_imgs
984984

985985

986-
def read_plainconfig(configname):
987-
"""
988-
This code is adapted from DeepLabCut with permission from MWMathis
989-
"""
990-
if not os.path.exists(configname):
991-
raise FileNotFoundError(
992-
f"Config {configname} is not found. Please make sure that the file exists."
993-
)
994-
with open(configname) as file:
995-
return YAML().load(file)
996-
997-
def DownloadModel(modelname, target_dir):
986+
def download_model(modelname, target_dir):
998987
"""
999988
Downloads a specific pretained model.
1000989
This code is adapted from DeepLabCut with permission from MWMathis
1001990
"""
991+
import json
1002992
import urllib.request
1003993
import tarfile
1004-
from tqdm import tqdm
1005994

1006995
def show_progress(count, block_size, total_size):
1007996
pbar.update(block_size)
@@ -1017,18 +1006,17 @@ def tarfilenamecutting(tarf):
10171006
if member.path.startswith(parent):
10181007
member.path = member.path[l:]
10191008
yield member
1009+
10201010
#TODO: fix error in line 1021;
10211011
cellseg3d_path = os.path.split(importlib.util.find_spec("napari-cellseg3d").origin)[0]
1022-
neturls = read_plainconfig(os.path.join(cellseg3d_path,"models","pretrained","pretrained_model_urls.yaml",))
1012+
json_path = os.path.join(cellseg3d_path, "models", "pretrained", "pretrained_model_urls.json")
1013+
with open(json_path) as f:
1014+
neturls = json.load(f)
10231015

10241016
if modelname in neturls.keys():
10251017
url = neturls[modelname]
10261018
response = urllib.request.urlopen(url)
1027-
print(
1028-
"Downloading the model from the M.W. Mathis Lab server {}....".format(
1029-
url
1030-
)
1031-
)
1019+
print(f"Downloading the model from the M.W. Mathis Lab server {url}....")
10321020
total_size = int(response.getheader("Content-Length"))
10331021
pbar = tqdm(unit="B", total=total_size, position=0)
10341022
filename, _ = urllib.request.urlretrieve(url, reporthook=show_progress)

0 commit comments

Comments
 (0)