Skip to content

Commit eb18b54

Browse files
author
vilim
committed
estimator in progress
1 parent 3439e78 commit eb18b54

File tree

7 files changed

+86
-50
lines changed

7 files changed

+86
-50
lines changed

stytra/collectors/accumulators.py

Lines changed: 24 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from typing import Optional
2+
13
from PyQt5.QtCore import QObject, pyqtSignal
24
import datetime
35
import numpy as np
@@ -7,14 +9,15 @@
79
from bisect import bisect_right
810
from os.path import basename
911

12+
from stytra.collectors.namedtuplequeue import NamedTupleQueue
1013
from stytra.utilities import save_df
1114

1215

1316
class Accumulator(QObject):
1417
def __init__(self, experiment, name="", max_history_if_not_running=1000):
1518
super().__init__()
1619
self.name = name
17-
self.exp = experiment
20+
#self.exp = experiment
1821
self.stored_data = []
1922
self.times = []
2023
self.max_history_if_not_running = max_history_if_not_running
@@ -76,20 +79,23 @@ def __getitem__(self, item):
7679
def t(self):
7780
return np.array(self.times)
7881

79-
def values_at_abs_time(self, time):
82+
def values_at_abs_time(self, time, t0):
8083
"""Finds the values in the accumulator closest to the datetime time
8184
8285
Parameters
8386
----------
8487
time : datetime
8588
time to search for
8689
90+
t0:
91+
reference time 0
92+
8793
Returns
8894
-------
8995
namedtuple of values
9096
9197
"""
92-
find_time = (time - self.exp.t0).total_seconds()
98+
find_time = (time - t0).total_seconds()
9399
i = bisect_right(self.times, find_time)
94100
return self.stored_data[i - 1]
95101

@@ -239,31 +245,40 @@ class QueueDataAccumulator(DataFrameAccumulator):
239245
240246
Parameters
241247
----------
242-
data_queue : (multiprocessing.Queue object)
248+
data_queue : NamedTupleQueue
243249
queue from witch to retrieve data.
250+
output_queue:Optional[NamedTupleQueue]
251+
an optinal queue to forward the data to
244252
header_list : list of str
245253
headers for the data to stored.
246254
247-
Returns
248-
-------
249-
250255
"""
251256

252-
def __init__(self, data_queue, **kwargs):
257+
def __init__(
258+
self,
259+
data_queue: NamedTupleQueue,
260+
output_queue: Optional[NamedTupleQueue] = None,
261+
**kwargs
262+
):
253263
""" """
254264
super().__init__(**kwargs)
255265

256266
# Store externally the starting time make us free to keep
257267
# only time differences in milliseconds in the list (faster)
258268
self.starting_time = None
259269
self.data_queue = data_queue
270+
self.output_queue = output_queue
260271

261272
def update_list(self):
262273
"""Upon calling put all available data into a list."""
263274
while True:
264275
try:
265276
# Get data from queue:
266277
t, data = self.data_queue.get(timeout=0.001)
278+
279+
if self.output_queue is not None:
280+
self.output_queue.put(t, data)
281+
267282
newtype = False
268283
if len(self.stored_data) == 0 or type(data) != type(
269284
self.stored_data[-1]
@@ -313,7 +328,7 @@ def __init__(self, *args, queue, **kwargs):
313328
super().__init__(*args, **kwargs)
314329
self.queue = queue
315330

316-
def update_list(self):
331+
def update_list(self, fps):
317332
while True:
318333
try:
319334
# Get data from queue:

stytra/examples/custom_tracking_exp.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -180,7 +180,7 @@ def retrieve_image(self):
180180
# To match tracked points and frame displayed looks for matching
181181
# timestamps of the displayed frame and of tracked queue:
182182
retrieved_data = self.experiment.acc_tracking.values_at_abs_time(
183-
self.current_frame_time
183+
self.current_frame_time, self.experiment.t0
184184
)
185185

186186
# Check for valid data to be displayed:

stytra/experiments/tracking_experiments.py

Lines changed: 23 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
EstimatorLog,
2222
FramerateQueueAccumulator,
2323
)
24+
from stytra.stimulation.estimator_process import EstimatorProcess
2425
from stytra.tracking.tracking_process import TrackingProcess
2526
from stytra.tracking.pipelines import Pipeline
2627
from stytra.collectors.namedtuplequeue import NamedTupleQueue
@@ -191,9 +192,7 @@ class TrackingExperiment(CameraVisualExperiment):
191192
192193
"""
193194

