Skip to content

Commit 956b1ed

Browse files
committed
_get_checks method of TaskQC for overloading in subclass;
get_bpodqc_metrics_frame moved from function to TaskQC method; QC kwarg for behaviour task run_qc method allows passing of custom TaskQC class; deprecated videopc params functions in ibllib.pipes.misc; raise in BehaviourTask.get_protocol when task protocol ambiguous
1 parent 043441f commit 956b1ed

File tree

8 files changed

+231
-65
lines changed

8 files changed

+231
-65
lines changed

ibllib/io/extractors/base.py

Lines changed: 19 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -17,13 +17,16 @@
1717

1818
class BaseExtractor(abc.ABC):
1919
"""
20-
Base extractor class
20+
Base extractor class.
21+
2122
Writing an extractor checklist:
22-
- on the child class, overload the _extract method
23-
- this method should output one or several numpy.arrays or dataframe with a consistent shape
24-
- save_names is a list or a string of filenames, there should be one per dataset
25-
- set save_names to None for a dataset that doesn't need saving (could be set dynamically
26-
in the _extract method)
23+
24+
- on the child class, overload the _extract method
25+
- this method should output one or several numpy.arrays or dataframe with a consistent shape
26+
- save_names is a list or a string of filenames, there should be one per dataset
27+
- set save_names to None for a dataset that doesn't need saving (could be set dynamically in
28+
the _extract method)
29+
2730
:param session_path: Absolute path of session folder
2831
:type session_path: str/Path
2932
"""
@@ -122,10 +125,11 @@ def _extract(self):
122125

123126
class BaseBpodTrialsExtractor(BaseExtractor):
124127
"""
125-
Base (abstract) extractor class for bpod jsonable data set
126-
Wrps the _extract private method
128+
Base (abstract) extractor class for bpod jsonable data set.
127129
128-
:param session_path: Absolute path of session folder
130+
Wraps the _extract private method.
131+
132+
:param session_path: Absolute path of session folder.
129133
:type session_path: str
130134
:param bpod_trials
131135
:param settings
@@ -159,6 +163,12 @@ def extract(self, bpod_trials=None, settings=None, **kwargs):
159163
self.settings["IBLRIG_VERSION"] = "100.0.0"
160164
return super(BaseBpodTrialsExtractor, self).extract(**kwargs)
161165

166+
@property
167+
def alf_path(self):
168+
"""pathlib.Path: The full task collection filepath."""
169+
if self.session_path:
170+
return self.session_path.joinpath(self.task_collection or '').absolute()
171+
162172

163173
def run_extractor_classes(classes, session_path=None, **kwargs):
164174
"""

