Skip to content

Commit 17fafa2

Browse files
Merge pull request #98 from computational-cell-analytics/custom_models
Add preliminary custom models
2 parents 449e025 + ecc3434 commit 17fafa2

File tree

8 files changed

+92
-26
lines changed

8 files changed

+92
-26
lines changed

environment_cpu.yaml

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,8 @@ dependencies:
99
- pooch
1010
- python-elf >=0.4.8
1111
- pytorch
12+
- segment-anything
1213
- torchvision
1314
- tqdm
14-
- pip:
15-
- git+https://github.com/facebookresearch/segment-anything.git
15+
# - pip:
16+
# - git+https://github.com/facebookresearch/segment-anything.git

environment_gpu.yaml

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,8 @@ dependencies:
1010
- python-elf >=0.4.8
1111
- pytorch
1212
- pytorch-cuda>=11.7 # you may need to update the cuda version to match your system
13+
- segment-anything
1314
- torchvision
1415
- tqdm
15-
- pip:
16-
- git+https://github.com/facebookresearch/segment-anything.git
16+
# - pip:
17+
# - git+https://github.com/facebookresearch/segment-anything.git

examples/sam_annotator_2d.py

Lines changed: 37 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -3,47 +3,72 @@
33
from micro_sam.sample_data import fetch_hela_2d_example_data, fetch_livecell_example_data, fetch_wholeslide_example_data
44

55

6-
def livecell_annotator():
6+
def livecell_annotator(use_finetuned_model):
77
"""Run the 2d annotator for an example image from the LiveCELL dataset.
88
99
See https://doi.org/10.1038/s41592-021-01249-6 for details on the data.
1010
"""
1111
example_data = fetch_livecell_example_data("./data")
1212
image = imageio.imread(example_data)
13-
embedding_path = "./embeddings/embeddings-livecell.zarr"
14-
annotator_2d(image, embedding_path, show_embeddings=False)
1513

14+
if use_finetuned_model:
15+
embedding_path = "./embeddings/embeddings-livecell-vit_h_lm.zarr"
16+
model_type = "vit_h_lm"
17+
else:
18+
embedding_path = "./embeddings/embeddings-livecell.zarr"
19+
model_type = "vit_h"
1620

17-
def hela_2d_annotator():
21+
annotator_2d(image, embedding_path, show_embeddings=False, model_type=model_type)
22+
23+
24+
def hela_2d_annotator(use_finetuned_model):
1825
"""Run the 2d annotator for an example image form the cell tracking challenge HeLa 2d dataset.
1926
"""
2027
example_data = fetch_hela_2d_example_data("./data")
2128
image = imageio.imread(example_data)
22-
embedding_path = "./embeddings/embeddings-hela2d.zarr"
23-
annotator_2d(image, embedding_path, show_embeddings=False)
29+
30+
if use_finetuned_model:
31+
embedding_path = "./embeddings/embeddings-hela2d-vit_h_lm.zarr"
32+
model_type = "vit_h_lm"
33+
else:
34+
embedding_path = "./embeddings/embeddings-hela2d.zarr"
35+
model_type = "vit_h"
36+
37+
annotator_2d(image, embedding_path, show_embeddings=False, model_type=model_type)
2438

2539

26-
def wholeslide_annotator():
40+
def wholeslide_annotator(use_finetuned_model):
2741
"""Run the 2d annotator with tiling for an example whole-slide image from the
2842
NeuRIPS cell segmentation challenge.
2943
3044
See https://neurips22-cellseg.grand-challenge.org/ for details on the data.
3145
"""
3246
example_data = fetch_wholeslide_example_data("./data")
3347
image = imageio.imread(example_data)
34-
embedding_path = "./embeddings/whole-slide-embeddings.zarr"
35-
annotator_2d(image, embedding_path, tile_shape=(1024, 1024), halo=(256, 256))
48+
49+
if use_finetuned_model:
50+
embedding_path = "./embeddings/whole-slide-embeddings-vit_h_lm.zarr"
51+
model_type = "vit_h_lm"
52+
else:
53+
embedding_path = "./embeddings/whole-slide-embeddings.zarr"
54+
model_type = "vit_h"
55+
56+
annotator_2d(image, embedding_path, tile_shape=(1024, 1024), halo=(256, 256), model_type=model_type)
3657

