2121 EstimatorLog ,
2222 FramerateQueueAccumulator ,
2323)
24+ from stytra .stimulation .estimator_process import EstimatorProcess
2425from stytra .tracking .tracking_process import TrackingProcess
2526from stytra .tracking .pipelines import Pipeline
2627from stytra .collectors .namedtuplequeue import NamedTupleQueue
@@ -191,9 +192,7 @@ class TrackingExperiment(CameraVisualExperiment):
191192
192193 """
193194
194- def __init__ (
195- self , * args , tracking , recording = None , second_output_queue = None , ** kwargs
196- ):
195+ def __init__ (self , * args , tracking , recording = None , second_output_queue = None , ** kwargs ):
197196 """
198197 :param tracking_method: class with the parameters for tracking (instance
199198 of TrackingMethod class, defined in the child);
@@ -210,14 +209,10 @@ def __init__(
210209 super ().__init__ (* args , ** kwargs )
211210 self .arguments .update (locals ())
212211
213- self .recording_event = (
214- Event () if (recording is not None or recording is False ) else None
215- )
212+ self .recording_event = Event () if (recording is not None or recording is False ) else None
216213
217214 self .pipeline_cls = (
218- pipeline_dict .get (tracking ["method" ], None )
219- if isinstance (tracking ["method" ], str )
220- else tracking ["method" ]
215+ pipeline_dict .get (tracking ["method" ], None ) if isinstance (tracking ["method" ], str ) else tracking ["method" ]
221216 )
222217
223218 self .frame_dispatcher = TrackingProcess (
@@ -237,20 +232,6 @@ def __init__(
237232 assert isinstance (self .pipeline , Pipeline )
238233 self .pipeline .setup (tree = self .dc )
239234
240- self .acc_tracking = QueueDataAccumulator (
241- name = "tracking" ,
242- experiment = self ,
243- data_queue = self .tracking_output_queue ,
244- monitored_headers = self .pipeline .headers_to_plot ,
245- )
246- self .acc_tracking .sig_acc_init .connect (self .refresh_plots )
247-
248- # Data accumulator is updated with GUI timer:
249- self .gui_timer .timeout .connect (self .acc_tracking .update_list )
250-
251- # Tracking is reset at experiment start:
252- self .protocol_runner .sig_protocol_started .connect (self .acc_tracking .reset )
253-
254235 # start frame dispatcher process:
255236 self .frame_dispatcher .start ()
256237
@@ -263,15 +244,28 @@ def __init__(
263244 est = est_type
264245
265246 if est is not None :
247+ self .estimator_process = EstimatorProcess (est_type , self .tracking_output_queue , self .finished_sig )
266248 self .estimator_log = EstimatorLog (experiment = self )
267- self .estimator = est (
268- self .acc_tracking ,
269- experiment = self ,
270- ** tracking .get ("estimator_params" , {})
271- )
249+ self .estimator = est (self .acc_tracking , experiment = self , ** tracking .get ("estimator_params" , {}))
272250 self .estimator_log .sig_acc_init .connect (self .refresh_plots )
251+ tracking_output_queue = self .estimator_process .tracking_output_queue
273252 else :
274253 self .estimator = None
254+ tracking_output_queue = self .tracking_output_queue
255+
256+ self .acc_tracking = QueueDataAccumulator (
257+ name = "tracking" ,
258+ experiment = self ,
259+ data_queue = tracking_output_queue ,
260+ monitored_headers = self .pipeline .headers_to_plot ,
261+ )
262+ self .acc_tracking .sig_acc_init .connect (self .refresh_plots )
263+
264+ # Data accumulator is updated with GUI timer:
265+ self .gui_timer .timeout .connect (self .acc_tracking .update_list )
266+
267+ # Tracking is reset at experiment start:
268+ self .protocol_runner .sig_protocol_started .connect (self .acc_tracking .reset )
275269
276270 self .acc_tracking_framerate = FramerateQueueAccumulator (
277271 self ,
@@ -376,9 +370,7 @@ def end_protocol(self, save=True):
376370 def save_data (self ):
377371 """Save tail position and dynamic parameters and terminate."""
378372
379- self .window_main .camera_display .save_image (
380- name = self .filename_base () + "img.png"
381- )
373+ self .window_main .camera_display .save_image (name = self .filename_base () + "img.png" )
382374 self .dc .add_static_data (self .filename_prefix () + "img.png" , "tracking/image" )
383375
384376 # Save log and estimators:
0 commit comments