194-
def __init__(
195-
self, *args, tracking, recording=None, second_output_queue=None, **kwargs
196-
):
195+
def __init__(self, *args, tracking, recording=None, second_output_queue=None, **kwargs):
197196
"""
198197
:param tracking_method: class with the parameters for tracking (instance
199198
of TrackingMethod class, defined in the child);
@@ -210,14 +209,10 @@ def __init__(
210209
super().__init__(*args, **kwargs)
211210
self.arguments.update(locals())
212211

213-
self.recording_event = (
214-
Event() if (recording is not None or recording is False) else None
215-
)
212+
self.recording_event = Event() if (recording is not None or recording is False) else None
216213

217214
self.pipeline_cls = (
218-
pipeline_dict.get(tracking["method"], None)
219-
if isinstance(tracking["method"], str)
220-
else tracking["method"]
215+
pipeline_dict.get(tracking["method"], None) if isinstance(tracking["method"], str) else tracking["method"]
221216
)
222217

223218
self.frame_dispatcher = TrackingProcess(
@@ -237,20 +232,6 @@ def __init__(
237232
assert isinstance(self.pipeline, Pipeline)
238233
self.pipeline.setup(tree=self.dc)
239234

240-
self.acc_tracking = QueueDataAccumulator(
241-
name="tracking",
242-
experiment=self,
243-
data_queue=self.tracking_output_queue,
244-
monitored_headers=self.pipeline.headers_to_plot,
245-
)
246-
self.acc_tracking.sig_acc_init.connect(self.refresh_plots)
247-
248-
# Data accumulator is updated with GUI timer:
249-
self.gui_timer.timeout.connect(self.acc_tracking.update_list)
250-
251-
# Tracking is reset at experiment start:
252-
self.protocol_runner.sig_protocol_started.connect(self.acc_tracking.reset)
253-
254235
# start frame dispatcher process:
255236
self.frame_dispatcher.start()
256237

@@ -263,15 +244,28 @@ def __init__(
263244
est = est_type
264245

265246
if est is not None:
247+
self.estimator_process = EstimatorProcess(est_type, self.tracking_output_queue, self.finished_sig)
266248
self.estimator_log = EstimatorLog(experiment=self)
267-
self.estimator = est(
268-
self.acc_tracking,
269-
experiment=self,
270-
**tracking.get("estimator_params", {})
271-
)
249+
self.estimator = est(self.acc_tracking, experiment=self, **tracking.get("estimator_params", {}))
272250
self.estimator_log.sig_acc_init.connect(self.refresh_plots)
251+
tracking_output_queue = self.estimator_process.tracking_output_queue
273252
else:
274253
self.estimator = None
254+
tracking_output_queue = self.tracking_output_queue
255+
256+
self.acc_tracking = QueueDataAccumulator(
257+
name="tracking",
258+
experiment=self,
259+
data_queue=tracking_output_queue,
260+
monitored_headers=self.pipeline.headers_to_plot,
261+
)
262+
self.acc_tracking.sig_acc_init.connect(self.refresh_plots)
263+
264+
# Data accumulator is updated with GUI timer:
265+
self.gui_timer.timeout.connect(self.acc_tracking.update_list)
266+
267+
# Tracking is reset at experiment start:
268+
self.protocol_runner.sig_protocol_started.connect(self.acc_tracking.reset)
275269

276270
self.acc_tracking_framerate = FramerateQueueAccumulator(
277271
self,
@@ -376,9 +370,7 @@ def end_protocol(self, save=True):
376370
def save_data(self):
377371
"""Save tail position and dynamic parameters and terminate."""
378372

379-
self.window_main.camera_display.save_image(
380-
name=self.filename_base() + "img.png"
381-
)
373+
self.window_main.camera_display.save_image(name=self.filename_base() + "img.png")
382374
self.dc.add_static_data(self.filename_prefix() + "img.png", "tracking/image")
383375

384376
# Save log and estimators:

stytra/gui/camera_display.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -340,7 +340,7 @@ def retrieve_image(self):
340340
# To match tracked points and frame displayed looks for matching
341341
# timestamps from the two different queues:
342342
retrieved_data = self.experiment.acc_tracking.values_at_abs_time(
343-
self.current_frame_time
343+
self.current_frame_time, self.experiment.t0
344344
)
345345
# Check for data to be displayed:
346346
# Retrieve tail angles from tail
@@ -442,7 +442,7 @@ def retrieve_image(self):
442442
# To match tracked points and frame displayed looks for matching
443443
# timestamps from the two different queues:
444444
retrieved_data = self.experiment.acc_tracking.values_at_abs_time(
445-
self.current_frame_time
445+
self.current_frame_time, self.experiment.t0
446446
)
447447
# Check for data to be displayed:
448448

@@ -622,7 +622,7 @@ def retrieve_image(self):
622622
return
623623

624624
current_data = self.experiment.acc_tracking.values_at_abs_time(
625-
self.current_frame_time
625+
self.current_frame_time, self.experiment.t0
626626
)
627627

628628
n_fish = self.tracking_params.n_fish_max
Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
from multiprocessing import Event, Process
2+
from typing import Type
3+
4+
from stytra.collectors import QueueDataAccumulator
5+
from stytra.collectors.namedtuplequeue import NamedTupleQueue
6+
from stytra.stimulation.estimators import Estimator
7+
8+
9+
class EstimatorProcess(Process):
10+
def __init__(
11+
self,
12+
estimator_cls: Type[Estimator],
13+
tracking_queue: NamedTupleQueue,
14+
finished_signal: Event,
15+
):
16+
super().__init__()
17+
self.tracking_queue = tracking_queue
18+
self.tracking_output_queue = NamedTupleQueue()
19+
self.estimator_queue = NamedTupleQueue()
20+
self.tracking_accumulator = QueueDataAccumulator(self.tracking_queue, self.tracking_output_queue)
21+
self.finished_signal = finished_signal
22+
self.estimator_cls = estimator_cls
23+
24+
def run(self):
25+
estimator = self.estimator_cls(self.tracking_accumulator, self.estimator_queue)
26+
27+
while not self.finished_signal.is_set():
28+
self.tracking_accumulator.update_list()
29+
estimator.update()

stytra/stimulation/estimators.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,10 @@ class Estimator:
1616
stream of the tracking pipelines (position in pixels, tail angles, etc.).
1717
"""
1818

