1- import numpy as np
21import datetime
2+ from collections import namedtuple
3+ from typing import NamedTuple , Optional , Tuple
4+
5+ import numpy as np
36
47from stytra .collectors import QueueDataAccumulator
8+ from stytra .collectors .namedtuplequeue import NamedTupleQueue
59from stytra .utilities import reduce_to_pi
6- from collections import namedtuple
710
811
912class Estimator :
@@ -15,11 +18,19 @@ class Estimator:
1518
1619 def __init__ (self , acc_tracking : QueueDataAccumulator , experiment ):
1720 self .exp = experiment
18- self .log = experiment .estimator_log
1921 self .acc_tracking = acc_tracking
22+ self .output_queue = NamedTupleQueue ()
23+ self ._output_type = None
24+
25+ def update (self ):
26+ raise NotImplementedError
2027
2128 def reset (self ):
22- self .log .reset ()
29+ pass
30+
31+
32+ class VigorEstimate (NamedTuple ):
33+ vigor : float
2334
2435
2536class VigorMotionEstimator (Estimator ):
@@ -34,27 +45,15 @@ def __init__(self, *args, vigor_window=0.050, base_gain=-12, **kwargs):
3445 self .vigor_window = vigor_window
3546 self .last_dt = 1 / 500.0
3647 self .base_gain = base_gain
37- self ._output_type = namedtuple ("s" , "vigor" )
38-
39- def get_velocity (self , lag = 0 ):
40- """
41-
42- Parameters
43- ----------
44- lag :
45- (Default value = 0)
48+ self ._output_type = namedtuple ("vigor_estimate" , ("vigor" ,))
4649
47- Returns
48- -------
49-
50- """
50+ def get_vigor (self ):
5151 vigor_n_samples = max (int (round (self .vigor_window / self .last_dt )), 2 )
52- n_samples_lag = max (int (round (lag / self .last_dt )), 0 )
5352 if not self .acc_tracking .stored_data :
5453 return 0
55- past_tail_motion = self .acc_tracking .get_last_n (
56- vigor_n_samples + n_samples_lag
57- )[ 0 : vigor_n_samples ]
54+ past_tail_motion = self .acc_tracking .get_last_n (vigor_n_samples )[
55+ 0 : vigor_n_samples
56+ ]
5857 end_t = past_tail_motion .t .iloc [- 1 ]
5958 start_t = past_tail_motion .t .iloc [0 ]
6059 new_dt = (end_t - start_t ) / vigor_n_samples
@@ -63,10 +62,15 @@ def get_velocity(self, lag=0):
6362 vigor = np .nanstd (np .array (past_tail_motion .tail_sum ))
6463 if np .isnan (vigor ):
6564 vigor = 0
65+ return end_t , vigor
66+
67+ def update (self ):
68+ end_t , vigor = self .get_vigor ()
69+ self .output_queue .put (end_t , self ._output_type (vigor ))
6670
67- if len ( self . log . times ) == 0 or self . log . times [ - 1 ] < end_t :
68- self . log . update_list ( end_t , self . _output_type ( vigor ))
69- return vigor * self . base_gain
71+
72+ class BoutEstimate ( NamedTuple ):
73+ is_bouting : bool
7074
7175
7276class BoutsEstimator (VigorMotionEstimator ):
@@ -78,138 +82,59 @@ def __init__(
7882 self .vigor_window = vigor_window
7983 self .min_interbout = min_interbout
8084 self .last_bout_t = None
85+ self ._output_type = namedtuple ("bouts" , ("is_bouting" ,))
8186
82- def bout_occured (self ):
83- if self .get_velocity () > self .base_gain * self .bout_threshold :
87+ def update (self ):
88+ end_t , vigor = self .get_vigor ()
89+ is_bouting = False
90+ if vigor > self .base_gain * self .bout_threshold :
8491 if (
8592 self .last_bout_t is None
8693 or (datetime .datetime .now () - self .last_bout_t ).total_seconds ()
8794 > self .min_interbout
8895 ):
8996 self .last_bout_t = datetime .datetime .now ()
90- return True
91- return False
97+ is_bouting = True
98+ self . output_queue . put ( end_t , self . _output_type ( is_bouting ))
9299
93100
94- class TailSumEstimator (Estimator ):
95- def __init__ (
96- self ,
97- * args ,
98- vigor_window = 0.050 ,
99- theta_window = 0.07 ,
100- base_gain = - 30 ,
101- bout_threshold = 0.05 ,
102- min_interbout = 0.1 ,
103- ** kwargs
104- ):
105- super ().__init__ (* args , ** kwargs )
106- self .vigor_window = vigor_window
107- self .theta_window = theta_window
108- self .last_dt = 1 / 500.0
109- self .base_gain = base_gain
110- self ._output_type = namedtuple ("s" , ("vigor" , "theta" , "bout_on" ))
111- self .bout_threshold = bout_threshold
112- self .vigor_window = vigor_window
113- self .min_interbout = min_interbout
114- self .last_bout_t = None
115- self .prev_time_on = False
116- self .bout_onset = 0
117- self .bout_on = 0
118- self .theta_provided = True
119- self .last_bout_t = 0
120- self .last_vigor = 0
121- self .last_bout_on = 0
122-
123- self .tail_th = 0
124-
125- def bout_occured (self ):
126- if self .bout_on :
127- if (
128- self .last_bout_t is None
129- or (datetime .datetime .now () - self .last_bout_t ).total_seconds ()
130- > self .min_interbout
131- ):
132- self .last_bout_t = datetime .datetime .now ()
133- return True
134- return False
101+ class EmbeddedBoutEstimate (NamedTuple ):
102+ vigor : float
103+ theta : float
104+ bout_on : bool
135105
136- def get_vel_and_theta (self , lag = 0 ):
137- """
138106
139- Parameters
140- ----------
141- lag :
142- (Default value = 0)
107+ class PositionEstimate ( NamedTuple ):
108+ x : float
109+ y : float
110+ theta : float
143111
144- Returns
145- -------
146112
147- """
148- # Vigor (copypasted from VigorEstimator method for simplicity)
149- vigor_n_samples = max (int (round (self .vigor_window / self .last_dt )), 2 )
150- n_samples_lag = max (int (round (lag / self .last_dt )), 0 )
151- if not self .acc_tracking .stored_data :
152- return 0 , 0 , 0
153- past_tail_motion = self .acc_tracking .get_last_n (
154- vigor_n_samples + n_samples_lag
155- )[0 :vigor_n_samples ]
156- end_t = past_tail_motion .t .iloc [- 1 ]
157- start_t = past_tail_motion .t .iloc [0 ]
158- new_dt = (end_t - start_t ) / vigor_n_samples
159- if new_dt > 0 :
160- self .last_dt = new_dt
161- vigor = np .nanstd (np .array (past_tail_motion .tail_sum ))
113+ def _propagate_change_above_threshold (
114+ current_estimate : PositionEstimate ,
115+ previous_estimate : Optional [PositionEstimate ],
116+ thresholds : PositionEstimate ,
117+ ) -> PositionEstimate :
118+ """Return updated components of a position if the component changed enough, otherwise return the old component"""
119+ if previous_estimate is None :
120+ return current_estimate
162121
163- if vigor is not None :
164- self .bout_on = int (vigor > self .bout_threshold )
165- else :
166- self .bout_on = int (self .last_vigor > self .bout_threshold )
167-
168- if self .bout_onset == 0 :
169- if self .bout_on and not self .last_bout_on :
170- self .bout_onset = 1
171- self .bout_start_t = datetime .datetime .now ()
172-
173- else :
174- self .theta_provided = False
175- self .bout_onset = 0
176-
177- if (
178- not self .theta_provided
179- ): # and (datetime.datetime.now() - self.bout_start_t).total_seconds() > 0.07:
180- # Tail theta:
181- th_n_samples = max (int (round (self .theta_window / self .last_dt )), 2 )
182- n_samples_lag = max (int (round (lag / self .last_dt )), 0 )
183-
184- past_tail_motion = self .acc_tracking .get_last_n (
185- th_n_samples + n_samples_lag
186- )[0 :th_n_samples ]
187- self .tail_th = np .nanmean (
188- np .array (past_tail_motion .tail_sum ) - past_tail_motion .tail_sum .iloc [0 ]
189- )
190- self .theta_provided = True
191- else :
192- self .tail_th = self .tail_th * (3 / 4 )
193- rn = np .random .randint (0 , 1 ) / 100
194- on_ns = self .bout_onset + rn
195- if len (self .log .times ) == 0 or self .log .times [- 1 ] < end_t :
196- self .log .update_list (end_t , self ._output_type (vigor , self .tail_th , on_ns ))
197-
198- if vigor is not None :
199- self .last_vigor = vigor
200-
201- self .last_bout_on = self .bout_on
202-
203- return vigor * self .base_gain , - self .tail_th * 3 , self .bout_on
204-
205-
206- def rot_mat (theta ):
207- """The rotation matrix for an angle theta"""
208- return np .array ([[np .cos (theta ), - np .sin (theta )], [np .sin (theta ), np .cos (theta )]])
122+ return PositionEstimate (
123+ x = current_estimate .x
124+ if abs (current_estimate .x - previous_estimate .x ) > thresholds .x
125+ else previous_estimate .x ,
126+ y = current_estimate .x
127+ if abs (current_estimate .y - previous_estimate .y ) > thresholds .y
128+ else previous_estimate .y ,
129+ theta = current_estimate .x
130+ if abs (reduce_to_pi (current_estimate .theta - previous_estimate .theta ))
131+ > thresholds .theta
132+ else previous_estimate .theta ,
133+ )
209134
210135
211136class PositionEstimator (Estimator ):
212- def __init__ (self , * args , change_thresholds = None , velocity_window = 10 , ** kwargs ):
137+ def __init__ (self , * args , change_thresholds : Optional [ PositionEstimate ] = None , velocity_window : int = 10 , ** kwargs ):
213138 """Uses the projector-to-camera calibration to give fish position in
214139 scree coordinates. If change_thresholds are set, update only the fish
215140 position after there is a big enough change (which prevents small
@@ -223,14 +148,12 @@ def __init__(self, *args, change_thresholds=None, velocity_window=10, **kwargs):
223148 super ().__init__ (* args , ** kwargs )
224149 self .calibrator = self .exp .calibrator
225150 self .last_location = None
226- self .past_values = None
151+ self .previous_position = None
227152
228153 self .velocity_window = velocity_window
229154 self .change_thresholds = change_thresholds
230- if change_thresholds is not None :
231- self .change_thresholds = np .array (change_thresholds )
232155
233- self ._output_type = namedtuple ( "f" , [ "x" , "y" , "theta" ])
156+ self ._output_type = PositionEstimate
234157
235158 def get_camera_position (self ):
236159 past_coords = {
@@ -248,17 +171,10 @@ def get_velocity(self):
248171 )
249172 return np .sqrt (np .sum (vel ** 2 ))
250173
251- def get_istantaneous_velocity (self ):
252- vel_xy = self .acc_tracking .get_last_n (self .velocity_window )[
253- ["f0_vx" , "f0_vy" ]
254- ].values
255- return np .sqrt (np .sum (vel_xy ** 2 ))
256-
257174 def reset (self ):
258- super ().reset ()
259- self .past_values = None
175+ self .previous_position = None
260176
261- def get_position (self ):
177+ def get_position (self ) -> Tuple [ float , PositionEstimate ] :
262178 if len (self .acc_tracking .stored_data ) == 0 or not np .isfinite (
263179 self .acc_tracking .stored_data [- 1 ].f0_x
264180 ):
@@ -286,23 +202,21 @@ def get_position(self):
286202 else :
287203 x , y , theta = past_coords .f0_x , past_coords .f0_y , past_coords .f0_theta
288204
289- c_values = np . array (( y , x , theta ) )
205+ current_position = PositionEstimate ( x , y , theta )
290206
291207 if self .change_thresholds is not None :
208+ if self .previous_position is None :
209+ self .previous_position = current_position
292210
293- if self .past_values is None :
294- self .past_values = np .array (c_values )
295- else :
296- deltas = c_values - self .past_values
297- deltas [2 ] = reduce_to_pi (deltas [2 ])
298- sel = np .abs (deltas ) > self .change_thresholds
299- self .past_values [sel ] = c_values [sel ]
300- c_values = self .past_values
211+ current_position = _propagate_change_above_threshold (
212+ current_position , self .previous_position , self .change_thresholds
213+ )
214+ self .previous_position = current_position
301215
302- logout = self ._output_type (* c_values )
303- self .log .update_list (t , logout )
216+ return t , current_position
304217
305- return c_values
218+ def update (self ):
219+ self .output_queue .put (* self .get_position ())
306220
307221
308222class SimulatedPositionEstimator (Estimator ):
@@ -319,15 +233,18 @@ def __init__(self, *args, motion, **kwargs):
319233 """
320234 super ().__init__ (* args , ** kwargs )
321235 self .motion = motion
322- self ._output_type = namedtuple ( "f" , [ "x" , "y" , "theta" ])
236+ self ._output_type = PositionEstimate
323237
324- def get_position (self ):
238+ def get_position (self ) -> Tuple [ float , PositionEstimate ] :
325239 t = (datetime .datetime .now () - self .exp .t0 ).total_seconds ()
326240
327- kt = tuple (
328- np .interp (t , self .motion .t , self .motion [p ]) for p in ("y " , "x " , "theta" )
241+ kt = PositionEstimate (
242+ * ( np .interp (t , self .motion .t , self .motion [p ]) for p in ("x " , "y " , "theta" ) )
329243 )
330- return kt
244+ return t , kt
245+
246+ def update (self ):
247+ self .output_queue .put (* self .get_position ())
331248
332249
333250estimator_dict = dict (
0 commit comments