Skip to content

Commit 2152b26

Browse files
Add sample data
1 parent b24b2e8 commit 2152b26

File tree

4 files changed

+83
-14
lines changed

4 files changed

+83
-14
lines changed

flamingo_tools/model_utils.py

Lines changed: 47 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import os
2-
from typing import Optional, Union
2+
from typing import Dict, Optional, Union
33

44
import pooch
55
import torch
@@ -113,3 +113,49 @@ def get_model(model_type: str, device: Optional[Union[str, torch.device]] = None
113113
model = torch.load(model_path, weights_only=False)
114114
model.to(device)
115115
return model
116+
117+
118+
def get_default_tiling() -> Dict[str, Dict[str, int]]:
119+
"""Determine the tile shape and halo depending on the available VRAM.
120+
121+
Returns:
122+
The default tiling settings for the available computational resources.
123+
"""
124+
if torch.cuda.is_available():
125+
# The default halo size.
126+
halo = {"x": 64, "y": 64, "z": 16}
127+
128+
# Determine the GPU RAM and derive a suitable tiling.
129+
vram = torch.cuda.get_device_properties(0).total_memory / 1e9
130+
131+
if vram >= 80:
132+
tile = {"x": 640, "y": 640, "z": 80}
133+
elif vram >= 40:
134+
tile = {"x": 512, "y": 512, "z": 64}
135+
elif vram >= 20:
136+
tile = {"x": 352, "y": 352, "z": 48}
137+
elif vram >= 10:
138+
tile = {"x": 256, "y": 256, "z": 32}
139+
halo = {"x": 64, "y": 64, "z": 8} # Choose a smaller halo in z.
140+
else:
141+
raise NotImplementedError(f"Infererence with a GPU with {vram} GB VRAM is not supported.")
142+
143+
tiling = {"tile": tile, "halo": halo}
144+
print(f"Determined tile size for CUDA: {tiling}")
145+
146+
elif torch.backends.mps.is_available(): # Check for Apple Silicon (MPS)
147+
tile = {"x": 256, "y": 256, "z": 16}
148+
halo = {"x": 16, "y": 16, "z": 4}
149+
tiling = {"tile": tile, "halo": halo}
150+
print(f"Determined tile size for MPS: {tiling}")
151+
152+
# I am not sure what is reasonable on a cpu. For now choosing very small tiling.
153+
# (This will not work well on a CPU in any case.)
154+
else:
155+
tiling = {
156+
"tile": {"x": 96, "y": 96, "z": 16},
157+
"halo": {"x": 16, "y": 16, "z": 4},
158+
}
159+
print(f"Determining default tiling for CPU: {tiling}")
160+
161+
return tiling

flamingo_tools/plugin/segmentation_widget.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,7 @@
99
from qtpy.QtWidgets import QWidget, QVBoxLayout, QPushButton, QLabel, QComboBox
1010

1111
from .base_widget import BaseWidget
12-
from .util import get_default_tiling, get_device
13-
from ..model_utils import get_model, get_model_registry
12+
from ..model_utils import get_model, get_model_registry, get_device, get_default_tiling
1413

1514

1615
def _load_custom_model(model_path: str, device: Optional[Union[str, torch.device]] = None) -> torch.nn.Module:

flamingo_tools/plugin/util.py

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +0,0 @@
1-
2-
3-
def get_default_tiling():
4-
pass
5-
6-
7-
def get_device():
8-
pass

flamingo_tools/test_data.py

Lines changed: 35 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,12 @@
22
from typing import Tuple
33

44
import imageio.v3 as imageio
5+
import pooch
56
import requests
67
from skimage.data import binary_blobs, cells3d
78
from skimage.measure import label
89

10+
from .file_utils import get_cache_dir
911
from .segmentation.postprocessing import compute_table_on_the_fly
1012

1113
SEGMENTATION_URL = "https://owncloud.gwdg.de/index.php/s/kwoGRYiJRRrswgw/download"
@@ -93,13 +95,43 @@ def create_test_data(root: str, size: int = 256, n_channels: int = 2, n_tiles: i
9395
imageio.imwrite(out_path, data)
9496

9597

98+
def _sample_registry():
99+
urls = {
100+
"PV": "https://owncloud.gwdg.de/index.php/s/JVZCOpkILT70sdv/download",
101+
"VGlut3": "https://owncloud.gwdg.de/index.php/s/LvGXh0xQR9IKvNk/download",
102+
"CTBP2": "https://owncloud.gwdg.de/index.php/s/qaffCaF1sGpqlT3/download",
103+
}
104+
registry = {
105+
"PV": "fbf50cc9119f2dd2bd4dac7d76b746b7d42cab33b94b21f8df304478dd51e632",
106+
"VGlut3": "6a3af6ffce3d06588ffdc73df356ac64b83b53aaf6aabeabd49ef6d11d927e20",
107+
"CTBP2": "8dcd5f1ebb35194f328788594e275f2452de0e28c85073578dac7100d83c45fc",
108+
}
109+
cache_dir = get_cache_dir()
110+
data_registry = pooch.create(
111+
path=os.path.join(cache_dir, "data"),
112+
base_url="",
113+
registry=registry,
114+
urls=urls,
115+
)
116+
return data_registry
117+
118+
96119
def sample_data_pv():
97-
pass
120+
data_path = _sample_registry().fetch("PV")
121+
data = imageio.imread(data_path, extension=".tif")
122+
add_image_kwargs = {"name": "PV", "colormap": "gray"}
123+
return [(data, add_image_kwargs)]
98124

99125

100126
def sample_data_vglut3():
101-
pass
127+
data_path = _sample_registry().fetch("VGlut3")
128+
data = imageio.imread(data_path, extension=".tif")
129+
add_image_kwargs = {"name": "VGlut3", "colormap": "gray"}
130+
return [(data, add_image_kwargs)]
102131

103132

104133
def sample_data_ctbp2():
105-
pass
134+
data_path = _sample_registry().fetch("CTBP2")
135+
data = imageio.imread(data_path, extension=".tif")
136+
add_image_kwargs = {"name": "CTBP2", "colormap": "gray"}
137+
return [(data, add_image_kwargs)]

0 commit comments

Comments
 (0)