3758

3859
def main():
60+
# whether to use the fine-tuned SAM model
61+
# this feature is still experimental!
62+
use_finetuned_model = False
63+
3964
# 2d annotator for livecell data
40-
# livecell_annotator()
65+
# livecell_annotator(use_finetuned_model)
4166

4267
# 2d annotator for cell tracking challenge hela data
43-
hela_2d_annotator()
68+
# hela_2d_annotator(use_finetuned_model)
4469

4570
# 2d annotator for a whole slide image
46-
# wholeslide_annotator()
71+
wholeslide_annotator(use_finetuned_model)
4772

4873

4974
if __name__ == "__main__":

examples/sam_annotator_tracking.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,20 +3,31 @@
33
from micro_sam.sample_data import fetch_tracking_example_data
44

55

6-
def track_ctc_data():
6+
def track_ctc_data(use_finetuned_model):
77
"""Run interactive tracking for data from the cell tracking challenge.
88
"""
99
# download the example data
1010
example_data = fetch_tracking_example_data("./data")
1111
# load the example data (load the sequence of tif files as timeseries)
1212
with open_file(example_data, mode="r") as f:
1313
timeseries = f["*.tif"]
14+
15+
if use_finetuned_model:
16+
embedding_path = "./embeddings/embeddings-ctc-vit_h_lm.zarr"
17+
model_type = "vit_h_lm"
18+
else:
19+
embedding_path = "./embeddings/embeddings-ctc.zarr"
20+
model_type = "vit_h"
21+
1422
# start the annotator with cached embeddings
15-
annotator_tracking(timeseries, embedding_path="./embeddings/embeddings-ctc.zarr", show_embeddings=False)
23+
annotator_tracking(timeseries, embedding_path=embedding_path, show_embeddings=False, model_type=model_type)
1624

1725

1826
def main():
19-
track_ctc_data()
27+
# whether to use the fine-tuned SAM model
28+
# this feature is still experimental!
29+
use_finetuned_model = False
30+
track_ctc_data(use_finetuned_model)
2031

2132

2233
if __name__ == "__main__":

micro_sam/sam_annotator/annotator_3d.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from magicgui import magicgui
88
from napari import Viewer
99
from napari.utils import progress
10+
from segment_anything import SamPredictor
1011

