Skip to content

Commit b24b2e8

Browse files
Update segmentation widget
1 parent e6deb38 commit b24b2e8

File tree

4 files changed

+131
-10
lines changed

4 files changed

+131
-10
lines changed

flamingo_tools/file_utils.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
1+
import os
12
import warnings
23
from typing import Optional, Union
34

45
import imageio.v3 as imageio
56
import numpy as np
7+
import pooch
68
import tifffile
79
import zarr
810
from elf.io import open_file
@@ -13,6 +15,18 @@
1315
from zarr._storage.store import BaseStore as Store
1416

1517

18+
def get_cache_dir() -> str:
19+
"""Get the cache directory of CochleaNet.
20+
21+
The default cache directory is "$HOME/cochlea-net"
22+
23+
Returns:
24+
The cache directory.
25+
"""
26+
cache_dir = os.path.expanduser(pooch.os_cache("cochlea-net"))
27+
return cache_dir
28+
29+
1630
def _parse_shape(metadata_file):
1731
depth, height, width = None, None, None
1832

flamingo_tools/model_utils.py

Lines changed: 115 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,115 @@
1+
import os
2+
from typing import Optional, Union
3+
4+
import pooch
5+
import torch
6+
from .file_utils import get_cache_dir
7+
8+
9+
def _get_default_device():
10+
# Check that we're in CI and use the CPU if we are.
11+
# Otherwise the tests may run out of memory on MAC if MPS is used.
12+
if os.getenv("GITHUB_ACTIONS") == "true":
13+
return "cpu"
14+
# Use cuda enabled gpu if it's available.
15+
if torch.cuda.is_available():
16+
device = "cuda"
17+
# As second priority use mps.
18+
# See https://pytorch.org/docs/stable/notes/mps.html for details
19+
elif torch.backends.mps.is_available() and torch.backends.mps.is_built():
20+
device = "mps"
21+
# Use the CPU as fallback.
22+
else:
23+
device = "cpu"
24+
return device
25+
26+
27+
def get_device(device: Optional[Union[str, torch.device]] = None) -> Union[str, torch.device]:
28+
"""Get the torch device.
29+
30+
If no device is passed the default device for your system is used.
31+
Else it will be checked if the device you have passed is supported.
32+
33+
Args:
34+
device: The input device.
35+
36+
Returns:
37+
The device.
38+
"""
39+
if device is None or device == "auto":
40+
device = _get_default_device()
41+
else:
42+
device_type = device if isinstance(device, str) else device.type
43+
if device_type.lower() == "cuda":
44+
if not torch.cuda.is_available():
45+
raise RuntimeError("PyTorch CUDA backend is not available.")
46+
elif device_type.lower() == "mps":
47+
if not (torch.backends.mps.is_available() and torch.backends.mps.is_built()):
48+
raise RuntimeError("PyTorch MPS backend is not available or is not built correctly.")
49+
elif device_type.lower() == "cpu":
50+
pass # cpu is always available
51+
else:
52+
raise RuntimeError(f"Unsupported device: {device}. Please choose from 'cpu', 'cuda', or 'mps'.")
53+
return device
54+
55+
56+
def get_model_registry() -> None:
57+
"""Get the model registry for downloading pre-trained CochleaNet models.
58+
"""
59+
registry = {
60+
"SGN": "3058690b49015d6210a8e8414eb341c34189fee660b8fac438f1fdc41bdfff98",
61+
"IHC": "89afbcca08ed302aa6dfbaba5bf2530fc13339c05a604b6f2551d97cf5f12774",
62+
"Synapses": "2a42712b056f082b4794f15cf41b15678aab0bec1acc922ff9f0dc76abe6747e",
63+
# TODO
64+
# "SGN-lowres": "",
65+
# "IHC-lowres": "",
66+
}
67+
urls = {
68+
"SGN": "https://owncloud.gwdg.de/index.php/s/NZ2vv7hxX1imITG/download",
69+
"IHC": "https://owncloud.gwdg.de/index.php/s/GBBJkPQFraz1ZzU/download",
70+
"Synapses": "https://owncloud.gwdg.de/index.php/s/A9W5NmOeBxiyZgY/download",
71+
# TODO
72+
# "SGN-lowres": "",
73+
# "IHC-lowres": "",
74+
}
75+
cache_dir = get_cache_dir()
76+
models = pooch.create(
77+
path=os.path.join(cache_dir, "models"),
78+
base_url="",
79+
registry=registry,
80+
urls=urls,
81+
)
82+
return models
83+
84+
85+
def get_model_path(model_type: str) -> str:
86+
"""Get the local path to a pretrained model.
87+
88+
Args:
89+
The model type.
90+
91+
Returns:
92+
The local path to the model.
93+
"""
94+
model_registry = get_model_registry()
95+
model_path = model_registry.fetch(model_type)
96+
return model_path
97+
98+
99+
def get_model(model_type: str, device: Optional[Union[str, torch.device]] = None) -> torch.nn.Module:
100+
"""Get the model for a specific segmentation type.
101+
102+
Args:
103+
model_type: The model for one of the following segmentation or detection tasks:
104+
'SGN', 'IHC', 'Synapses', 'SGN-lowres', 'IHC-lowres'.
105+
device: The device to use.
106+
107+
Returns:
108+
The model.
109+
"""
110+
if device is None:
111+
device = get_device(device)
112+
model_path = get_model_path(model_type)
113+
model = torch.load(model_path, weights_only=False)
114+
model.to(device)
115+
return model

flamingo_tools/models.py

Lines changed: 0 additions & 8 deletions
This file was deleted.

flamingo_tools/plugin/segmentation_widget.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010

1111
from .base_widget import BaseWidget
1212
from .util import get_default_tiling, get_device
13-
from ..models import get_model, get_model_registry
13+
from ..model_utils import get_model, get_model_registry
1414

1515

1616
def _load_custom_model(model_path: str, device: Optional[Union[str, torch.device]] = None) -> torch.nn.Module:
@@ -111,7 +111,7 @@ def load_model_widget(self):
111111
title_label = QLabel("Select Model:")
112112

113113
# Exclude the models that are only offered through the CLI and not in the plugin.
114-
model_list = set(get_model_registry().urls.keys())
114+
model_list = list(get_model_registry().urls.keys())
115115

116116
models = ["- choose -"] + model_list
117117
self.model_selector = QComboBox()

0 commit comments

Comments
 (0)