Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 44 additions & 29 deletions stytra/collectors/accumulators.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from typing import Optional

from PyQt5.QtCore import QObject, pyqtSignal
import datetime
import numpy as np
Expand All @@ -7,23 +9,40 @@
from bisect import bisect_right
from os.path import basename

from stytra.collectors.namedtuplequeue import NamedTupleQueue
from stytra.utilities import save_df


class Accumulator(QObject):
def __init__(self, experiment, name="", max_history_if_not_running=1000):
def __init__(self, name="", max_trimmed_len=1000, trim = False):
super().__init__()
self.name = name
self.exp = experiment
self.stored_data = []
self.times = []
self.max_history_if_not_running = max_history_if_not_running
self.max_trimmed_len = max_trimmed_len
self._trim = trim #

@property
def trim(self) -> bool:
return self._trim

def trim_data(self):
if self.trim and len(self.times) > self.max_trimmed_len * 1.5:
self.times[: -self.max_trimmed_len] = []
self.stored_data[: -self.max_trimmed_len] = []

@property
def t0(self) -> float:
raise NotImplementedError

def is_empty(self) -> bool:
return len(self.stored_data) == 0


class DataFrameAccumulator(Accumulator):
"""Abstract class for accumulating streams of data.

It is use to save or plot in real time data from stimulus logs or
It is used to save or plot in real time data from stimulus logs or
behavior tracking. Data is stored in a list in the stored_data
attribute.

Expand Down Expand Up @@ -76,20 +95,23 @@ def __getitem__(self, item):
def t(self):
return np.array(self.times)

def values_at_abs_time(self, time):
def values_at_abs_time(self, time, t0):
"""Finds the values in the accumulator closest to the datetime time

Parameters
----------
time : datetime
time to search for

t0:
reference time 0

Returns
-------
namedtuple of values

"""
find_time = (time - self.exp.t0).total_seconds()
find_time = (time - t0).total_seconds()
i = bisect_right(self.times, find_time)
return self.stored_data[i - 1]

Expand Down Expand Up @@ -128,14 +150,6 @@ def reset(self, monitored_headers=None):

self._header_dict = None

def trim_data(self):
if (
not self.exp.protocol_runner.running
and len(self.times) > self.max_history_if_not_running * 1.5
):
self.times[: -self.max_history_if_not_running] = []
self.stored_data[: -self.max_history_if_not_running] = []

def get_fps(self):
""" """
try:
Expand Down Expand Up @@ -223,9 +237,6 @@ def save(self, path, format="csv"):
saved_filename = save_df(df, path, format)
return basename(saved_filename)

def is_empty(self):
return len(self.stored_data) == 0


class QueueDataAccumulator(DataFrameAccumulator):
"""General class for retrieving data from a Queue.
Expand All @@ -239,31 +250,40 @@ class QueueDataAccumulator(DataFrameAccumulator):

Parameters
----------
data_queue : (multiprocessing.Queue object)
data_queue : NamedTupleQueue
queue from witch to retrieve data.
output_queue:Optional[NamedTupleQueue]
an optional queue to forward the data to
header_list : list of str
headers for the data to stored.

Returns
-------
headers for the data to be stored.