19-
def __init__(self, acc_tracking: QueueDataAccumulator, experiment):
20-
self.exp = experiment
19+
def __init__(self, acc_tracking: QueueDataAccumulator, output_queue: NamedTupleQueue, cam_to_proj=None):
2120
self.acc_tracking = acc_tracking
22-
self.output_queue = NamedTupleQueue()
21+
self.output_queue = output_queue
22+
self.cam_to_proj = cam_to_proj
2323
self._output_type = None
2424

2525
def update(self):
@@ -184,8 +184,8 @@ def get_position(self) -> Tuple[float, PositionEstimate]:
184184
past_coords = self.acc_tracking.stored_data[-1]
185185
t = self.acc_tracking.times[-1]
186186

187-
if not self.calibrator.cam_to_proj is None:
188-
projmat = np.array(self.calibrator.cam_to_proj)
187+
if not self.cam_to_proj is None:
188+
projmat = np.array(self.cam_to_proj)
189189
if projmat.shape != (2, 3):
190190
projmat = np.array([[1.0, 0.0, 0.0], [0.0, 1.0, 0.0]])
191191

stytra/tracking/tracking_process.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from queue import Empty, Full
2-
from multiprocessing import Event, Value
2+
from multiprocessing import Event
33

44
from stytra.utilities import FrameProcess
55
from arrayqueues.shared_arrays import TimestampedArrayQueue

0 commit comments

Comments
 (0)