Skip to content

Commit b46f48d

Browse files
Merge pull request #1049 from computational-cell-analytics/dev
New release
2 parents 4c81670 + 4aa59b1 commit b46f48d

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

44 files changed

+2525
-428
lines changed

.github/workflows/test.yaml

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,7 @@ jobs:
2020
fail-fast: false
2121
matrix:
2222
os: [ubuntu-latest, windows-latest, macos-latest]
23-
# 3.12 currently not supported due to issues with nifty.
24-
# python-version: ["3.11", "3.12"]
25-
python-version: ["3.11"]
23+
python-version: ["3.11", "3.12"]
2624

2725
steps:
2826
- name: Checkout

.gitignore

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -195,6 +195,8 @@ iterative_prompting_results/
195195
*.tif
196196
*.zip
197197
*MACOSX
198+
hela_ctc
199+
clf-test-data
198200

199201
# Related to i2k workshop folders.
200202
data/

environment.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ dependencies:
55
- nifty >=1.2.3
66
- imagecodecs
77
- magicgui
8-
- napari >=0.5.0,<0.6.0
8+
- napari
99
- natsort
1010
- pip
1111
- pooch

examples/README.md

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,14 @@
11
# Examples
22

33
Examples for using the `micro_sam` annotation tools:
4-
- `annotator_2d.py`: run the interactive 2d annotation tool.
5-
- `annotator_3d.py`: run the interactive 3d annotation tool.
6-
- `annotator_tracking.py`: run the interactive tracking annotation tool.
7-
- `image_series_annotator.py`: run the annotation tool for a series of images.
4+
- `annotator_2d.py`: Run the interactive 2d annotation tool.
5+
- `annotator_3d.py`: Run the interactive 3d annotation tool.
6+
- `annotator_tracking.py`: Run the interactive tracking annotation tool.
7+
- `image_series_annotator.py`: Run the annotation tool for a series of images.
8+
9+
And python scripts for automatic segmentation and tracking:
10+
- `automatic_segmentation.py`: Run automatic segmentation on 2d images.
11+
- `automatic_tracking.py`: Run automatic tracking on 2d timeseries images.
812

913
And examples for using the `micro_sam` automatic segmentation feature:
1014
- `quick_start.py`: run the automatic segmentation feature of `micro_sam` on an example 2d image.

examples/automatic_tracking.py

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
import os
2+
3+
from elf.io import open_file
4+
5+
from micro_sam.util import get_cache_directory
6+
from micro_sam.sample_data import fetch_tracking_example_data
7+
from micro_sam.automatic_segmentation import automatic_tracking, get_predictor_and_segmenter
8+
9+
10+
DATA_CACHE = os.path.join(get_cache_directory(), "sample_data")
11+
EMBEDDING_CACHE = os.path.join(get_cache_directory(), "embeddings")
12+
os.makedirs(EMBEDDING_CACHE, exist_ok=True)
13+
14+
15+
def example_automatic_tracking(use_finetuned_model):
16+
"""Run automatic tracking for data from the cell tracking challenge.
17+
"""
18+
# Download the example tracking data.
19+
example_data = fetch_tracking_example_data(DATA_CACHE)
20+
21+
# Load the example data (load the sequence of tif files as timeseries)
22+
with open_file(example_data, mode="r") as f:
23+
timeseries = f["*.tif"]
24+
25+
if use_finetuned_model:
26+
embedding_path = os.path.join(EMBEDDING_CACHE, "embeddings-ctc-vit_b_lm.zarr")
27+
model_type = "vit_b_lm"
28+
else:
29+
embedding_path = os.path.join(EMBEDDING_CACHE, "embeddings-ctc.zarr")
30+
model_type = "vit_h"
31+
32+
predictor, segmenter = get_predictor_and_segmenter(model_type=model_type, amg=False)
33+
34+
masks_tracked, _ = automatic_tracking(
35+
predictor=predictor,
36+
segmenter=segmenter,
37+
input_path=timeseries[:],
38+
output_path="./hela_ctc",
39+
embedding_path=embedding_path,
40+
)
41+
42+
import napari
43+
v = napari.Viewer()
44+
v.add_image(timeseries)
45+
v.add_labels(masks_tracked)
46+
napari.run()
47+
48+
49+
def main():
50+
# Whether to use the fine-tuned SAM model.
51+
use_finetuned_model = True
52+
example_automatic_tracking(use_finetuned_model)
53+
54+
55+
if __name__ == "__main__":
56+
main()

