Skip to content

Commit 96a62bd

Browse files
Merge pull request #340 from MannLabs/patch_337
[FEATURE] add backup repo for downloading cellpose models
2 parents e6b6087 + 0677ce8 commit 96a62bd

File tree

2 files changed

+91
-0
lines changed

2 files changed

+91
-0
lines changed

src/scportrait/pipeline/segmentation/workflows/_cellpose.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
ShardedSegmentation,
1919
)
2020
from scportrait.pipeline.segmentation.workflows._base_segmentation_workflow import _BaseSegmentation
21+
from scportrait.pipeline.segmentation.workflows._model_caches import _download_model
2122

2223

2324
class _CellposeSegmentation(_BaseSegmentation):
@@ -53,6 +54,14 @@ def _read_cellpose_model(self, modeltype: str, name: str, gpu: str, device) -> m
5354
5455
"""
5556
if modeltype == "pretrained":
57+
try:
58+
_download_model(name)
59+
60+
except FileNotFoundError as e:
61+
raise FileNotFoundError(
62+
f"Could not download the requested Cellpose model '{name}'. "
63+
"Please check the model name or ensure that the Cellpose model server is available."
64+
) from e
5665
model = models.Cellpose(model_type=name, gpu=gpu, device=device)
5766
elif modeltype == "custom":
5867
if not Path(name).exists():
Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
import os
2+
from urllib.error import HTTPError
3+
from urllib.parse import quote
4+
5+
from cellpose import utils
6+
from cellpose.models import MODEL_DIR, model_path
7+
8+
ZENODO_RECORD_ID = "17564109"
9+
10+
11+
def _make_zenodo_download_link(record_id: str, filename: str) -> str:
12+
"""
13+
Construct a direct download URL for a file stored in a Zenodo record.
14+
15+
Parameters
16+
----------
17+
record_id : str
18+
The Zenodo record identifier (e.g., "1234567").
19+
filename : str
20+
The exact filename stored in the Zenodo record (case sensitive).
21+
22+
Returns
23+
-------
24+
str
25+
A direct HTTPS download URL suitable for urllib / requests / wget.
26+
"""
27+
return f"https://zenodo.org/records/{record_id}/files/{quote(filename)}?download=1"
28+
29+
30+
def _scportrait_cache_model_path(basename: str) -> None:
31+
"""Download a model from a public Nextcloud share into Cellpose's model cache if missing."""
32+
MODEL_DIR.mkdir(parents=True, exist_ok=True)
33+
34+
url = _make_zenodo_download_link(
35+
record_id=ZENODO_RECORD_ID,
36+
filename=basename,
37+
)
38+
cached_file = MODEL_DIR / basename
39+
40+
if not cached_file.exists():
41+
print(f'Downloading: "{url}" → {cached_file}')
42+
utils.download_url_to_file(url, os.fspath(cached_file), progress=True)
43+
44+
return None
45+
46+
47+
def _model_path(model_type: str, model_index: int = 0) -> None:
48+
"""Return local path to a Cellpose model (downloading if needed)."""
49+
torch_str = "torch"
50+
if model_type in ("cyto", "cyto2", "nuclei"):
51+
basename = f"{model_type}{torch_str}_{model_index}"
52+
else:
53+
basename = model_type
54+
return _scportrait_cache_model_path(basename)
55+
56+
57+
def _size_model_path(model_type: str) -> None:
58+
"""Return local path to the size model (downloading if needed)."""
59+
torch_str = "torch"
60+
61+
if model_type in ("cyto", "nuclei", "cyto2", "cyto3"):
62+
if model_type == "cyto3":
63+
basename = f"size_{model_type}.npy"
64+
else:
65+
basename = f"size_{model_type}{torch_str}_0.npy"
66+
return _scportrait_cache_model_path(basename)
67+
else:
68+
# nothing to do
69+
return None
70+
71+
72+
def _download_model(name: str):
73+
try:
74+
# Try default cellpose download
75+
model_path(name)
76+
except HTTPError:
77+
print("Cellpose model server appears to be down. Trying scPortrait backup cache...")
78+
79+
# Try scPortrait backup cache
80+
_model_path(name)
81+
_size_model_path(name)
82+
print("Cellpose model and size file downloaded from scPortrait cache.")

0 commit comments

Comments
 (0)