Skip to content

Commit 7c2b0c9

Browse files
Add option to use fine-tuned models in example scripts
1 parent 55027d1 commit 7c2b0c9

File tree

2 files changed

+51
-15
lines changed

2 files changed

+51
-15
lines changed

examples/sam_annotator_2d.py

Lines changed: 37 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -3,47 +3,72 @@
33
from micro_sam.sample_data import fetch_hela_2d_example_data, fetch_livecell_example_data, fetch_wholeslide_example_data
44

55

6-
def livecell_annotator():
6+
def livecell_annotator(use_finetuned_model):
77
"""Run the 2d annotator for an example image from the LiveCELL dataset.
88
99
See https://doi.org/10.1038/s41592-021-01249-6 for details on the data.
1010
"""
1111
example_data = fetch_livecell_example_data("./data")
1212
image = imageio.imread(example_data)
13-
embedding_path = "./embeddings/embeddings-livecell.zarr"
14-
annotator_2d(image, embedding_path, show_embeddings=False)
1513

14+
if use_finetuned_model:
15+
embedding_path = "./embeddings/embeddings-livecell-vit_h_lm.zarr"
16+
model_type = "vit_h_lm"
17+
else:
18+
embedding_path = "./embeddings/embeddings-livecell.zarr"
19+
model_type = "vit_h"
1620

17-
def hela_2d_annotator():
21+
annotator_2d(image, embedding_path, show_embeddings=False, model_type=model_type)
22+
23+
24+
def hela_2d_annotator(use_finetuned_model):
1825
"""Run the 2d annotator for an example image form the cell tracking challenge HeLa 2d dataset.
1926
"""
2027
example_data = fetch_hela_2d_example_data("./data")
2128
image = imageio.imread(example_data)
22-
embedding_path = "./embeddings/embeddings-hela2d.zarr"
23-
annotator_2d(image, embedding_path, show_embeddings=False)
29+
30+
if use_finetuned_model:
31+
embedding_path = "./embeddings/embeddings-hela2d-vit_h_lm.zarr"
32+
model_type = "vit_h_lm"
33+
else:
34+
embedding_path = "./embeddings/embeddings-hela2d.zarr"
35+
model_type = "vit_h"
36+
37+
annotator_2d(image, embedding_path, show_embeddings=False, model_type=model_type)
2438

2539

26-
def wholeslide_annotator():
40+
def wholeslide_annotator(use_finetuned_model):
2741
"""Run the 2d annotator with tiling for an example whole-slide image from the
2842
NeuRIPS cell segmentation challenge.
2943
3044
See https://neurips22-cellseg.grand-challenge.org/ for details on the data.
3145
"""
3246
example_data = fetch_wholeslide_example_data("./data")
3347
image = imageio.imread(example_data)
34-
embedding_path = "./embeddings/whole-slide-embeddings.zarr"
35-
annotator_2d(image, embedding_path, tile_shape=(1024, 1024), halo=(256, 256))
48+
49+
if use_finetuned_model:
50+
embedding_path = "./embeddings/whole-slide-embeddings-vit_h_lm.zarr"
51+
model_type = "vit_h_lm"
52+
else:
53+
embedding_path = "./embeddings/whole-slide-embeddings.zarr"
54+
model_type = "vit_h"
55+
56+
annotator_2d(image, embedding_path, tile_shape=(1024, 1024), halo=(256, 256), model_type=model_type)
3657

3758

3859
def main():
60+
# whether to use the fine-tuned SAM model
61+
# this feature is still experimental!
62+
use_finetuned_model = False
63+
3964
# 2d annotator for livecell data
40-
# livecell_annotator()
65+
# livecell_annotator(use_finetuned_model)
4166

4267
# 2d annotator for cell tracking challenge hela data
43-
hela_2d_annotator()
68+
# hela_2d_annotator(use_finetuned_model)
4469

4570
# 2d annotator for a whole slide image
46-
# wholeslide_annotator()
71+
wholeslide_annotator(use_finetuned_model)
4772

4873

4974
if __name__ == "__main__":

examples/sam_annotator_tracking.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,20 +3,31 @@
33
from micro_sam.sample_data import fetch_tracking_example_data
44

55

6-
def track_ctc_data():
6+
def track_ctc_data(use_finetuned_model):
77
"""Run interactive tracking for data from the cell tracking challenge.
88
"""
99
# download the example data
1010
example_data = fetch_tracking_example_data("./data")
1111
# load the example data (load the sequence of tif files as timeseries)
1212
with open_file(example_data, mode="r") as f:
1313
timeseries = f["*.tif"]
14+
15+
if use_finetuned_model:
16+
embedding_path = "./embeddings/embeddings-ctc-vit_h_lm.zarr"
17+
model_type = "vit_h_lm"
18+
else:
19+
embedding_path = "./embeddings/embeddings-ctc.zarr"
20+
model_type = "vit_h"
21+
1422
# start the annotator with cached embeddings
15-
annotator_tracking(timeseries, embedding_path="./embeddings/embeddings-ctc.zarr", show_embeddings=False)
23+
annotator_tracking(timeseries, embedding_path=embedding_path, show_embeddings=False, model_type=model_type)
1624

1725

1826
def main():
19-
track_ctc_data()
27+
# whether to use the fine-tuned SAM model
28+
# this feature is still experimental!
29+
use_finetuned_model = False
30+
track_ctc_data(use_finetuned_model)
2031

2132

2233
if __name__ == "__main__":

0 commit comments

Comments
 (0)