"""

def __init__(self, data_queue, **kwargs):
def __init__(
self,
data_queue: NamedTupleQueue,
output_queue: Optional[NamedTupleQueue] = None,
**kwargs
):
""" """
super().__init__(**kwargs)

# Store externally the starting time make us free to keep
# only time differences in milliseconds in the list (faster)
self.starting_time = None
self.data_queue = data_queue
self.output_queue = output_queue

def update_list(self):
"""Upon calling put all available data into a list."""
while True:
try:
# Get data from queue:
t, data = self.data_queue.get(timeout=0.001)

if self.output_queue is not None:
self.output_queue.put(t, data)

newtype = False
if len(self.stored_data) == 0 or type(data) != type(
self.stored_data[-1]
Expand Down Expand Up @@ -292,11 +312,6 @@ def __init__(self, *args, goal_framerate=None, **kwargs):
super().__init__(*args, **kwargs)
self.goal_framerate = goal_framerate

def trim_data(self):
if len(self.times) > self.max_history_if_not_running * 1.5:
self.times[: -self.max_history_if_not_running] = []
self.stored_data[: -self.max_history_if_not_running] = []

def reset(self):
self.times = []
self.stored_data = []
Expand All @@ -313,7 +328,7 @@ def __init__(self, *args, queue, **kwargs):
super().__init__(*args, **kwargs)
self.queue = queue

def update_list(self):
def update_list(self, fps):
while True:
try:
# Get data from queue:
Expand Down
2 changes: 1 addition & 1 deletion stytra/examples/custom_tracking_exp.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,7 @@ def retrieve_image(self):
# To match tracked points and frame displayed looks for matching
# timestamps of the displayed frame and of tracked queue:
retrieved_data = self.experiment.acc_tracking.values_at_abs_time(
self.current_frame_time
self.current_frame_time, self.experiment.t0
)

# Check for valid data to be displayed:
Expand Down
60 changes: 29 additions & 31 deletions stytra/experiments/tracking_experiments.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
EstimatorLog,
FramerateQueueAccumulator,
)
from stytra.stimulation.estimator_process import EstimatorProcess
from stytra.tracking.tracking_process import TrackingProcess
from stytra.tracking.pipelines import Pipeline
from stytra.collectors.namedtuplequeue import NamedTupleQueue
Expand Down Expand Up @@ -191,9 +192,7 @@ class TrackingExperiment(CameraVisualExperiment):

"""

