Skip to content

Commit fe67103

Browse files
Implement motion model for tracking
1 parent 5ec88e1 commit fe67103

File tree

2 files changed

+79
-6
lines changed

2 files changed

+79
-6
lines changed

examples/sam_annotator_tracking.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,23 @@
55
from micro_sam.sam_annotator import annotator_tracking
66

77

8+
# FOR DEBUGGING / DEVELOPMENT
9+
def _check_tracking(timeseries, embedding_path):
10+
import micro_sam.util as util
11+
from micro_sam.sam_annotator.annotator_tracking import _track_from_prompts
12+
13+
predictor = util.get_sam_model()
14+
image_embeddings = util.precompute_image_embeddings(predictor, timeseries, embedding_path)
15+
16+
# seg = np.zeros(timeseries.shape, dtype="uint32")
17+
seg = np.load("./seg.npy")
18+
assert seg.shape == timeseries.shape
19+
slices = np.array([0])
20+
stop_upper = False
21+
22+
_track_from_prompts(seg, predictor, slices, image_embeddings, stop_upper, threshold=0.5, method="bounding_box")
23+
24+
825
def main():
926
pattern = "/home/pape/Work/data/incu_cyte/carmello/videos/MiaPaCa_flat_B3-3_registered/image-*"
1027
paths = glob(pattern)
@@ -16,7 +33,9 @@ def main():
1633
timeseries.append(f["phase-contrast"][:])
1734
timeseries = np.stack(timeseries)
1835

19-
annotator_tracking(timeseries, embedding_path="./embeddings/embeddings-tracking.zarr")
36+
embedding_path = "./embeddings/embeddings-tracking.zarr"
37+
# _check_tracking(timeseries, embedding_path)
38+
annotator_tracking(timeseries, embedding_path=embedding_path)
2039

2140

2241
if __name__ == "__main__":

micro_sam/sam_annotator/annotator_tracking.py

Lines changed: 59 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44
from magicgui import magicgui
55
from napari import Viewer
66
from napari.utils import progress
7+
from scipy.ndimage import shift
8+
from vigra.filters import eccentricityCenters
79

810
from .. import util
911
from ..segment_from_prompts import segment_from_mask, segment_from_points
@@ -18,9 +20,37 @@
1820
#
1921

2022

21-
# TODO motion model!!!
23+
def _compute_movement(seg, t0, t1):
24+
25+
def compute_center(t):
26+
center = np.array(eccentricityCenters(seg[t].astype("uint32")))
27+
assert center.shape == (2, 2)
28+
return center[1]
29+
30+
center0 = compute_center(t0)
31+
center1 = compute_center(t1)
32+
33+
move = center1 - center0
34+
return move.astype("float64")
35+
36+
37+
def _update_motion_model(motion_model, move, motion_smoothing):
38+
alpha = motion_smoothing
39+
motion_model = alpha * motion_model + (1 - alpha) * move
40+
return motion_model
41+
42+
43+
def _shift_object(mask, motion_model):
44+
mask_shifted = np.zeros_like(mask)
45+
shift(mask, motion_model, output=mask_shifted, order=0, prefilter=False)
46+
return mask_shifted
47+
48+
2249
# TODO handle divison annotations + division classifier
23-
def _track_from_prompts(seg, predictor, slices, image_embeddings, stop_upper, threshold, method, progress_bar=None):
50+
def _track_from_prompts(
51+
seg, predictor, slices, image_embeddings, stop_upper, threshold, method,
52+
progress_bar=None, motion_smoothing=0.5,
53+
):
2454
assert method in ("mask", "bounding_box")
2555
if method == "mask":
2656
use_mask, use_box = True, True
@@ -31,17 +61,40 @@ def _update_progress():
3161
if progress_bar is not None:
3262
progress_bar.update(1)
3363

64+
# shift the segmentation based on the motion model and update the motion model
65+
def _update_motion(seg, t, t0, motion_model):
66+
seg_prev = seg[t - 1]
67+
68+
if t == t0 + 1: # this is the second frame, we don't have a motion model yet
69+
pass
70+
elif t == t0 + 2: # this the third frame, we initialize the motion model
71+
current_move = _compute_movement(seg, t - 1, t - 2)
72+
motion_model = current_move
73+
else: # we already have a motion model and update it
74+
current_move = _compute_movement(seg, t - 1, t - 2)
75+
motion_model = _update_motion_model(motion_model, current_move, motion_smoothing)
76+
77+
if motion_model is not None: # shift the segmentation according to the motion model
78+
seg_prev = _shift_object(seg_prev, motion_model)
79+
80+
return seg_prev, motion_model
81+
82+
motion_model = None
83+
verbose = False
84+
3485
t0 = int(slices.min())
3586
t = t0 + 1
3687
while True:
3788
if t in slices:
3889
seg_prev = None
3990
seg_t = seg[t]
4091
else:
41-
seg_prev = seg[t - 1]
92+
seg_prev, motion_model = _update_motion(seg, t, t0, motion_model)
4293
seg_t = segment_from_mask(predictor, seg_prev, image_embeddings=image_embeddings, i=t,
4394
use_mask=use_mask, use_box=use_box)
4495
_update_progress()
96+
if verbose:
97+
print(f"Tracking object in frame {t} with movement {motion_model}")
4598

4699
if (threshold is not None) and (seg_prev is not None):
47100
iou = util.compute_iou(seg_prev, seg_t)
@@ -84,7 +137,7 @@ def segment_frame_wigdet(v: Viewer):
84137

85138

86139
@magicgui(call_button="Track Object [V]", method={"choices": ["bounding_box", "mask"]})
87-
def track_objet_widget(v: Viewer, iou_threshold: float = 0.8, method: str = "mask"):
140+
def track_objet_widget(v: Viewer, iou_threshold: float = 0.5, method: str = "mask", motion_smoothing: float = 0.5):
88141
shape = v.layers["raw"].data.shape
89142

90143
with progress(total=shape[0]) as progress_bar:
@@ -95,7 +148,8 @@ def track_objet_widget(v: Viewer, iou_threshold: float = 0.8, method: str = "mas
95148

96149
# step 2: track the object starting from the lowest annotated slice
97150
seg = _track_from_prompts(
98-
seg, PREDICTOR, slices, IMAGE_EMBEDDINGS, stop_upper, iou_threshold, method, progress_bar=progress_bar
151+
seg, PREDICTOR, slices, IMAGE_EMBEDDINGS, stop_upper, threshold=iou_threshold,
152+
method=method, progress_bar=progress_bar, motion_smoothing=motion_smoothing,
99153
)
100154

101155
v.layers["current_track"].data = seg

0 commit comments

Comments
 (0)