Skip to content

Commit 659f225

Browse files
Implement rudimentary functionality for tracking annotator
1 parent ff24b34 commit 659f225

File tree

4 files changed

+95
-34
lines changed

4 files changed

+95
-34
lines changed

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ We implement napari applications for:
77
- interactive tracking of 2d image data
88

99
**Early beta version**
10+
1011
This is an early beta version. Any feedback is welcome, but please be aware that the functionality is evolving fast and not fully tested.
1112

1213
## Functionality overview

micro_sam/sam_annotator/annotator_3d.py

Lines changed: 3 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from .. import util
88
from ..segment_from_prompts import segment_from_mask, segment_from_points
99
from ..visualization import compute_pca
10-
from .util import commit_segmentation_widget, create_prompt_menu, prompt_layer_to_points
10+
from .util import commit_segmentation_widget, create_prompt_menu, prompt_layer_to_points, segment_slices_with_prompts
1111

1212
COLOR_CYCLE = ["#00FF00", "#FF0000"]
1313

@@ -18,33 +18,6 @@
1818
#
1919

2020

21-
def _segment_slices_with_prompts(predictor, prompt_layer, image_embeddings, shape):
22-
seg = np.zeros(shape, dtype="uint32")
23-
24-
slices = np.unique(prompt_layer.data[:, 0]).astype("int")
25-
stop_lower, stop_upper = False, False
26-
27-
for z in slices:
28-
prompts_z = prompt_layer_to_points(prompt_layer, z)
29-
30-
# do we end the segmentation at the outer slices?
31-
if prompts_z is None:
32-
if z == slices[0]:
33-
stop_lower = True
34-
elif z == slices[-1]:
35-
stop_upper = True
36-
else:
37-
raise RuntimeError("Stop slices can only be at the start or end")
38-
seg[z] = 0
39-
continue
40-
41-
points, labels = prompts_z
42-
seg_z = segment_from_points(predictor, points, labels, image_embeddings=image_embeddings, i=z)
43-
seg[z] = seg_z
44-
45-
return seg, slices, stop_lower, stop_upper
46-
47-
4821
# TODO refactor
4922
def _segment_volume(
5023
seg, predictor, image_embeddings, segmented_slices,
@@ -74,9 +47,9 @@ def segment_range(z_start, z_stop, increment, stopping_criterion, threshold=None
7447
seg[z] = seg_z
7548
z += increment
7649
if stopping_criterion(z, z_stop):
77-
break
7850
if verbose:
7951
print(f"Segment {z_start} to {z_stop}: stop at slice {z}")
52+
break
8053

8154
z0, z1 = int(segmented_slices.min()), int(segmented_slices.max())
8255

@@ -145,7 +118,7 @@ def segment_slice_wigdet(v: Viewer):
145118
def segment_volume_widget(v: Viewer, iou_threshold: float = 0.8, method: str = "mask"):
146119
# step 1: segment all slices with prompts
147120
shape = v.layers["raw"].data.shape
148-
seg, slices, stop_lower, stop_upper = _segment_slices_with_prompts(
121+
seg, slices, stop_lower, stop_upper = segment_slices_with_prompts(
149122
PREDICTOR, v.layers["prompts"], IMAGE_EMBEDDINGS, shape
150123
)
151124

micro_sam/sam_annotator/annotator_tracking.py

Lines changed: 61 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,13 +5,60 @@
55
from napari import Viewer
66

77
from .. import util
8-
from ..segment_from_prompts import segment_from_points
8+
from ..segment_from_prompts import segment_from_mask, segment_from_points
99
from ..visualization import compute_pca
10-
from .util import create_prompt_menu, prompt_layer_to_points
10+
from .util import create_prompt_menu, prompt_layer_to_points, segment_slices_with_prompts
1111

1212
COLOR_CYCLE = ["#00FF00", "#FF0000"]
1313

1414

15+
#
16+
# util functionality
17+
#
18+
19+
20+
# TODO motion model!!!
21+
# TODO handle divison annotations + division classifier
22+
def _track_from_prompts(seg, predictor, slices, image_embeddings, stop_upper, threshold, method):
23+
assert method in ("mask", "bounding_box")
24+
if method == "mask":
25+
use_mask, use_box = True, True
26+
else:
27+
use_mask, use_box = False, True
28+
29+
t0 = int(slices.min())
30+
t = t0 + 1
31+
while True:
32+
if t in slices:
33+
seg_prev = None
34+
seg_t = seg[t]
35+
else:
36+
seg_prev = seg[t - 1]
37+
seg_t = segment_from_mask(predictor, seg_prev, image_embeddings=image_embeddings, i=t,
38+
use_mask=use_mask, use_box=use_box)
39+
40+
if (threshold is not None) and (seg_prev is not None):
41+
iou = util.compute_iou(seg_prev, seg_t)
42+
if iou < threshold:
43+
msg = f"Segmentation stopped at frame {t} due to IOU {iou} < {threshold}."
44+
print(msg)
45+
break
46+
47+
seg[t] = seg_t
48+
t += 1
49+
50+
# TODO here we need to also stop once divisions are implemented
51+
# stop tracking if we have stop upper set (i.e. single negative point was set to indicate stop track)
52+
if t == slices[-1] and stop_upper:
53+
break
54+
55+
# stop if we are at the last slce
56+
if t == seg.shape[0] - 1:
57+
break
58+
59+
return seg
60+
61+
1562
#
1663
# the widgets
1764
#
@@ -31,8 +78,18 @@ def segment_frame_wigdet(v: Viewer):
3178

3279

3380
@magicgui(call_button="Track Object [V]", method={"choices": ["bounding_box", "mask"]})
34-
def track_objet_widget(v: Viewer, iou_threshold: float = 0.8, method: str = "bounding_box"):
35-
pass
81+
def track_objet_widget(v: Viewer, iou_threshold: float = 0.8, method: str = "mask"):
82+
# step 1: segment all slices with prompts
83+
shape = v.layers["raw"].data.shape
84+
seg, slices, _, stop_upper = segment_slices_with_prompts(
85+
PREDICTOR, v.layers["prompts"], IMAGE_EMBEDDINGS, shape
86+
)
87+
88+
# step 2: track the object starting from the lowest annotated slice
89+
seg = _track_from_prompts(seg, PREDICTOR, slices, IMAGE_EMBEDDINGS, stop_upper, iou_threshold, method)
90+
91+
v.layers["current_track"].data = seg
92+
v.layers["current_track"].refresh()
3693

3794

3895
def annotator_tracking(raw, embedding_path=None, show_embeddings=False):

micro_sam/sam_annotator/util.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44
from magicgui.widgets import ComboBox, Container
55
from napari import Viewer
66

7+
from ..segment_from_prompts import segment_from_points
8+
79

810
@magicgui(call_button="Commit [C]")
911
def commit_segmentation_widget(v: Viewer):
@@ -72,3 +74,31 @@ def prompt_layer_to_points(prompt_layer, i=None):
7274
return None
7375

7476
return this_points, this_labels
77+
78+
79+
def segment_slices_with_prompts(predictor, prompt_layer, image_embeddings, shape):
80+
seg = np.zeros(shape, dtype="uint32")
81+
82+
slices = np.unique(prompt_layer.data[:, 0]).astype("int")
83+
stop_lower, stop_upper = False, False
84+
85+
for i in slices:
86+
prompts_i = prompt_layer_to_points(prompt_layer, i)
87+
88+
# TODO also take into account division properties once we have this implemented in tracking
89+
# do we end the segmentation at the outer slices?
90+
if prompts_i is None:
91+
if i == slices[0]:
92+
stop_lower = True
93+
elif i == slices[-1]:
94+
stop_upper = True
95+
else:
96+
raise RuntimeError("Stop slices can only be at the start or end")
97+
seg[i] = 0
98+
continue
99+
100+
points, labels = prompts_i
101+
seg_i = segment_from_points(predictor, points, labels, image_embeddings=image_embeddings, i=i)
102+
seg[i] = seg_i
103+
104+
return seg, slices, stop_lower, stop_upper

0 commit comments

Comments
 (0)