|
| 1 | +import os |
| 2 | +from urllib.error import HTTPError |
| 3 | +from urllib.parse import quote |
| 4 | + |
| 5 | +from cellpose import utils |
| 6 | +from cellpose.models import MODEL_DIR, model_path |
| 7 | + |
| 8 | +ZENODO_RECORD_ID = "17564109" |
| 9 | + |
| 10 | + |
| 11 | +def _make_zenodo_download_link(record_id: str, filename: str) -> str: |
| 12 | + """ |
| 13 | + Construct a direct download URL for a file stored in a Zenodo record. |
| 14 | +
|
| 15 | + Parameters |
| 16 | + ---------- |
| 17 | + record_id : str |
| 18 | + The Zenodo record identifier (e.g., "1234567"). |
| 19 | + filename : str |
| 20 | + The exact filename stored in the Zenodo record (case sensitive). |
| 21 | +
|
| 22 | + Returns |
| 23 | + ------- |
| 24 | + str |
| 25 | + A direct HTTPS download URL suitable for urllib / requests / wget. |
| 26 | + """ |
| 27 | + return f"https://zenodo.org/records/{record_id}/files/{quote(filename)}?download=1" |
| 28 | + |
| 29 | + |
| 30 | +def _scportrait_cache_model_path(basename: str) -> None: |
| 31 | + """Download a model from a public Nextcloud share into Cellpose's model cache if missing.""" |
| 32 | + MODEL_DIR.mkdir(parents=True, exist_ok=True) |
| 33 | + |
| 34 | + url = _make_zenodo_download_link( |
| 35 | + record_id=ZENODO_RECORD_ID, |
| 36 | + filename=basename, |
| 37 | + ) |
| 38 | + cached_file = MODEL_DIR / basename |
| 39 | + |
| 40 | + if not cached_file.exists(): |
| 41 | + print(f'Downloading: "{url}" → {cached_file}') |
| 42 | + utils.download_url_to_file(url, os.fspath(cached_file), progress=True) |
| 43 | + |
| 44 | + return None |
| 45 | + |
| 46 | + |
| 47 | +def _model_path(model_type: str, model_index: int = 0) -> None: |
| 48 | + """Return local path to a Cellpose model (downloading if needed).""" |
| 49 | + torch_str = "torch" |
| 50 | + if model_type in ("cyto", "cyto2", "nuclei"): |
| 51 | + basename = f"{model_type}{torch_str}_{model_index}" |
| 52 | + else: |
| 53 | + basename = model_type |
| 54 | + return _scportrait_cache_model_path(basename) |
| 55 | + |
| 56 | + |
| 57 | +def _size_model_path(model_type: str) -> None: |
| 58 | + """Return local path to the size model (downloading if needed).""" |
| 59 | + torch_str = "torch" |
| 60 | + |
| 61 | + if model_type in ("cyto", "nuclei", "cyto2", "cyto3"): |
| 62 | + if model_type == "cyto3": |
| 63 | + basename = f"size_{model_type}.npy" |
| 64 | + else: |
| 65 | + basename = f"size_{model_type}{torch_str}_0.npy" |
| 66 | + return _scportrait_cache_model_path(basename) |
| 67 | + else: |
| 68 | + # nothing to do |
| 69 | + return None |
| 70 | + |
| 71 | + |
| 72 | +def _download_model(name: str): |
| 73 | + try: |
| 74 | + # Try default cellpose download |
| 75 | + model_path(name) |
| 76 | + except HTTPError: |
| 77 | + print("Cellpose model server appears to be down. Trying scPortrait backup cache...") |
| 78 | + |
| 79 | + # Try scPortrait backup cache |
| 80 | + _model_path(name) |
| 81 | + _size_model_path(name) |
| 82 | + print("Cellpose model and size file downloaded from scPortrait cache.") |
0 commit comments