Skip to content

Commit cf1ca66

Browse files
Merge pull request #4 from computational-cell-analytics/tracking
Further implement tracking functionality
2 parents bc35fcc + 4324163 commit cf1ca66

File tree

12 files changed

+272
-108
lines changed

12 files changed

+272
-108
lines changed

README.md

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -56,8 +56,8 @@ pip install -e .
5656
## Usage
5757

5858
After the installation the three applications for interactive annotations can be started from the command line or within a python script:
59-
- **2d annotation**: via the command `micro_sam.annotator_2d` or with the function `micro_sam.sam_annotator.annotator_2d` from python. Run `micro_sam.annotator_2d -h` or check out [examples/sam_annotator_2d](https://github.com/computational-cell-analytics/micro-sam/blob/master/examples/sam_annotator_2d.py) for details.
60-
- **3d annotation**: via the command `micro_sam.annotator_3d` or with the function `micro_sam.sam_annotator.annotator_3d` from python. Run `micro_sam.annotator_3d -h` or check out [examples/sam_annotator_3d](https://github.com/computational-cell-analytics/micro-sam/blob/master/examples/sam_annotator_3d.py) for details.
59+
- **2d segmentation**: via the command `micro_sam.annotator_2d` or with the function `micro_sam.sam_annotator.annotator_2d` from python. Run `micro_sam.annotator_2d -h` or check out [examples/sam_annotator_2d](https://github.com/computational-cell-analytics/micro-sam/blob/master/examples/sam_annotator_2d.py) for details.
60+
- **3d segmentation**: via the command `micro_sam.annotator_3d` or with the function `micro_sam.sam_annotator.annotator_3d` from python. Run `micro_sam.annotator_3d -h` or check out [examples/sam_annotator_3d](https://github.com/computational-cell-analytics/micro-sam/blob/master/examples/sam_annotator_3d.py) for details.
6161
- **tracking**: via the command `micro_sam.annotator_tracking` or with the function `micro_sam.sam_annotator.annotator_tracking` from python. Run `micro_sam.annotator_tracking -h` or check out [examples/sam_annotator_tracking](https://github.com/computational-cell-analytics/micro-sam/blob/master/examples/sam_annotator_tracking.py) for details.
6262

6363
TODO
@@ -94,4 +94,4 @@ micro_sam <- library with utility functionality for using SAM for microscopy dat
9494

9595
If you are using this repository in your research please cite
9696
- [SegmentAnything](https://arxiv.org/abs/2304.02643)
97-
- and our repository on [zenodo](TODO) (we are working on a full publication)
97+
- and our repository on [zenodo](TODO) (we are working on a publication)

development/.gitignore

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
embeddings/
2+
*.npy
Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,12 @@
77

88

99
def mito_segmentation():
10-
input_path = "./data/Lucchi++/Test_In"
10+
input_path = "../examples/data/Lucchi++/Test_In"
1111
with open_file(input_path) as f:
1212
raw = f["*.png"][-1, :768, :768]
1313

1414
predictor = util.get_sam_model()
15-
image_embeddings = util.precompute_image_embeddings(predictor, raw, "./embeddings/embeddings-mito2d.zarr")
15+
image_embeddings = util.precompute_image_embeddings(predictor, raw, "../examples/embeddings/embeddings-mito2d.zarr")
1616
embedding_pca = compute_pca(image_embeddings["features"])
1717

1818
seg, initial_seg = segment_from_embeddings(predictor, image_embeddings=image_embeddings, return_initial_seg=True)
@@ -26,14 +26,16 @@ def mito_segmentation():
2626

2727

2828
def cell_segmentation():
29-
path = "./DIC-C2DH-HeLa/train/01"
29+
path = "../examples/data/DIC-C2DH-HeLa/train/01"
3030
with open_file(path, mode="r") as f:
3131
timeseries = f["*.tif"][:50]
3232

3333
frame = 11
3434

3535
predictor = util.get_sam_model()
36-
image_embeddings = util.precompute_image_embeddings(predictor, timeseries, "./embeddings/embeddings-ctc.zarr")
36+
image_embeddings = util.precompute_image_embeddings(
37+
predictor, timeseries, "../examples/embeddings/embeddings-ctc.zarr"
38+
)
3739
embedding_pca = compute_pca(image_embeddings["features"][frame])
3840

3941
seg, initial_seg = segment_from_embeddings(

development/tracking.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
from glob import glob
2+
3+
import numpy as np
4+
from elf.io import open_file
5+
from micro_sam.sam_annotator import annotator_tracking
6+
7+
8+
def debug_tracking(timeseries, embedding_path):
9+
import micro_sam.util as util
10+
from micro_sam.sam_annotator.annotator_tracking import _track_from_prompts
11+
12+
predictor = util.get_sam_model()
13+
image_embeddings = util.precompute_image_embeddings(predictor, timeseries, embedding_path)
14+
15+
# seg = np.zeros(timeseries.shape, dtype="uint32")
16+
seg = np.load("./seg.npy")
17+
assert seg.shape == timeseries.shape
18+
slices = np.array([0])
19+
stop_upper = False
20+
21+
_track_from_prompts(seg, predictor, slices, image_embeddings, stop_upper, threshold=0.5, projection="bounding_box")
22+
23+
24+
def load_data():
25+
pattern = "/home/pape/Work/data/incu_cyte/carmello/videos/MiaPaCa_flat_B3-3_registered/image-*"
26+
paths = glob(pattern)
27+
paths.sort()
28+
29+
timeseries = []
30+
for p in paths[:45]:
31+
with open_file(p, mode="r") as f:
32+
timeseries.append(f["phase-contrast"][:])
33+
timeseries = np.stack(timeseries)
34+
return timeseries
35+
36+
37+
def main():
38+
timeseries = load_data()
39+
embedding_path = "./embeddings/embeddings-tracking.zarr"
40+
41+
# _check_tracking(timeseries, embedding_path)
42+
annotator_tracking(timeseries, embedding_path=embedding_path)
43+
44+
45+
if __name__ == "__main__":
46+
main()

examples/image_series_annotator_app.py

Lines changed: 3 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111

1212
from magicgui import magicgui
1313
from micro_sam.segment_from_prompts import segment_from_points
14-
from micro_sam.sam_annotator.util import create_prompt_menu, prompt_layer_to_points
14+
from micro_sam.sam_annotator.util import create_prompt_menu, prompt_layer_to_points, toggle_label
1515
from napari import Viewer
1616

1717

@@ -64,15 +64,8 @@ def image_series_annotator(image_paths, embedding_save_path, output_folder):
6464

6565
# toggle the points between positive / negative
6666
@v.bind_key("t")
67-
def toggle_label(event=None):
68-
# get the currently selected label
69-
current_properties = prompts.current_properties
70-
current_label = current_properties["label"][0]
71-
new_label = "negative" if current_label == "positive" else "positive"
72-
current_properties["label"] = np.array([new_label])
73-
prompts.current_properties = current_properties
74-
prompts.refresh()
75-
prompts.refresh_colors()
67+
def _toggle_label(event=None):
68+
toggle_label(prompts)
7669

7770
# bind the segmentation to a key 's'
7871
@v.bind_key("s")

examples/sam_annotator_2d.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,10 +13,11 @@ def livecell_annotator():
1313

1414
def main():
1515
# 2d annotator for livecell data
16-
# livecell_annotator()
16+
livecell_annotator()
1717

18+
# TODO
1819
# 2d annotator for cell tracking challenge hela data
19-
hela_2d_annotator()
20+
# hela_2d_annotator()
2021

2122

2223
if __name__ == "__main__":

examples/sam_annotator_tracking.py

Lines changed: 4 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,38 +1,19 @@
1-
from glob import glob
2-
3-
import numpy as np
41
from elf.io import open_file
52
from micro_sam.sam_annotator import annotator_tracking
63

74

8-
def track_incucyte_data():
9-
pattern = "/home/pape/Work/data/incu_cyte/carmello/videos/MiaPaCa_flat_B3-3_registered/image-*"
10-
paths = glob(pattern)
11-
paths.sort()
12-
13-
timeseries = []
14-
for p in paths[:45]:
15-
with open_file(p, mode="r") as f:
16-
timeseries.append(f["phase-contrast"][:])
17-
timeseries = np.stack(timeseries)
18-
19-
annotator_tracking(timeseries, embedding_path="./embeddings/embeddings-tracking.zarr", show_embeddings=False)
20-
21-
22-
# TODO describe how to get the data from CTC
5+
# This runs the interactive tracking annotator for data from the cell tracking challenge:
6+
# It uses the training data for the HeLA dataset. You can download the data via
7+
# http://data.celltrackingchallenge.net/training-datasets/DIC-C2DH-HeLa.zip
238
def track_ctc_data():
249
path = "./data/DIC-C2DH-HeLa/train/01"
2510
with open_file(path, mode="r") as f:
2611
timeseries = f["*.tif"][:50]
27-
2812
annotator_tracking(timeseries, embedding_path="./embeddings/embeddings-ctc.zarr")
2913

3014

3115
def main():
32-
# private data used for initial tests
33-
# track_incucyte_data()
34-
35-
# data from the cell tracking challenges
16+
# run interactive tracking for data from the cell tracking challenge
3617
track_ctc_data()
3718

3819

micro_sam/sam_annotator/annotator_2d.py

Lines changed: 18 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,9 @@
88
from ..visualization import project_embeddings_for_visualization
99
from ..segment_instances import segment_from_embeddings
1010
from ..segment_from_prompts import segment_from_points
11-
from .util import commit_segmentation_widget, create_prompt_menu, prompt_layer_to_points
12-
13-
COLOR_CYCLE = ["#00FF00", "#FF0000"]
11+
from .util import (
12+
commit_segmentation_widget, create_prompt_menu, prompt_layer_to_points, toggle_label, LABEL_COLOR_CYCLE
13+
)
1414

1515

1616
@magicgui(call_button="Segment Object [S]")
@@ -35,7 +35,7 @@ def annotator_2d(raw, embedding_path=None, show_embeddings=False, segmentation_r
3535
global PREDICTOR, IMAGE_EMBEDDINGS
3636

3737
PREDICTOR = util.get_sam_model()
38-
IMAGE_EMBEDDINGS = util.precompute_image_embeddings(PREDICTOR, raw, save_path=embedding_path)
38+
IMAGE_EMBEDDINGS = util.precompute_image_embeddings(PREDICTOR, raw, save_path=embedding_path, ndim=2)
3939
util.set_precomputed(PREDICTOR, IMAGE_EMBEDDINGS)
4040

4141
#
@@ -45,16 +45,23 @@ def annotator_2d(raw, embedding_path=None, show_embeddings=False, segmentation_r
4545
v = Viewer()
4646

4747
v.add_image(raw)
48-
v.add_labels(data=np.zeros(raw.shape, dtype="uint32"), name="auto_segmentation")
48+
if raw.ndim == 2:
49+
shape = raw.shape
50+
elif raw.ndim == 3 and raw.shape[-1] == 3:
51+
shape = raw.shape[:2]
52+
else:
53+
raise ValueError(f"Invalid input image of shape {raw.shape}. Expect either 2D grayscale or 3D RGB image.")
54+
55+
v.add_labels(data=np.zeros(shape, dtype="uint32"), name="auto_segmentation")
4956
if segmentation_result is None:
50-
v.add_labels(data=np.zeros(raw.shape, dtype="uint32"), name="committed_objects")
57+
v.add_labels(data=np.zeros(shape, dtype="uint32"), name="committed_objects")
5158
else:
5259
v.add_labels(segmentation_result, name="committed_objects")
53-
v.add_labels(data=np.zeros(raw.shape, dtype="uint32"), name="current_object")
60+
v.add_labels(data=np.zeros(shape, dtype="uint32"), name="current_object")
5461

5562
# show the PCA of the image embeddings
5663
if show_embeddings:
57-
embedding_vis, scale = project_embeddings_for_visualization(IMAGE_EMBEDDINGS["features"], raw.shape)
64+
embedding_vis, scale = project_embeddings_for_visualization(IMAGE_EMBEDDINGS["features"], shape)
5865
v.add_image(embedding_vis, name="embeddings", scale=scale)
5966

6067
labels = ["positive", "negative"]
@@ -63,7 +70,7 @@ def annotator_2d(raw, embedding_path=None, show_embeddings=False, segmentation_r
6370
name="prompts",
6471
properties={"label": labels},
6572
edge_color="label",
66-
edge_color_cycle=COLOR_CYCLE,
73+
edge_color_cycle=LABEL_COLOR_CYCLE,
6774
symbol="o",
6875
face_color="transparent",
6976
edge_width=0.5,
@@ -98,15 +105,8 @@ def _commit(v):
98105
commit_segmentation_widget(v)
99106

100107
@v.bind_key("t")
101-
def toggle_label(event=None):
102-
# get the currently selected label
103-
current_properties = prompts.current_properties
104-
current_label = current_properties["label"][0]
105-
new_label = "negative" if current_label == "positive" else "positive"
106-
current_properties["label"] = np.array([new_label])
107-
prompts.current_properties = current_properties
108-
prompts.refresh()
109-
prompts.refresh_colors()
108+
def _toggle_label(event=None):
109+
toggle_label(prompts)
110110

111111
@v.bind_key("Shift-C")
112112
def clear_prompts(v):

micro_sam/sam_annotator/annotator_3d.py

Lines changed: 15 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,11 @@
88
from .. import util
99
from ..segment_from_prompts import segment_from_mask, segment_from_points
1010
from ..visualization import project_embeddings_for_visualization
11-
from .util import commit_segmentation_widget, create_prompt_menu, prompt_layer_to_points, segment_slices_with_prompts
12-
13-
COLOR_CYCLE = ["#00FF00", "#FF0000"]
11+
from .util import (
12+
commit_segmentation_widget, create_prompt_menu,
13+
prompt_layer_to_points, segment_slices_with_prompts,
14+
toggle_label, LABEL_COLOR_CYCLE
15+
)
1416

1517

1618
#
@@ -74,9 +76,16 @@ def segment_range(z_start, z_stop, increment, stopping_criterion, threshold=None
7476
for z_start, z_stop in zip(segmented_slices[:-1], segmented_slices[1:]):
7577
slice_diff = z_stop - z_start
7678
z_mid = int((z_start + z_stop) // 2)
79+
7780
if slice_diff == 1: # the slices are adjacent -> we don't need to do anything
7881
pass
7982

83+
elif z_start == z0 and stop_lower: # the lower slice is stop: we just segment from upper
84+
segment_range(z_stop, z_start, -1, np.less_equal, verbose=verbose)
85+
86+
elif z_stop == z1 and stop_upper: # the upper slice is stop: we just segment from lower
87+
segment_range(z_start, z_stop, 1, np.greater_equal, verbose=verbose)
88+
8089
elif slice_diff == 2: # there is only one slice in between -> use combined mask
8190
z = z_start + 1
8291
seg_prompt = np.logical_or(seg[z_start] == 1, seg[z_stop] == 1)
@@ -187,7 +196,7 @@ def annotator_3d(raw, embedding_path=None, show_embeddings=False, segmentation_r
187196
name="prompts",
188197
properties={"label": labels},
189198
edge_color="label",
190-
edge_color_cycle=COLOR_CYCLE,
199+
edge_color_cycle=LABEL_COLOR_CYCLE,
191200
symbol="o",
192201
face_color="transparent",
193202
edge_width=0.5,
@@ -227,15 +236,8 @@ def _commit(v):
227236
commit_segmentation_widget(v)
228237

229238
@v.bind_key("t")
230-
def toggle_label(event=None):
231-
# get the currently selected label
232-
current_properties = prompts.current_properties
233-
current_label = current_properties["label"][0]
234-
new_label = "negative" if current_label == "positive" else "positive"
235-
current_properties["label"] = np.array([new_label])
236-
prompts.current_properties = current_properties
237-
prompts.refresh()
238-
prompts.refresh_colors()
239+
def _toggle_label(event=None):
240+
toggle_label(prompts)
239241

240242
@v.bind_key("Shift-C")
241243
def clear_prompts(v):

0 commit comments

Comments
 (0)