44from magicgui import magicgui
55from napari import Viewer
66from napari .utils import progress
7+ from scipy .ndimage import shift
8+ from vigra .filters import eccentricityCenters
79
810from .. import util
911from ..segment_from_prompts import segment_from_mask , segment_from_points
1820#
1921
2022
21- # TODO motion model!!!
23+ def _compute_movement (seg , t0 , t1 ):
24+
25+ def compute_center (t ):
26+ center = np .array (eccentricityCenters (seg [t ].astype ("uint32" )))
27+ assert center .shape == (2 , 2 )
28+ return center [1 ]
29+
30+ center0 = compute_center (t0 )
31+ center1 = compute_center (t1 )
32+
33+ move = center1 - center0
34+ return move .astype ("float64" )
35+
36+
37+ def _update_motion_model (motion_model , move , motion_smoothing ):
38+ alpha = motion_smoothing
39+ motion_model = alpha * motion_model + (1 - alpha ) * move
40+ return motion_model
41+
42+
43+ def _shift_object (mask , motion_model ):
44+ mask_shifted = np .zeros_like (mask )
45+ shift (mask , motion_model , output = mask_shifted , order = 0 , prefilter = False )
46+ return mask_shifted
47+
48+
2249# TODO handle divison annotations + division classifier
23- def _track_from_prompts (seg , predictor , slices , image_embeddings , stop_upper , threshold , method , progress_bar = None ):
50+ def _track_from_prompts (
51+ seg , predictor , slices , image_embeddings , stop_upper , threshold , method ,
52+ progress_bar = None , motion_smoothing = 0.5 ,
53+ ):
2454 assert method in ("mask" , "bounding_box" )
2555 if method == "mask" :
2656 use_mask , use_box = True , True
@@ -31,17 +61,40 @@ def _update_progress():
3161 if progress_bar is not None :
3262 progress_bar .update (1 )
3363
64+ # shift the segmentation based on the motion model and update the motion model
65+ def _update_motion (seg , t , t0 , motion_model ):
66+ seg_prev = seg [t - 1 ]
67+
68+ if t == t0 + 1 : # this is the second frame, we don't have a motion model yet
69+ pass
70+ elif t == t0 + 2 : # this the third frame, we initialize the motion model
71+ current_move = _compute_movement (seg , t - 1 , t - 2 )
72+ motion_model = current_move
73+ else : # we already have a motion model and update it
74+ current_move = _compute_movement (seg , t - 1 , t - 2 )
75+ motion_model = _update_motion_model (motion_model , current_move , motion_smoothing )
76+
77+ if motion_model is not None : # shift the segmentation according to the motion model
78+ seg_prev = _shift_object (seg_prev , motion_model )
79+
80+ return seg_prev , motion_model
81+
82+ motion_model = None
83+ verbose = False
84+
3485 t0 = int (slices .min ())
3586 t = t0 + 1
3687 while True :
3788 if t in slices :
3889 seg_prev = None
3990 seg_t = seg [t ]
4091 else :
41- seg_prev = seg [ t - 1 ]
92+ seg_prev , motion_model = _update_motion ( seg , t , t0 , motion_model )
4293 seg_t = segment_from_mask (predictor , seg_prev , image_embeddings = image_embeddings , i = t ,
4394 use_mask = use_mask , use_box = use_box )
4495 _update_progress ()
96+ if verbose :
97+ print (f"Tracking object in frame { t } with movement { motion_model } " )
4598
4699 if (threshold is not None ) and (seg_prev is not None ):
47100 iou = util .compute_iou (seg_prev , seg_t )
@@ -84,7 +137,7 @@ def segment_frame_wigdet(v: Viewer):
84137
85138
86139@magicgui (call_button = "Track Object [V]" , method = {"choices" : ["bounding_box" , "mask" ]})
87- def track_objet_widget (v : Viewer , iou_threshold : float = 0.8 , method : str = "mask" ):
140+ def track_objet_widget (v : Viewer , iou_threshold : float = 0.5 , method : str = "mask" , motion_smoothing : float = 0.5 ):
88141 shape = v .layers ["raw" ].data .shape
89142
90143 with progress (total = shape [0 ]) as progress_bar :
@@ -95,7 +148,8 @@ def track_objet_widget(v: Viewer, iou_threshold: float = 0.8, method: str = "mas
95148
96149 # step 2: track the object starting from the lowest annotated slice
97150 seg = _track_from_prompts (
98- seg , PREDICTOR , slices , IMAGE_EMBEDDINGS , stop_upper , iou_threshold , method , progress_bar = progress_bar
151+ seg , PREDICTOR , slices , IMAGE_EMBEDDINGS , stop_upper , threshold = iou_threshold ,
152+ method = method , progress_bar = progress_bar , motion_smoothing = motion_smoothing ,
99153 )
100154
101155 v .layers ["current_track" ].data = seg
0 commit comments