@@ -174,7 +174,7 @@ def compute(self, **kwargs):
174174 self .criteria ['_task_passed_trial_checks' ] = {'NOT_SET' : 0 }
175175
176176 self .log .info (f'Session { self .session_path } : Running QC on behavior data...' )
177- self .metrics , self . passed = get_bpodqc_metrics_frame (
177+ self .get_bpodqc_metrics_frame (
178178 self .extractor .data ,
179179 wheel_gain = self .extractor .settings ['STIM_GAIN' ], # The wheel gain
180180 photodiode = self .extractor .frame_ttls ,
@@ -183,7 +183,56 @@ def compute(self, **kwargs):
183183 min_qt = self .extractor .settings .get ('QUIESCENT_PERIOD' ) or 0.2 ,
184184 audio_output = self .extractor .settings .get ('device_sound' , {}).get ('OUTPUT' , 'unknown' )
185185 )
186- return
186+
187+ def _get_checks (self ):
188+ """
189+ Find all methods that begin with 'check_'.
190+
191+ Returns
192+ -------
193+ Dict[str, function]
194+ A map of QC check function names and the corresponding functions that return `metric`
195+ (any), `passed` (bool).
196+ """
197+ def is_metric (x ):
198+ return isfunction (x ) and x .__name__ .startswith ('check_' )
199+
200+ return dict (getmembers (sys .modules [__name__ ], is_metric ))
201+
202+ def get_bpodqc_metrics_frame (self , data , ** kwargs ):
203+ """
204+ Evaluates all the QC metric functions in this module (those starting with 'check') and
205+ returns the results. The optional kwargs listed below are passed to each QC metric function.
206+ :param data: dict of extracted task data
207+ :param re_encoding: the encoding of the wheel data, X1, X2 or X4
208+ :param enc_res: the rotary encoder resolution
209+ :param wheel_gain: the STIM_GAIN task parameter
210+ :param photodiode: the fronts from Bpod's BNC1 input or FPGA frame2ttl channel
211+ :param audio: the fronts from Bpod's BNC2 input FPGA audio sync channel
212+ :param min_qt: the QUIESCENT_PERIOD task parameter
213+ :return metrics: dict of checks and their QC metrics
214+ :return passed: dict of checks and a float array of which samples passed
215+ """
216+
217+ # Find all methods that begin with 'check_'
218+ checks = self ._get_checks ()
219+ prefix = '_task_' # Extended QC fields will start with this
220+ # Method 'check_foobar' stored with key '_task_foobar' in metrics map
221+ qc_metrics_map = {prefix + k [6 :]: fn (data , ** kwargs ) for k , fn in checks .items ()}
222+
223+ # Split metrics and passed frames
224+ self .metrics = {}
225+ self .passed = {}
226+ for k in qc_metrics_map :
227+ self .metrics [k ], self .passed [k ] = qc_metrics_map [k ]
228+
229+ # Add a check for trial level pass: did a given trial pass all checks?
230+ n_trials = data ['intervals' ].shape [0 ]
231+ # Trial-level checks return an array the length that equals the number of trials
232+ trial_level_passed = [m for m in self .passed .values () if isinstance (m , Sized ) and len (m ) == n_trials ]
233+ name = prefix + 'passed_trial_checks'
234+ self .metrics [name ] = reduce (np .logical_and , trial_level_passed or (None , None ))
235+ self .passed [name ] = self .metrics [name ].astype (float ) if trial_level_passed else None
187236
188237 def run (self , update = False , namespace = 'task' , ** kwargs ):
189238 """
@@ -377,46 +426,6 @@ def compute(self, download_data=None, **kwargs):
377426 self .metrics , self .passed = (metrics , passed )
378427
379428
380- def get_bpodqc_metrics_frame (data , ** kwargs ):
381- """
382- Evaluates all the QC metric functions in this module (those starting with 'check') and
383- returns the results. The optional kwargs listed below are passed to each QC metric function.
384- :param data: dict of extracted task data
385- :param re_encoding: the encoding of the wheel data, X1, X2 or X4
386- :param enc_res: the rotary encoder resolution
387- :param wheel_gain: the STIM_GAIN task parameter
388- :param photodiode: the fronts from Bpod's BNC1 input or FPGA frame2ttl channel
389- :param audio: the fronts from Bpod's BNC2 input FPGA audio sync channel
390- :param min_qt: the QUIESCENT_PERIOD task parameter
391- :return metrics: dict of checks and their QC metrics
392- :return passed: dict of checks and a float array of which samples passed
393- """
394- def is_metric (x ):
395- return isfunction (x ) and x .__name__ .startswith ('check_' )
396- # Find all methods that begin with 'check_'
397- checks = getmembers (sys .modules [__name__ ], is_metric )
398- prefix = '_task_' # Extended QC fields will start with this
399- # Method 'check_foobar' stored with key '_task_foobar' in metrics map
400- qc_metrics_map = {prefix + k [6 :]: fn (data , ** kwargs ) for k , fn in checks }
401-
402- # Split metrics and passed frames
403- metrics = {}
404- passed = {}
405- for k in qc_metrics_map :
406- metrics [k ], passed [k ] = qc_metrics_map [k ]
407-
408- # Add a check for trial level pass: did a given trial pass all checks?
409- n_trials = data ['intervals' ].shape [0 ]
410- # Trial-level checks return an array the length that equals the number of trials
411- trial_level_passed = [m for m in passed .values ()
412- if isinstance (m , Sized ) and len (m ) == n_trials ]
413- name = prefix + 'passed_trial_checks'
414- metrics [name ] = reduce (np .logical_and , trial_level_passed or (None , None ))
415- passed [name ] = metrics [name ].astype (float ) if trial_level_passed else None
416-
417- return metrics , passed
418-
419-
420429# SINGLE METRICS
421430# ---------------------------------------------------------------------------- #
422431
0 commit comments