def __init__(
self, *args, tracking, recording=None, second_output_queue=None, **kwargs
):
def __init__(self, *args, tracking, recording=None, second_output_queue=None, **kwargs):
"""
:param tracking_method: class with the parameters for tracking (instance
of TrackingMethod class, defined in the child);
Expand All @@ -210,14 +209,10 @@ def __init__(
super().__init__(*args, **kwargs)
self.arguments.update(locals())

self.recording_event = (
Event() if (recording is not None or recording is False) else None
)
self.recording_event = Event() if (recording is not None or recording is False) else None

self.pipeline_cls = (
pipeline_dict.get(tracking["method"], None)
if isinstance(tracking["method"], str)
else tracking["method"]
pipeline_dict.get(tracking["method"], None) if isinstance(tracking["method"], str) else tracking["method"]
)

self.frame_dispatcher = TrackingProcess(
Expand All @@ -237,20 +232,6 @@ def __init__(
assert isinstance(self.pipeline, Pipeline)
self.pipeline.setup(tree=self.dc)

self.acc_tracking = QueueDataAccumulator(
name="tracking",
experiment=self,
data_queue=self.tracking_output_queue,
monitored_headers=self.pipeline.headers_to_plot,
)
self.acc_tracking.sig_acc_init.connect(self.refresh_plots)

# Data accumulator is updated with GUI timer:
self.gui_timer.timeout.connect(self.acc_tracking.update_list)

# Tracking is reset at experiment start:
self.protocol_runner.sig_protocol_started.connect(self.acc_tracking.reset)

# start frame dispatcher process:
self.frame_dispatcher.start()

Expand All @@ -263,15 +244,34 @@ def __init__(
est = est_type

if est is not None:
self.estimator_process = EstimatorProcess(est_type, self.tracking_output_queue, self.finished_sig)
self.estimator_log = EstimatorLog(experiment=self)
self.estimator = est(
self.acc_tracking,
experiment=self,
**tracking.get("estimator_params", {})
)
self.estimator = est(self.acc_tracking, experiment=self)
first_est_params = tracking.get("estimator_params", None)
if first_est_params is not None:
self.estimator_process.estimator_parameter_queue.put(first_est_params)

self.estimator_log.sig_acc_init.connect(self.refresh_plots)
tracking_output_queue = self.estimator_process.tracking_output_queue
self.protocol_runner.attach_estimator_queue(self.est)
self.estimator_process.start()
else:
self.estimator = None
tracking_output_queue = self.tracking_output_queue

self.acc_tracking = QueueDataAccumulator(
name="tracking",
experiment=self,
data_queue=tracking_output_queue,
monitored_headers=self.pipeline.headers_to_plot,
)
self.acc_tracking.sig_acc_init.connect(self.refresh_plots)

# Data accumulator is updated with GUI timer:
self.gui_timer.timeout.connect(self.acc_tracking.update_list)

# Tracking is reset at experiment start:
self.protocol_runner.sig_protocol_started.connect(self.acc_tracking.reset)

self.acc_tracking_framerate = FramerateQueueAccumulator(
self,
Expand Down Expand Up @@ -376,9 +376,7 @@ def end_protocol(self, save=True):
def save_data(self):
"""Save tail position and dynamic parameters and terminate."""

self.window_main.camera_display.save_image(
name=self.filename_base() + "img.png"
)
self.window_main.camera_display.save_image(name=self.filename_base() + "img.png")
self.dc.add_static_data(self.filename_prefix() + "img.png", "tracking/image")

# Save log and estimators:
Expand Down
6 changes: 3 additions & 3 deletions stytra/gui/camera_display.py
Original file line number Diff line number Diff line change
Expand Up @@ -340,7 +340,7 @@ def retrieve_image(self):
# To match tracked points and frame displayed looks for matching
# timestamps from the two different queues:
retrieved_data = self.experiment.acc_tracking.values_at_abs_time(
self.current_frame_time
self.current_frame_time, self.experiment.t0
)
# Check for data to be displayed:
# Retrieve tail angles from tail
Expand Down Expand Up @@ -442,7 +442,7 @@ def retrieve_image(self):
# To match tracked points and frame displayed looks for matching
# timestamps from the two different queues:
retrieved_data = self.experiment.acc_tracking.values_at_abs_time(
self.current_frame_time
self.current_frame_time, self.experiment.t0
)
# Check for data to be displayed:

Expand Down Expand Up @@ -622,7 +622,7 @@ def retrieve_image(self):
return

current_data = self.experiment.acc_tracking.values_at_abs_time(
self.current_frame_time
self.current_frame_time, self.experiment.t0
)

n_fish = self.tracking_params.n_fish_max
Expand Down
43 changes: 43 additions & 0 deletions stytra/stimulation/estimator_process.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
from multiprocessing import Event, Process, Queue
from queue import Empty
from typing import Type

from stytra.collectors import QueueDataAccumulator
from stytra.collectors.namedtuplequeue import NamedTupleQueue
from stytra.stimulation.estimators import Estimator


class EstimatorProcess(Process):
def __init__(
self,
estimator_cls: Type[Estimator],
tracking_queue: NamedTupleQueue,
estimator_parameter_queue: Queue,
finished_signal: Event,
):
super().__init__()
self.tracking_queue = tracking_queue
self.tracking_output_queue = NamedTupleQueue()
self.estimator_parameter_queue = estimator_parameter_queue
self.estimator_queue = NamedTupleQueue()
self.tracking_accumulator = QueueDataAccumulator(self.tracking_queue, self.tracking_output_queue)
self.finished_signal = finished_signal
self.estimator_cls = estimator_cls


def update_estimator_params(self, estimator):
while True:
try:
param_dict = self.estimator_parameter_queue.get(timeout=0.0001)
estimator.update_params(param_dict)
except Empty:
break


def run(self):
estimator = self.estimator_cls(self.tracking_accumulator, self.estimator_queue)

while not self.finished_signal.is_set():
self.update_estimator_params(estimator)
self.tracking_accumulator.update_list()
estimator.update()
Loading