1112
from .. import util
1213
from ..prompt_based_segmentation import segment_from_mask
@@ -195,10 +196,15 @@ def annotator_3d(
195196
tile_shape: Optional[Tuple[int, int]] = None,
196197
halo: Optional[Tuple[int, int]] = None,
197198
return_viewer: bool = False,
199+
predictor: Optional[SamPredictor] = None,
198200
) -> None:
199201
# for access to the predictor and the image embeddings in the widgets
200202
global PREDICTOR, IMAGE_EMBEDDINGS
201-
PREDICTOR = util.get_sam_model(model_type=model_type)
203+
204+
if predictor is None:
205+
PREDICTOR = util.get_sam_model(model_type=model_type)
206+
else:
207+
PREDICTOR = predictor
202208
IMAGE_EMBEDDINGS = util.precompute_image_embeddings(
203209
PREDICTOR, raw, save_path=embedding_path, tile_shape=tile_shape, halo=halo,
204210
wrong_file_callback=show_wrong_file_warning,

micro_sam/sam_annotator/annotator_tracking.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from napari import Viewer
1010
from napari.utils import progress
1111
from scipy.ndimage import shift
12+
from segment_anything import SamPredictor
1213

1314
# this is more precise for comuting the centers, but slow!
1415
# from vigra.filters import eccentricityCenters
@@ -355,12 +356,16 @@ def annotator_tracking(
355356
tile_shape: Optional[Tuple[int, int]] = None,
356357
halo: Optional[Tuple[int, int]] = None,
357358
return_viewer: bool = False,
359+
predictor: Optional[SamPredictor] = None,
358360
) -> None:
359361
# global state
360362
global PREDICTOR, IMAGE_EMBEDDINGS, CURRENT_TRACK_ID, LINEAGE
361363
global TRACKING_WIDGET
362364

363-
PREDICTOR = util.get_sam_model(model_type=model_type)
365+
if predictor is None:
366+
PREDICTOR = util.get_sam_model(model_type=model_type)
367+
else:
368+
PREDICTOR = predictor
364369
IMAGE_EMBEDDINGS = util.precompute_image_embeddings(
365370
PREDICTOR, raw, save_path=embedding_path, tile_shape=tile_shape, halo=halo,
366371
wrong_file_callback=show_wrong_file_warning,

micro_sam/sam_annotator/image_series_annotator.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@
88

99
from magicgui import magicgui
1010
from napari.utils import progress as tqdm
11+
from segment_anything import SamPredictor
12+
1113
from .annotator_2d import annotator_2d
1214
from .. import util
1315

@@ -32,6 +34,7 @@ def image_series_annotator(
3234
image_files: List[str],
3335
output_folder: str,
3436
embedding_path: Optional[str] = None,
37+
predictor: Optional[SamPredictor] = None,
3538
**kwargs
3639
) -> None:
3740
"""
@@ -45,7 +48,8 @@ def image_series_annotator(
4548
os.makedirs(output_folder, exist_ok=True)
4649
next_image_id = 0
4750

48-
predictor = util.get_sam_model(model_type=kwargs.get("model_type", "vit_h"))
51+
if predictor is None:
52+
predictor = util.get_sam_model(model_type=kwargs.get("model_type", "vit_h"))
4953
if embedding_path is None:
5054
embedding_paths = None
5155
else:
@@ -101,12 +105,13 @@ def image_folder_annotator(
101105
output_folder: str,
102106
pattern: str = "*",
103107
embedding_path: Optional[str] = None,
108+
predictor: Optional[SamPredictor] = None,
104109
**kwargs
105110
) -> None:
106111
"""
107112
"""
108113
image_files = sorted(glob(os.path.join(input_folder, pattern)))
109-
image_series_annotator(image_files, output_folder, embedding_path, **kwargs)
114+
image_series_annotator(image_files, output_folder, embedding_path, predictor, **kwargs)
110115

111116

112117
def main():

micro_sam/util.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,13 +25,19 @@
2525
_MODEL_URLS = {
2626
"vit_h": "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth",
2727
"vit_l": "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_l_0b3195.pth",
28-
"vit_b": "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth"
28+
"vit_b": "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth",
29+
# preliminary finetuned models
30+
"vit_h_lm": "https://owncloud.gwdg.de/index.php/s/CnxBvsdGPN0TD3A/download",
31+
"vit_b_lm": "https://owncloud.gwdg.de/index.php/s/gGlR1LFsav0eQ2k/download",
2932
}
3033
_CHECKPOINT_FOLDER = os.environ.get("SAM_MODELS", os.path.expanduser("~/.sam_models"))
3134
_CHECKSUMS = {
3235
"vit_h": "a7bf3b02f3ebf1267aba913ff637d9a2d5c33d3173bb679e46d9f338c26f262e",
3336
"vit_l": "3adcc4315b642a4d2101128f611684e8734c41232a17c648ed1693702a49a622",
34-
"vit_b": "ec2df62732614e57411cdcf32a23ffdf28910380d03139ee0f4fcbe91eb8c912"
37+
"vit_b": "ec2df62732614e57411cdcf32a23ffdf28910380d03139ee0f4fcbe91eb8c912",
38+
# preliminary finetuned models
39+
"vit_h_lm": "c30a580e6ccaff2f4f0fbaf9cad10cee615a915cdd8c7bc4cb50ea9bdba3fc09",
40+
"vit_b_lm": "f2b8676f92a123f6f8ac998818118bd7269a559381ec60af4ac4be5c86024a1b",
3541
}
3642

3743

@@ -105,7 +111,13 @@ def get_sam_model(
105111
"""
106112
checkpoint = _get_checkpoint(model_type, checkpoint_path)
107113
device = "cuda" if torch.cuda.is_available() else "cpu"
108-
sam = sam_model_registry[model_type](checkpoint=checkpoint)
114+
115+
# Our custom model types have a suffix "_...". This suffix needs to be stripped
116+
# before calling sam_model_registry.
117+
model_type_ = model_type[:5]
118+
assert model_type_ in ("vit_h", "vit_b", "vit_l")
119+
120+
sam = sam_model_registry[model_type_](checkpoint=checkpoint)
109121
sam.to(device=device)
110122
predictor = SamPredictor(sam)
111123
if return_sam:

0 commit comments

Comments
 (0)