examples/object_classifier.py

Lines changed: 276 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,276 @@
1+
import os
2+
3+
import imageio.v3 as imageio
4+
import numpy as np
5+
6+
from micro_sam.util import get_cache_directory
7+
from micro_sam.sample_data import fetch_livecell_example_data, fetch_wholeslide_example_data, fetch_3d_example_data
8+
9+
from elf.io import open_file
10+
11+
12+
DATA_CACHE = os.path.join(get_cache_directory(), "sample_data")
13+
EMBEDDING_CACHE = os.path.join(get_cache_directory(), "embeddings")
14+
os.makedirs(EMBEDDING_CACHE, exist_ok=True)
15+
16+
17+
def livecell_annotator():
18+
from micro_sam.sam_annotator.object_classifier import object_classifier
19+
20+
example_data = fetch_livecell_example_data(DATA_CACHE)
21+
image = imageio.imread(example_data)
22+
23+
embedding_path = os.path.join(EMBEDDING_CACHE, "embeddings-livecell-vit_b_lm.zarr")
24+
model_type = "vit_b_lm"
25+
26+
# This is the vit-b-lm segmentation
27+
segmentation = imageio.imread("./clf-test-data/livecell-test-seg.tif")
28+
29+
object_classifier(image, segmentation, embedding_path=embedding_path, model_type=model_type)
30+
31+
32+
def wholeslide_annotator():
33+
from micro_sam.sam_annotator.object_classifier import object_classifier
34+
35+
example_data = fetch_wholeslide_example_data(DATA_CACHE)
36+
image = imageio.imread(example_data)
37+
38+
embedding_path = os.path.join(EMBEDDING_CACHE, "whole-slide-embeddings-vit_b_lm.zarr")
39+
model_type = "vit_b_lm"
40+
41+
segmentation = imageio.imread("./clf-test-data/whole-slide-seg.tif")
42+
object_classifier(
43+
image, segmentation, embedding_path=embedding_path, model_type=model_type,
44+
tile_shape=(1024, 1024), halo=(256, 256),
45+
)
46+
47+
48+
def lucchi_annotator():
49+
from micro_sam.sam_annotator.object_classifier import object_classifier
50+
51+
example_data = fetch_3d_example_data(DATA_CACHE)
52+
with open_file(example_data) as f:
53+
raw = f["*.png"][:]
54+
55+
embedding_path = os.path.join(EMBEDDING_CACHE, "embeddings-lucchi-vit_b_em_organelles.zarr")
56+
57+
model_type = "vit_b_lm"
58+
segmentation = imageio.imread("./clf-test-data/lucchi-test-segmentation.tif")
59+
60+
object_classifier(raw, segmentation, embedding_path=embedding_path, model_type=model_type)
61+
62+
63+
def tiled_3d_annotator():
64+
from micro_sam.sam_annotator.object_classifier import object_classifier
65+
from skimage.data import cells3d
66+
67+
data = cells3d()[30:34, 1]
68+
embed_path = "./clf-test-data/emebds-3d-tiled.zarr"
69+
70+
model_type = "vit_b_lm"
71+
segmentation = imageio.imread("./clf-test-data/tiled-3d-segmentation.tif")
72+
73+
object_classifier(
74+
data, segmentation, embedding_path=embed_path, model_type=model_type,
75+
tile_shape=(128, 128), halo=(32, 32)
76+
)
77+
78+
79+
def _get_livecell_data():
80+
example_data = fetch_livecell_example_data(DATA_CACHE)
81+
image = imageio.imread(example_data)
82+
83+
embedding_path = os.path.join(EMBEDDING_CACHE, "embeddings-livecell-vit_b_lm.zarr")
84+
85+
# This is the vit-b-lm segmentation and a test annotaiton.
86+
segmentation = imageio.imread("./clf-test-data/livecell-test-seg.tif")
87+
annotations = imageio.imread("./clf-test-data/livecell-test-annotations.tif")
88+
89+
model_type = "vit_b_lm"
90+
91+
return image, segmentation, annotations, model_type, embedding_path, None, None
92+
93+
94+
def _get_wholeslide_data():
95+
example_data = fetch_wholeslide_example_data(DATA_CACHE)
96+
image = imageio.imread(example_data)
97+
98+
embedding_path = os.path.join(EMBEDDING_CACHE, "whole-slide-embeddings-vit_b_lm.zarr")
99+
100+
# This is the vit-b-lm segmentation and a test annotaiton.
101+
segmentation = imageio.imread("./clf-test-data/whole-slide-seg.tif")
102+
annotations = imageio.imread("./clf-test-data/wholeslide-annotations.tif")
103+
104+
model_type = "vit_b_lm"
105+
106+
return image, segmentation, annotations, model_type, embedding_path, (1024, 1024), (256, 256)
107+
108+
109+
def _get_lucchi_data():
110+
example_data = fetch_3d_example_data(DATA_CACHE)
111+
with open_file(example_data) as f:
112+
raw = f["*.png"][:]
113+
114+
embedding_path = os.path.join(EMBEDDING_CACHE, "embeddings-lucchi-vit_b_em_organelles.zarr")
115+
116+
segmentation = imageio.imread("./clf-test-data/lucchi-test-segmentation.tif")
117+
annotations = imageio.imread("./clf-test-data/lucchi-annotations.tif")
118+
119+
model_type = "vit_b_em_organelles"
120+
121+
return raw, segmentation, annotations, model_type, embedding_path, None, None
122+
123+
124+
def _get_3d_tiled_data():
125+
from skimage.data import cells3d
126+
127+
data = cells3d()[30:34, 1]
128+
embed_path = "./clf-test-data/emebds-3d-tiled.zarr"
129+
model_type = "vit_b_lm"
130+
131+
segmentation = imageio.imread("./clf-test-data/tiled-3d-segmentation.tif")
132+
annotations = imageio.imread("./clf-test-data/tiled-3d-annotations.tif")
133+
134+
return data, segmentation, annotations, model_type, embed_path, (128, 128), (32, 32)
135+
136+
137+
def annotator_devel():
138+
from micro_sam import object_classification as core_clf
139+
from micro_sam.sam_annotator import object_classifier as clf
140+
from micro_sam.util import precompute_image_embeddings, get_sam_model
141+
142+
# image, segmentation, annotations, model_type, embedding_path, tile_shape, halo = _get_livecell_data()
143+
# image, segmentation, annotations, model_type, embedding_path, tile_shape, halo = _get_wholeslide_data()
144+
# image, segmentation, annotations, model_type, embedding_path, tile_shape, halo = _get_lucchi_data()
145+
image, segmentation, annotations, model_type, embedding_path, tile_shape, halo = _get_3d_tiled_data()
146+
147+
# 1. Get the SAM model
148+
predictor = get_sam_model(model_type)
149+
# 2. Precompute the image embeddings.
150+
image_embeddings = precompute_image_embeddings(
151+
predictor, image, save_path=embedding_path, tile_shape=tile_shape, halo=halo
152+
)
153+
# 3. Get the segmentation ids and the extracted features for the segmentations.
154+
seg_ids, features = core_clf.compute_object_features(image_embeddings, segmentation)
155+
# 4. Points to the objects we would like to select for training RF.
156+
labels = clf._accumulate_labels(segmentation, annotations)
157+
# 5. Traint the RF model.
158+
rf = clf._train_rf(features, labels, n_estimators=200, max_depth=10)
159+
# 6. Run the trained RF prediction on new images.
160+
object_prediction = rf.predict(features)
161+
# 7. Map the predictions back to the instance segmentation.
162+
prediction = core_clf.project_prediction_to_segmentation(segmentation, object_prediction, seg_ids)
163+
164+
import napari
165+
v = napari.Viewer()
166+
v.add_image(image)
167+
v.add_labels(annotations)
168+
v.add_labels(prediction)
169+
napari.run()
170+
171+
172+
def create_3d_data_with_tiling():
173+
from skimage.data import cells3d
174+
from micro_sam.automatic_segmentation import automatic_instance_segmentation, get_predictor_and_segmenter
175+
176+
predictor, segmenter = get_predictor_and_segmenter(model_type="vit_b_lm", is_tiled=True)
177+
data = cells3d()[30:34, 1]
178+
179+
embed_path = "./clf-test-data/emebds-3d-tiled.zarr"
180+
seg = automatic_instance_segmentation(
181+
predictor, segmenter, data, embedding_path=embed_path, ndim=3, tile_shape=(128, 128), halo=(32, 32)
182+
)
183+
184+
import napari
185+
v = napari.Viewer()
186+
v.add_image(data)
187+
v.add_labels(seg)
188+
# For annotations.
189+
v.add_labels(np.zeros_like(seg))
190+
napari.run()
191+
192+
193+
def histopathology_annotator():
194+
from torch_em.data.datasets.histopathology.lynsec import get_lynsec_paths
195+
from micro_sam.sam_annotator import object_classifier as clf
196+
from micro_sam.automatic_segmentation import automatic_instance_segmentation, get_predictor_and_segmenter
197+
198+
predictor, segmenter = get_predictor_and_segmenter(model_type="vit_b_histopathology")
199+
200+
image_paths, _ = get_lynsec_paths(path="./clf-test-data/nuclick", choice="ihc", download=True)
201+
image_paths = image_paths[:10]
202+
203+
images, segmentations = [], []
204+
embedding_paths = []
205+
206+
for i, image_path in enumerate(image_paths):
207+
image = imageio.imread(image_path)
208+
embedding_path = f"./clf-test-data/embeddings_nuclick_{i}.zarr"
209+
seg_path = f"./clf-test-data/seg-nuclick_{i}.tif"
210+
211+
if os.path.exists(seg_path):
212+
segmentation = imageio.imread(seg_path)
213+
else:
214+
segmentation = automatic_instance_segmentation(
215+
predictor, segmenter, embedding_path=embedding_path, input_path=image, ndim=2,
216+
)
217+
imageio.imwrite(seg_path, segmentation, compression="zlib")
218+
219+
images.append(image)
220+
segmentations.append(segmentation)
221+
embedding_paths.append(embedding_path)
222+
223+
clf.image_series_object_classifier(
224+
images, segmentations, output_folder="./clf-test-data/histo-results",
225+
embedding_paths=embedding_paths, model_type="vit_b_histopathology", ndim=2,
226+
)
227+
228+
229+
def batch_prediction():
230+
import napari
231+
from torch_em.data.datasets.histopathology.lynsec import get_lynsec_paths
232+
from micro_sam.automatic_segmentation import automatic_instance_segmentation, get_predictor_and_segmenter
233+
from micro_sam.object_classification import run_prediction_with_object_classifier
234+
from tqdm import tqdm
235+
236+
predictor, segmenter = get_predictor_and_segmenter(model_type="vit_b_histopathology")
237+
238+
image_paths, _ = get_lynsec_paths(path="./clf-test-data/nuclick", choice="ihc", download=True)
239+
# Test the batch prediction on the next 5 images.
240+
image_paths = image_paths[10:12]
241+
242+
images, segmentations = [], []
243+
# Prepare images and segmentations
244+
for image_path in tqdm(image_paths, desc="Segment images"):
245+
image = imageio.imread(image_path)
246+
segmentation = automatic_instance_segmentation(predictor, segmenter, input_path=image, ndim=2, verbose=False)
247+
images.append(image)
248+
segmentations.append(segmentation)
249+
250+
rf_path = "clf-test-data/histo-results/rf.joblib"
251+
print("Start object clf")
252+
predictions = run_prediction_with_object_classifier(images, segmentations, predictor, rf_path, ndim=2)
253+
254+
for im, seg, pred in zip(images, segmentations, predictions):
255+
v = napari.Viewer()
256+
v.add_image(im)
257+
v.add_labels(seg)
258+
v.add_labels(pred)
259+
napari.run()
260+
261+
262+
def main():
263+
# create_3d_data_with_tiling()
264+
265+
# livecell_annotator()
266+
# wholeslide_annotator()
267+
# lucchi_annotator()
268+
# tiled_3d_annotator()
269+
histopathology_annotator()
270+
# batch_prediction()
271+
272+
# annotator_devel()
273+
274+
275+
if __name__ == "__main__":
276+
main()

0 commit comments

Comments
 (0)