Skip to content

Commit 8789bdf

Browse files
committed
Download test and fixed old desc
1 parent 368e473 commit 8789bdf

File tree

3 files changed

+26
-3
lines changed

3 files changed

+26
-3
lines changed
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
from napari_cellseg3d.code_models.model_workers import WeightsDownloader, WEIGHTS_DIR
2+
from napari_cellseg3d.config import ModelInfo
3+
4+
5+
def test_weight_download():
6+
7+
info = ModelInfo()
8+
9+
downloader = WeightsDownloader()
10+
11+
downloader.download_weights(
12+
info.name,
13+
info.get_model().get_weights_file()
14+
)
15+
result_path = WEIGHTS_DIR / str(info.get_model().get_weights_file())
16+
17+
assert result_path.is_file()
18+
19+

napari_cellseg3d/code_models/model_workers.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,7 @@ def show_progress(count, block_size, total_size):
119119
url = neturls[model_name]
120120
response = urllib.request.urlopen(url)
121121

122-
start_message = f"Downloading the model from the M.W. Mathis Lab server {url}...."
122+
start_message = f"Downloading the model from HuggingFace {url}...."
123123
total_size = int(response.getheader("Content-Length"))
124124
if self.log_widget is None:
125125
logger.info(start_message)
@@ -142,7 +142,11 @@ def is_within_directory(directory, target):
142142
abs_directory = Path(directory).resolve()
143143
abs_target = Path(target).resolve()
144144
# prefix = os.path.commonprefix([abs_directory, abs_target])
145-
return abs_target in abs_directory.parents
145+
logger.debug(abs_directory)
146+
logger.debug(abs_target)
147+
logger.debug(abs_directory.parents)
148+
149+
return abs_directory in abs_target.parents
146150

147151
def safe_extract(tar, path=".", members=None, *, numeric_owner=False):
148152

napari_cellseg3d/config.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,8 @@
2323
# TODO(cyril) add JSON load/save
2424

2525
MODEL_LIST = {
26-
"VNet": VNet,
2726
"SegResNet": SegResNet,
27+
"VNet": VNet,
2828
# "TRAILMAP": TRAILMAP,
2929
"TRAILMAP_MS": TRAILMAP_MS,
3030
"SwinUNetR": SwinUNetR,

0 commit comments

Comments
 (0)