Skip to content

Commit 379e661

Browse files
Stop tracks when reaching a division label
1 parent b43d44a commit 379e661

File tree

2 files changed

+51
-10
lines changed

2 files changed

+51
-10
lines changed

micro_sam/sam_annotator/annotator_tracking.py

Lines changed: 26 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,9 @@
1111

1212
from .. import util
1313
from ..segment_from_prompts import segment_from_mask, segment_from_points
14-
from .util import create_prompt_menu, prompt_layer_to_points, segment_slices_with_prompts, LABEL_COLOR_CYCLE
14+
from .util import (
15+
create_prompt_menu, prompt_layer_to_points, prompt_layer_to_state, segment_slices_with_prompts, LABEL_COLOR_CYCLE
16+
)
1517
from ..visualization import project_embeddings_for_visualization
1618

1719
# Magenta and Cyan
@@ -58,7 +60,8 @@ def _shift_object(mask, motion_model):
5860

5961
# TODO handle divison annotations + division classifier
6062
def _track_from_prompts(
61-
seg, predictor, slices, image_embeddings, stop_upper, threshold, projection,
63+
prompt_layer, seg, predictor, slices, image_embeddings,
64+
stop_upper, threshold, projection,
6265
progress_bar=None, motion_smoothing=0.5,
6366
):
6467
assert projection in ("mask", "bounding_box")
@@ -95,16 +98,28 @@ def _update_motion(seg, t, t0, motion_model):
9598
t0 = int(slices.min())
9699
t = t0 + 1
97100
while True:
98-
if t in slices:
101+
102+
if t in slices: # this is a slice with prompts
99103
seg_prev = None
100104
seg_t = seg[t]
101-
else:
105+
track_state = prompt_layer_to_state(prompt_layer, t)
106+
# TODO what do we do with the motion model here?
107+
108+
else: # this is a slice without prompts
102109
seg_prev, motion_model = _update_motion(seg, t, t0, motion_model)
110+
if verbose:
111+
print(f"Tracking object in frame {t} with movement {motion_model}")
103112
seg_t = segment_from_mask(predictor, seg_prev, image_embeddings=image_embeddings, i=t,
104113
use_mask=use_mask, use_box=use_box)
114+
track_state = "track"
115+
116+
# are we beyond the last slice with prompt?
117+
# if no: we continue tracking because we know we need to connect to a future frame
118+
# if yes: we only continue tracking if overlaps are above the threshold
119+
if t < slices[-1]:
120+
seg_prev = None
121+
105122
_update_progress()
106-
if verbose:
107-
print(f"Tracking object in frame {t} with movement {motion_model}")
108123

109124
if (threshold is not None) and (seg_prev is not None):
110125
iou = util.compute_iou(seg_prev, seg_t)
@@ -116,7 +131,6 @@ def _update_motion(seg, t, t0, motion_model):
116131
seg[t] = seg_t
117132
t += 1
118133

119-
# TODO here we need to also stop once divisions are implemented
120134
# stop tracking if we have stop upper set (i.e. single negative point was set to indicate stop track)
121135
if t == slices[-1] and stop_upper:
122136
break
@@ -125,6 +139,10 @@ def _update_motion(seg, t, t0, motion_model):
125139
if t == seg.shape[0] - 1:
126140
break
127141

142+
# stop if we have a division
143+
if track_state == "division":
144+
break
145+
128146
return seg
129147

130148

@@ -167,7 +185,7 @@ def track_objet_widget(
167185

168186
# step 2: track the object starting from the lowest annotated slice
169187
seg = _track_from_prompts(
170-
seg, PREDICTOR, slices, IMAGE_EMBEDDINGS, stop_upper, threshold=iou_threshold,
188+
v.layers["prompts"], seg, PREDICTOR, slices, IMAGE_EMBEDDINGS, stop_upper, threshold=iou_threshold,
171189
projection=projection_, progress_bar=progress_bar, motion_smoothing=motion_smoothing,
172190
)
173191

micro_sam/sam_annotator/util.py

Lines changed: 25 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ def label_changed(new_label):
5454
def prompt_layer_to_points(prompt_layer, i=None):
5555
"""Extract point prompts for SAM from point layer.
5656
57-
Argumtents:
57+
Arguments:
5858
prompt_layer: the point layer
5959
i [int] - index for the data (required for 3d data)
6060
"""
@@ -82,6 +82,30 @@ def prompt_layer_to_points(prompt_layer, i=None):
8282
return this_points, this_labels
8383

8484

85+
def prompt_layer_to_state(prompt_layer, i):
86+
"""Get the state of the track from the prompt layer.
87+
Only relevant for annotator_tracking.
88+
89+
Arguments:
90+
prompt_layer: the point layer
91+
i [int] - index for the data (required for 3d data)
92+
"""
93+
state = prompt_layer.properties["state"]
94+
95+
points = prompt_layer.data
96+
assert points.shape[1] == 3, f"{points.shape}"
97+
mask = points[:, 0] == i
98+
this_points = points[mask][:, 1:]
99+
this_state = state[mask]
100+
assert len(this_points) == len(this_state)
101+
102+
# we set the state to 'division' if at least one point in this frame has a division label
103+
if any(st == "division" for st in this_state):
104+
return "division"
105+
else:
106+
return "track"
107+
108+
85109
def segment_slices_with_prompts(predictor, prompt_layer, image_embeddings, shape, progress_bar=None):
86110
seg = np.zeros(shape, dtype="uint32")
87111

@@ -95,7 +119,6 @@ def _update_progress():
95119
for i in slices:
96120
prompts_i = prompt_layer_to_points(prompt_layer, i)
97121

98-
# TODO also take into account division properties once we have this implemented in tracking
99122
# do we end the segmentation at the outer slices?
100123
if prompts_i is None:
101124

0 commit comments

Comments
 (0)