ibllib/io/extractors/training_trials.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -570,9 +570,9 @@ def get_stimOn_times_ge5(session_path, data=False, task_collection='raw_behavior
570570
@staticmethod
571571
def get_stimOn_times_lt5(session_path, data=False, task_collection='raw_behavior_data'):
572572
"""
573-
Find the time of the statemachine command to turn on hte stim
573+
Find the time of the statemachine command to turn on the stim
574574
(state stim_on start or rotary_encoder_event2)
575-
Find the next frame change from the photodiodeafter that TS.
575+
Find the next frame change from the photodiode after that TS.
576576
Screen is not displaying anything until then.
577577
(Frame changes are in BNC1High and BNC1Low)
578578
"""

ibllib/pipes/base_tasks.py

Lines changed: 83 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -90,16 +90,98 @@ def __init__(self, session_path, **kwargs):
9090
self.output_collection += f'/task_{self.protocol_number:02}'
9191

9292
def get_protocol(self, protocol=None, task_collection=None):
93-
return protocol if protocol else sess_params.get_task_protocol(self.session_params, task_collection)
93+
"""
94+
Return the task protocol name.
95+
96+
This returns the task protocol based on the task collection. If `protocol` is not None, this
97+
acts as an identity function. If both `task_collection` and `protocol` are None, returns
98+
the protocol defined in the experiment description file only if a single protocol was run.
99+
If the `task_collection` is not None, the associated protocol name is returned.
100+
101+
102+
Parameters
103+
----------
104+
protocol : str
105+
A task protocol name. If not None, the same value is returned.
106+
task_collection : str
107+
The task collection whose protocol name to return. May be None if only one protocol run.
108+
109+
Returns
110+
-------
111+
str, None
112+
The task protocol name, or None, if no protocol found.
113+
114+
Raises
115+
------
116+
ValueError
117+
For session with multiple task protocols, a task collection must be passed.
118+
"""
119+
if protocol:
120+
return protocol
121+
protocol = sess_params.get_task_protocol(self.session_params, task_collection) or None
122+
if isinstance(protocol, set):
123+
if len(protocol) == 1:
124+
protocol = next(iter(protocol))
125+
else:
126+
raise ValueError('Multiple task protocols for session. Task collection must be explicitly defined.')
127+
return protocol
94128

95129
def get_task_collection(self, collection=None):
130+
"""
131+
Return the task collection.
132+
133+
If `collection` is not None, this acts as an identity function. Otherwise loads it from
134+
the experiment description if only one protocol was run.
135+
136+
Parameters
137+
----------
138+
collection : str
139+
A task collection. If not None, the same value is returned.
140+
141+
Returns
142+
-------
143+
str, None
144+
The task collection, or None if no task protocols were run.
145+
146+
Raises
147+
------
148+
AssertionError
149+
Raised if multiple protocols were run and collection is None, or if experiment
150+
description file is improperly formatted.
151+
152+
"""
96153
if not collection:
97154
collection = sess_params.get_task_collection(self.session_params)
98155
# If inferring the collection from the experiment description, assert only one returned
99156
assert collection is None or isinstance(collection, str) or len(collection) == 1
100157
return collection
101158

102159
def get_protocol_number(self, number=None, task_protocol=None):
160+
"""
161+
Return the task protocol number.
162+
163+
Numbering starts from 0. If the 'protocol_number' field is missing from the experiment
164+
description, None is returned. If `task_protocol` is None, the first protocol number if n
165+
protocols == 1, otherwise returns None.
166+
167+
NB: :func:`ibllib.pipes.dynamic_pipeline.make_pipeline` will determine the protocol number
168+
from the order of the tasks in the experiment description if the task collection follows
169+
the pattern 'raw_task_data_XX'. If the task protocol does not follow this pattern, the
170+
experiment description file should explicitly define the number with the 'protocol_number'
171+
field.
172+
173+
Parameters
174+
----------
175+
number : int
176+
The protocol number. If not None, the same value is returned.
177+
task_protocol : str
178+
The task protocol name.
179+
180+
Returns
181+
-------
182+
int, None
183+
The task protocol number, if defined.
184+
"""
103185
if number is None: # Do not use "if not number" as that will return True if number is 0
104186
number = sess_params.get_task_protocol_number(self.session_params, task_protocol)
105187
# If inferring the number from the experiment description, assert only one returned (or something went wrong)

ibllib/pipes/behavior_tasks.py

Lines changed: 30 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -277,6 +277,7 @@ class ChoiceWorldTrialsBpod(base_tasks.BehaviourTask):
277277
priority = 90
278278
job_size = 'small'
279279
extractor = None
280+
"""ibllib.io.extractors.base.BaseBpodTrialsExtractor: An instance of the Bpod trials extractor."""
280281

281282
@property
282283
def signature(self):
@@ -318,7 +319,24 @@ def _extract_behaviour(self, **kwargs):
318319
self.extractor.default_path = self.output_collection
319320
return self.extractor.extract(task_collection=self.collection, **kwargs)
320321

321-
def _run_qc(self, trials_data=None, update=True):
322+
def _run_qc(self, trials_data=None, update=True, QC=None):
323+
"""
324+
Run the task QC.
325+
326+
Parameters
327+
----------
328+
trials_data : dict
329+
The complete extracted task data.
330+
update : bool
331+
If True, updates the session QC fields on Alyx.
332+
QC : ibllib.qc.task_metrics.TaskQC
333+
An optional QC class to instantiate.
334+
335+
Returns
336+
-------
337+
ibllib.qc.task_metrics.TaskQC
338+
The task QC object.
339+
"""
322340
if not self.extractor or trials_data is None:
323341
trials_data, _ = self._extract_behaviour(save=False)
324342
if not trials_data:
@@ -328,10 +346,11 @@ def _run_qc(self, trials_data=None, update=True):
328346
qc_extractor = TaskQCExtractor(self.session_path, lazy=True, sync_collection=self.sync_collection, one=self.one,
329347
sync_type=self.sync, task_collection=self.collection)
330348
qc_extractor.data = qc_extractor.rename_data(trials_data)
331-
if type(self.extractor).__name__ == 'HabituationTrials':
332-
qc = HabituationQC(self.session_path, one=self.one, log=_logger)
333-
else:
334-
qc = TaskQC(self.session_path, one=self.one, log=_logger)
349+
if not QC:
350+
QC = HabituationQC if type(self.extractor).__name__ == 'HabituationTrials' else TaskQC
351+
_logger.debug('Running QC with %s.%s', QC.__module__, QC.__name__)
352+
qc = QC(self.session_path, one=self.one, log=_logger)
353+
if QC is not HabituationQC:
335354
qc_extractor.wheel_encoding = 'X1'
336355
qc_extractor.settings = self.extractor.settings
337356
qc_extractor.frame_ttls, qc_extractor.audio_ttls = load_bpod_fronts(
@@ -412,7 +431,7 @@ def _extract_behaviour(self, save=True, **kwargs):
412431
task_collection=self.collection, protocol_number=self.protocol_number, **kwargs)
413432
return outputs, files
414433

415-
def _run_qc(self, trials_data=None, update=False, plot_qc=False):
434+
def _run_qc(self, trials_data=None, update=False, plot_qc=False, QC=None):
416435
if not self.extractor or trials_data is None:
417436
trials_data, _ = self._extract_behaviour(save=False)
418437
if not trials_data:
@@ -422,10 +441,11 @@ def _run_qc(self, trials_data=None, update=False, plot_qc=False):
422441
qc_extractor = TaskQCExtractor(self.session_path, lazy=True, sync_collection=self.sync_collection, one=self.one,
423442
sync_type=self.sync, task_collection=self.collection)
424443
qc_extractor.data = qc_extractor.rename_data(trials_data.copy())
425-
if type(self.extractor).__name__ == 'HabituationTrials':
426-
qc = HabituationQC(self.session_path, one=self.one, log=_logger)
427-
else:
428-
qc = TaskQC(self.session_path, one=self.one, log=_logger)
444+
if not QC:
445+
QC = HabituationQC if type(self.extractor).__name__ == 'HabituationTrials' else TaskQC
446+
_logger.debug('Running QC with %s.%s', QC.__module__, QC.__name__)
447+
qc = QC(self.session_path, one=self.one, log=_logger)
448+
if QC is not HabituationQC:
429449
# Add Bpod wheel data
430450
wheel_ts_bpod = self.extractor.bpod2fpga(self.extractor.bpod_trials['wheel_timestamps'])
431451
qc_extractor.data['wheel_timestamps_bpod'] = wheel_ts_bpod

ibllib/pipes/misc.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import sys
1010
import time
1111
import logging
12+
import warnings
1213
from functools import wraps
1314
from pathlib import Path
1415
from typing import Union, List, Callable, Any
@@ -365,6 +366,8 @@ def load_params_dict(params_fname: str) -> dict:
365366

366367

367368
def load_videopc_params():
369+
"""(DEPRECATED) This will be removed in favour of iblrigv8 functions."""
370+
warnings.warn('load_videopc_params will be removed in favour of iblrigv8', FutureWarning)
368371
if not load_params_dict("videopc_params"):
369372
create_videopc_params()
370373
return load_params_dict("videopc_params")
@@ -472,6 +475,9 @@ def create_basic_transfer_params(param_str='transfer_params', local_data_path=No
472475

473476

474477
def create_videopc_params(force=False, silent=False):
478+
"""(DEPRECATED) This will be removed in favour of iblrigv8 functions."""
479+
url = 'https://github.com/int-brain-lab/iblrig/blob/videopc/docs/source/video.rst'
480+
warnings.warn(f'create_videopc_params is deprecated, see {url}', DeprecationWarning)
475481
if Path(params.getfile("videopc_params")).exists() and not force:
476482
print(f"{params.getfile('videopc_params')} exists already, exiting...")
477483
print(Path(params.getfile("videopc_params")).exists())

ibllib/pipes/tasks.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -385,7 +385,7 @@ def assert_expected(self, expected_files, silent=False):
385385
everything_is_fine = True
386386
files = []
387387
for expected_file in expected_files:
388-
actual_files = list(Path(self.session_path).rglob(str(Path(expected_file[1]).joinpath(expected_file[0]))))
388+
actual_files = list(Path(self.session_path).rglob(str(Path(*filter(None, reversed(expected_file[:2]))))))
389389
if len(actual_files) == 0 and expected_file[2]:
390390
everything_is_fine = False
391391
if not silent:

ibllib/qc/task_metrics.py

Lines changed: 51 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -174,7 +174,7 @@ def compute(self, **kwargs):
174174
self.criteria['_task_passed_trial_checks'] = {'NOT_SET': 0}
175175

176176
self.log.info(f'Session {self.session_path}: Running QC on behavior data...')
177-
self.metrics, self.passed = get_bpodqc_metrics_frame(
177+
self.get_bpodqc_metrics_frame(
178178
self.extractor.data,
179179
wheel_gain=self.extractor.settings['STIM_GAIN'], # The wheel gain
180180
photodiode=self.extractor.frame_ttls,
@@ -183,7 +183,56 @@ def compute(self, **kwargs):
183183
min_qt=self.extractor.settings.get('QUIESCENT_PERIOD') or 0.2,
184184
audio_output=self.extractor.settings.get('device_sound', {}).get('OUTPUT', 'unknown')
185185
)
186-
return
186+
187+
def _get_checks(self):
188+
"""
189+
Find all methods that begin with 'check_'.
190+
191+
Returns
192+
-------
193+
Dict[str, function]
194+
A map of QC check function names and the corresponding functions that return `metric`
195+
(any), `passed` (bool).
196+
"""
197+
def is_metric(x):
198+
return isfunction(x) and x.__name__.startswith('check_')
199+
200+
return dict(getmembers(sys.modules[__name__], is_metric))
201+
202+
def get_bpodqc_metrics_frame(self, data, **kwargs):
203+
"""
204+
Evaluates all the QC metric functions in this module (those starting with 'check') and
205+
returns the results. The optional kwargs listed below are passed to each QC metric function.
206+
:param data: dict of extracted task data
207+
:param re_encoding: the encoding of the wheel data, X1, X2 or X4
208+
:param enc_res: the rotary encoder resolution
209+
:param wheel_gain: the STIM_GAIN task parameter
210+
:param photodiode: the fronts from Bpod's BNC1 input or FPGA frame2ttl channel
211+
:param audio: the fronts from Bpod's BNC2 input FPGA audio sync channel
212+
:param min_qt: the QUIESCENT_PERIOD task parameter
213+
:return metrics: dict of checks and their QC metrics
214+
:return passed: dict of checks and a float array of which samples passed
215+
"""
216+
217+
# Find all methods that begin with 'check_'
218+
checks = self._get_checks()
219+
prefix = '_task_' # Extended QC fields will start with this
220+
# Method 'check_foobar' stored with key '_task_foobar' in metrics map
221+
qc_metrics_map = {prefix + k[6:]: fn(data, **kwargs) for k, fn in checks.items()}
222+
223+
# Split metrics and passed frames
224+
self.metrics = {}
225+
self.passed = {}
226+
for k in qc_metrics_map:
227+
self.metrics[k], self.passed[k] = qc_metrics_map[k]
228+
229+
# Add a check for trial level pass: did a given trial pass all checks?
230+
n_trials = data['intervals'].shape[0]
231+
# Trial-level checks return an array the length that equals the number of trials
232+
trial_level_passed = [m for m in self.passed.values() if isinstance(m, Sized) and len(m) == n_trials]
233+
name = prefix + 'passed_trial_checks'
234+
self.metrics[name] = reduce(np.logical_and, trial_level_passed or (None, None))
235+
self.passed[name] = self.metrics[name].astype(float) if trial_level_passed else None
187236

188237
def run(self, update=False, namespace='task', **kwargs):
189238
"""
@@ -377,46 +426,6 @@ def compute(self, download_data=None, **kwargs):
377426
self.metrics, self.passed = (metrics, passed)
378427

379428

380-
def get_bpodqc_metrics_frame(data, **kwargs):
381-
"""
382-
Evaluates all the QC metric functions in this module (those starting with 'check') and
383-
returns the results. The optional kwargs listed below are passed to each QC metric function.
384-
:param data: dict of extracted task data
385-
:param re_encoding: the encoding of the wheel data, X1, X2 or X4
386-
:param enc_res: the rotary encoder resolution
387-
:param wheel_gain: the STIM_GAIN task parameter
388-
:param photodiode: the fronts from Bpod's BNC1 input or FPGA frame2ttl channel
389-
:param audio: the fronts from Bpod's BNC2 input FPGA audio sync channel
390-
:param min_qt: the QUIESCENT_PERIOD task parameter
391-
:return metrics: dict of checks and their QC metrics
392-
:return passed: dict of checks and a float array of which samples passed
393-
"""
394-
def is_metric(x):
395-
return isfunction(x) and x.__name__.startswith('check_')
396-
# Find all methods that begin with 'check_'
397-
checks = getmembers(sys.modules[__name__], is_metric)
398-
prefix = '_task_' # Extended QC fields will start with this
399-
# Method 'check_foobar' stored with key '_task_foobar' in metrics map
400-
qc_metrics_map = {prefix + k[6:]: fn(data, **kwargs) for k, fn in checks}
401-
402-
# Split metrics and passed frames
403-
metrics = {}
404-
passed = {}
405-
for k in qc_metrics_map:
406-
metrics[k], passed[k] = qc_metrics_map[k]
407-
408-
# Add a check for trial level pass: did a given trial pass all checks?
409-
n_trials = data['intervals'].shape[0]
410-
# Trial-level checks return an array the length that equals the number of trials
411-
trial_level_passed = [m for m in passed.values()
412-
if isinstance(m, Sized) and len(m) == n_trials]
413-
name = prefix + 'passed_trial_checks'
414-
metrics[name] = reduce(np.logical_and, trial_level_passed or (None, None))
415-
passed[name] = metrics[name].astype(float) if trial_level_passed else None
416-
417-
return metrics, passed
418-
419-
420429
# SINGLE METRICS
421430
# ---------------------------------------------------------------------------- #
422431

0 commit comments

Comments
 (0)