Skip to content

Commit 3439e78

Browse files
author
vilim
committed
estimator cleanup
1 parent 77587c6 commit 3439e78

File tree

4 files changed

+93
-170
lines changed

4 files changed

+93
-170
lines changed

stytra/stimulation/estimators.py

Lines changed: 86 additions & 169 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,12 @@
1-
import numpy as np
21
import datetime
2+
from collections import namedtuple
3+
from typing import NamedTuple, Optional, Tuple
4+
5+
import numpy as np
36

47
from stytra.collectors import QueueDataAccumulator
8+
from stytra.collectors.namedtuplequeue import NamedTupleQueue
59
from stytra.utilities import reduce_to_pi
6-
from collections import namedtuple
710

811

912
class Estimator:
@@ -15,11 +18,19 @@ class Estimator:
1518

1619
def __init__(self, acc_tracking: QueueDataAccumulator, experiment):
1720
self.exp = experiment
18-
self.log = experiment.estimator_log
1921
self.acc_tracking = acc_tracking
22+
self.output_queue = NamedTupleQueue()
23+
self._output_type = None
24+
25+
def update(self):
26+
raise NotImplementedError
2027

2128
def reset(self):
22-
self.log.reset()
29+
pass
30+
31+
32+
class VigorEstimate(NamedTuple):
33+
vigor: float
2334

2435

2536
class VigorMotionEstimator(Estimator):
@@ -34,27 +45,15 @@ def __init__(self, *args, vigor_window=0.050, base_gain=-12, **kwargs):
3445
self.vigor_window = vigor_window
3546
self.last_dt = 1 / 500.0
3647
self.base_gain = base_gain
37-
self._output_type = namedtuple("s", "vigor")
38-
39-
def get_velocity(self, lag=0):
40-
"""
41-
42-
Parameters
43-
----------
44-
lag :
45-
(Default value = 0)
48+
self._output_type = namedtuple("vigor_estimate", ("vigor",))
4649

47-
Returns
48-
-------
49-
50-
"""
50+
def get_vigor(self):
5151
vigor_n_samples = max(int(round(self.vigor_window / self.last_dt)), 2)
52-
n_samples_lag = max(int(round(lag / self.last_dt)), 0)
5352
if not self.acc_tracking.stored_data:
5453
return 0
55-
past_tail_motion = self.acc_tracking.get_last_n(
56-
vigor_n_samples + n_samples_lag
57-
)[0:vigor_n_samples]
54+
past_tail_motion = self.acc_tracking.get_last_n(vigor_n_samples)[
55+
0:vigor_n_samples
56+
]
5857
end_t = past_tail_motion.t.iloc[-1]
5958
start_t = past_tail_motion.t.iloc[0]
6059
new_dt = (end_t - start_t) / vigor_n_samples
@@ -63,10 +62,15 @@ def get_velocity(self, lag=0):
6362
vigor = np.nanstd(np.array(past_tail_motion.tail_sum))
6463
if np.isnan(vigor):
6564
vigor = 0
65+
return end_t, vigor
66+
67+
def update(self):
68+
end_t, vigor = self.get_vigor()
69+
self.output_queue.put(end_t, self._output_type(vigor))
6670

67-
if len(self.log.times) == 0 or self.log.times[-1] < end_t:
68-
self.log.update_list(end_t, self._output_type(vigor))
69-
return vigor * self.base_gain
71+
72+
class BoutEstimate(NamedTuple):
73+
is_bouting: bool
7074

7175

7276
class BoutsEstimator(VigorMotionEstimator):
@@ -78,138 +82,59 @@ def __init__(
7882
self.vigor_window = vigor_window
7983
self.min_interbout = min_interbout
8084
self.last_bout_t = None
85+
self._output_type = namedtuple("bouts", ("is_bouting",))
8186

82-
def bout_occured(self):
83-
if self.get_velocity() > self.base_gain * self.bout_threshold:
87+
def update(self):
88+
end_t, vigor = self.get_vigor()
89+
is_bouting = False
90+
if vigor > self.base_gain * self.bout_threshold:
8491
if (
8592
self.last_bout_t is None
8693
or (datetime.datetime.now() - self.last_bout_t).total_seconds()
8794
> self.min_interbout
8895
):
8996
self.last_bout_t = datetime.datetime.now()
90-
return True
91-
return False
97+
is_bouting = True
98+
self.output_queue.put(end_t, self._output_type(is_bouting))
9299

93100

94-
class TailSumEstimator(Estimator):
95-
def __init__(
96-
self,
97-
*args,
98-
vigor_window=0.050,
99-
theta_window=0.07,
100-
base_gain=-30,
101-
bout_threshold=0.05,
102-
min_interbout=0.1,
103-
**kwargs
104-
):
105-
super().__init__(*args, **kwargs)
106-
self.vigor_window = vigor_window
107-
self.theta_window = theta_window
108-
self.last_dt = 1 / 500.0
109-
self.base_gain = base_gain
110-
self._output_type = namedtuple("s", ("vigor", "theta", "bout_on"))
111-
self.bout_threshold = bout_threshold
112-
self.vigor_window = vigor_window
113-
self.min_interbout = min_interbout
114-
self.last_bout_t = None
115-
self.prev_time_on = False
116-
self.bout_onset = 0
117-
self.bout_on = 0
118-
self.theta_provided = True
119-
self.last_bout_t = 0
120-
self.last_vigor = 0
121-
self.last_bout_on = 0
122-
123-
self.tail_th = 0
124-
125-
def bout_occured(self):
126-
if self.bout_on:
127-
if (
128-
self.last_bout_t is None
129-
or (datetime.datetime.now() - self.last_bout_t).total_seconds()
130-
> self.min_interbout
131-
):
132-
self.last_bout_t = datetime.datetime.now()
133-
return True
134-
return False
101+
class EmbeddedBoutEstimate(NamedTuple):
102+
vigor: float
103+
theta: float
104+
bout_on: bool
135105

136-
def get_vel_and_theta(self, lag=0):
137-
"""
138106

139-
Parameters
140-
----------
141-
lag :
142-
(Default value = 0)
107+
class PositionEstimate(NamedTuple):
108+
x: float
109+
y: float
110+
theta: float
143111

144-
Returns
145-
-------
146112

147-
"""
148-
# Vigor (copypasted from VigorEstimator method for simplicity)
149-
vigor_n_samples = max(int(round(self.vigor_window / self.last_dt)), 2)
150-
n_samples_lag = max(int(round(lag / self.last_dt)), 0)
151-
if not self.acc_tracking.stored_data:
152-
return 0, 0, 0
153-
past_tail_motion = self.acc_tracking.get_last_n(
154-
vigor_n_samples + n_samples_lag
155-
)[0:vigor_n_samples]
156-
end_t = past_tail_motion.t.iloc[-1]
157-
start_t = past_tail_motion.t.iloc[0]
158-
new_dt = (end_t - start_t) / vigor_n_samples
159-
if new_dt > 0:
160-
self.last_dt = new_dt
161-
vigor = np.nanstd(np.array(past_tail_motion.tail_sum))
113+
def _propagate_change_above_threshold(
114+
current_estimate: PositionEstimate,
115+
previous_estimate: Optional[PositionEstimate],
116+
thresholds: PositionEstimate,
117+
) -> PositionEstimate:
118+
"""Return updated components of a position if the component changed enough, otherwise return the old component"""
119+
if previous_estimate is None:
120+
return current_estimate
162121

163-
if vigor is not None:
164-
self.bout_on = int(vigor > self.bout_threshold)
165-
else:
166-
self.bout_on = int(self.last_vigor > self.bout_threshold)
167-
168-
if self.bout_onset == 0:
169-
if self.bout_on and not self.last_bout_on:
170-
self.bout_onset = 1
171-
self.bout_start_t = datetime.datetime.now()
172-
173-
else:
174-
self.theta_provided = False
175-
self.bout_onset = 0
176-
177-
if (
178-
not self.theta_provided
179-
): # and (datetime.datetime.now() - self.bout_start_t).total_seconds() > 0.07:
180-
# Tail theta:
181-
th_n_samples = max(int(round(self.theta_window / self.last_dt)), 2)
182-
n_samples_lag = max(int(round(lag / self.last_dt)), 0)
183-
184-
past_tail_motion = self.acc_tracking.get_last_n(
185-
th_n_samples + n_samples_lag
186-
)[0:th_n_samples]
187-
self.tail_th = np.nanmean(
188-
np.array(past_tail_motion.tail_sum) - past_tail_motion.tail_sum.iloc[0]
189-
)
190-
self.theta_provided = True
191-
else:
192-
self.tail_th = self.tail_th * (3 / 4)
193-
rn = np.random.randint(0, 1) / 100
194-
on_ns = self.bout_onset + rn
195-
if len(self.log.times) == 0 or self.log.times[-1] < end_t:
196-
self.log.update_list(end_t, self._output_type(vigor, self.tail_th, on_ns))
197-
198-
if vigor is not None:
199-
self.last_vigor = vigor
200-
201-
self.last_bout_on = self.bout_on
202-
203-
return vigor * self.base_gain, -self.tail_th * 3, self.bout_on
204-
205-
206-
def rot_mat(theta):
207-
"""The rotation matrix for an angle theta"""
208-
return np.array([[np.cos(theta), -np.sin(theta)], [np.sin(theta), np.cos(theta)]])
122+
return PositionEstimate(
123+
x=current_estimate.x
124+
if abs(current_estimate.x - previous_estimate.x) > thresholds.x
125+
else previous_estimate.x,
126+
y=current_estimate.x
127+
if abs(current_estimate.y - previous_estimate.y) > thresholds.y
128+
else previous_estimate.y,
129+
theta=current_estimate.x
130+
if abs(reduce_to_pi(current_estimate.theta - previous_estimate.theta))
131+
> thresholds.theta
132+
else previous_estimate.theta,
133+
)
209134

210135

211136
class PositionEstimator(Estimator):
212-
def __init__(self, *args, change_thresholds=None, velocity_window=10, **kwargs):
137+
def __init__(self, *args, change_thresholds:Optional[PositionEstimate]=None, velocity_window:int=10, **kwargs):
213138
"""Uses the projector-to-camera calibration to give fish position in
214139
scree coordinates. If change_thresholds are set, update only the fish
215140
position after there is a big enough change (which prevents small
@@ -223,14 +148,12 @@ def __init__(self, *args, change_thresholds=None, velocity_window=10, **kwargs):
223148
super().__init__(*args, **kwargs)
224149
self.calibrator = self.exp.calibrator
225150
self.last_location = None
226-
self.past_values = None
151+
self.previous_position = None
227152

228153
self.velocity_window = velocity_window
229154
self.change_thresholds = change_thresholds
230-
if change_thresholds is not None:
231-
self.change_thresholds = np.array(change_thresholds)
232155

233-
self._output_type = namedtuple("f", ["x", "y", "theta"])
156+
self._output_type = PositionEstimate
234157

235158
def get_camera_position(self):
236159
past_coords = {
@@ -248,17 +171,10 @@ def get_velocity(self):
248171
)
249172
return np.sqrt(np.sum(vel ** 2))
250173

251-
def get_istantaneous_velocity(self):
252-
vel_xy = self.acc_tracking.get_last_n(self.velocity_window)[
253-
["f0_vx", "f0_vy"]
254-
].values
255-
return np.sqrt(np.sum(vel_xy ** 2))
256-
257174
def reset(self):
258-
super().reset()
259-
self.past_values = None
175+
self.previous_position = None
260176

261-
def get_position(self):
177+
def get_position(self) -> Tuple[float, PositionEstimate]:
262178
if len(self.acc_tracking.stored_data) == 0 or not np.isfinite(
263179
self.acc_tracking.stored_data[-1].f0_x
264180
):
@@ -286,23 +202,21 @@ def get_position(self):
286202
else:
287203
x, y, theta = past_coords.f0_x, past_coords.f0_y, past_coords.f0_theta
288204

289-
c_values = np.array((y, x, theta))
205+
current_position = PositionEstimate(x, y, theta)
290206

291207
if self.change_thresholds is not None:
208+
if self.previous_position is None:
209+
self.previous_position = current_position
292210

293-
if self.past_values is None:
294-
self.past_values = np.array(c_values)
295-
else:
296-
deltas = c_values - self.past_values
297-
deltas[2] = reduce_to_pi(deltas[2])
298-
sel = np.abs(deltas) > self.change_thresholds
299-
self.past_values[sel] = c_values[sel]
300-
c_values = self.past_values
211+
current_position = _propagate_change_above_threshold(
212+
current_position, self.previous_position, self.change_thresholds
213+
)
214+
self.previous_position = current_position
301215

302-
logout = self._output_type(*c_values)
303-
self.log.update_list(t, logout)
216+
return t, current_position
304217

305-
return c_values
218+
def update(self):
219+
self.output_queue.put(*self.get_position())
306220

307221

308222
class SimulatedPositionEstimator(Estimator):
@@ -319,15 +233,18 @@ def __init__(self, *args, motion, **kwargs):
319233
"""
320234
super().__init__(*args, **kwargs)
321235
self.motion = motion
322-
self._output_type = namedtuple("f", ["x", "y", "theta"])
236+
self._output_type = PositionEstimate
323237

324-
def get_position(self):
238+
def get_position(self) -> Tuple[float, PositionEstimate]:
325239
t = (datetime.datetime.now() - self.exp.t0).total_seconds()
326240

327-
kt = tuple(
328-
np.interp(t, self.motion.t, self.motion[p]) for p in ("y", "x", "theta")
241+
kt = PositionEstimate(
242+
*(np.interp(t, self.motion.t, self.motion[p]) for p in ("x", "y", "theta"))
329243
)
330-
return kt
244+
return t, kt
245+
246+
def update(self):
247+
self.output_queue.put(*self.get_position())
331248

332249

333250
estimator_dict = dict(

stytra/stimulation/stimuli/closed_loop.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ def __init__(
6666
self.prev_bout_t = 0
6767

6868
def get_fish_vel(self):
69-
"""Function that update estimated fish velocty. Change to add lag or
69+
"""Function that update estimated fish velocity. Change to add lag or
7070
shunting.
7171
"""
7272
self.fish_vel = self._experiment.estimator.get_velocity()

stytra/stimulation/stimuli/generic_stimuli.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,7 @@ def __init__(self, duration=0.0):
6666
self._elapsed = 0.0 # time from the beginning of the stimulus
6767
self.name = "undefined"
6868
self._experiment = None
69+
self._input_queue = None
6970
self.real_time_start = None
7071
self.real_time_stop = None
7172

stytra/utilities.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -278,3 +278,8 @@ def save_df(df, path, fileformat):
278278
else:
279279
raise (NotImplementedError(fileformat + " is not an implemented log format"))
280280
return outpath.name
281+
282+
283+
def rot_mat(theta):
284+
"""The rotation matrix for an angle theta"""
285+
return np.array([[np.cos(theta), -np.sin(theta)], [np.sin(theta), np.cos(theta)]])

0 commit comments

Comments
 (0)