Skip to content

Commit 908f530

Browse files
authored
Merge pull request #741 from int-brain-lab/taskRegistration
Task registration
2 parents f4c02c8 + de2bf8e commit 908f530

18 files changed

+678
-391
lines changed

ibllib/io/extractors/biased_trials.py

Lines changed: 18 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,7 @@
1717

1818

1919
class ContrastLR(BaseBpodTrialsExtractor):
20-
"""
21-
Get left and right contrasts from raw datafile.
22-
"""
20+
"""Get left and right contrasts from raw datafile."""
2321
save_names = ('_ibl_trials.contrastLeft.npy', '_ibl_trials.contrastRight.npy')
2422
var_names = ('contrastLeft', 'contrastRight')
2523

@@ -32,9 +30,7 @@ def _extract(self, **kwargs):
3230

3331

3432
class ProbaContrasts(BaseBpodTrialsExtractor):
35-
"""
36-
Bpod pre-generated values for probabilityLeft, contrastLR, phase, quiescence
37-
"""
33+
"""Bpod pre-generated values for probabilityLeft, contrastLR, phase, quiescence."""
3834
save_names = ('_ibl_trials.contrastLeft.npy', '_ibl_trials.contrastRight.npy', None, None,
3935
'_ibl_trials.probabilityLeft.npy', '_ibl_trials.quiescencePeriod.npy')
4036
var_names = ('contrastLeft', 'contrastRight', 'phase',
@@ -103,10 +99,12 @@ class TrialsTableBiased(BaseBpodTrialsExtractor):
10399
'wheelMoves_peakAmplitude', 'peakVelocity_times', 'is_final_movement')
104100

105101
def _extract(self, extractor_classes=None, **kwargs):
102+
extractor_classes = extractor_classes or []
106103
base = [Intervals, GoCueTimes, ResponseTimes, Choice, StimOnOffFreezeTimes, ContrastLR, FeedbackTimes, FeedbackType,
107104
RewardVolume, ProbabilityLeft, Wheel]
108-
out, _ = run_extractor_classes(base, session_path=self.session_path, bpod_trials=self.bpod_trials, settings=self.settings,
109-
save=False, task_collection=self.task_collection)
105+
out, _ = run_extractor_classes(
106+
base + extractor_classes, session_path=self.session_path, bpod_trials=self.bpod_trials,
107+
settings=self.settings, save=False, task_collection=self.task_collection)
110108

111109
table = AlfBunch({k: out.pop(k) for k in list(out.keys()) if k not in self.var_names})
112110
assert len(table.keys()) == 12
@@ -130,11 +128,13 @@ class TrialsTableEphys(BaseBpodTrialsExtractor):
130128
'phase', 'position', 'quiescence')
131129

132130
def _extract(self, extractor_classes=None, **kwargs):
131+
extractor_classes = extractor_classes or []
133132
base = [Intervals, GoCueTimes, ResponseTimes, Choice, StimOnOffFreezeTimes, ProbaContrasts,
134133
FeedbackTimes, FeedbackType, RewardVolume, Wheel]
135134
# Exclude from trials table
136-
out, _ = run_extractor_classes(base, session_path=self.session_path, bpod_trials=self.bpod_trials, settings=self.settings,
137-
save=False, task_collection=self.task_collection)
135+
out, _ = run_extractor_classes(
136+
base + extractor_classes, session_path=self.session_path, bpod_trials=self.bpod_trials,
137+
settings=self.settings, save=False, task_collection=self.task_collection)
138138
table = AlfBunch({k: v for k, v in out.items() if k not in self.var_names})
139139
assert len(table.keys()) == 12
140140

@@ -158,11 +158,13 @@ class BiasedTrials(BaseBpodTrialsExtractor):
158158
'phase', 'position', 'quiescence')
159159

160160
def _extract(self, extractor_classes=None, **kwargs) -> dict:
161+
extractor_classes = extractor_classes or []
161162
base = [GoCueTriggerTimes, StimOnTriggerTimes, ItiInTimes, StimOffTriggerTimes, StimFreezeTriggerTimes,
162163
ErrorCueTriggerTimes, TrialsTableBiased, IncludedTrials, PhasePosQuiescence]
163164
# Exclude from trials table
164-
out, _ = run_extractor_classes(base, session_path=self.session_path, bpod_trials=self.bpod_trials, settings=self.settings,
165-
save=False, task_collection=self.task_collection)
165+
out, _ = run_extractor_classes(
166+
base + extractor_classes, session_path=self.session_path, bpod_trials=self.bpod_trials,
167+
settings=self.settings, save=False, task_collection=self.task_collection)
166168
return {k: out[k] for k in self.var_names}
167169

168170

@@ -181,13 +183,15 @@ class EphysTrials(BaseBpodTrialsExtractor):
181183
'phase', 'position', 'quiescence')
182184

183185
def _extract(self, extractor_classes=None, **kwargs) -> dict:
186+
extractor_classes = extractor_classes or []
184187
base = [GoCueTriggerTimes, StimOnTriggerTimes, ItiInTimes, StimOffTriggerTimes, StimFreezeTriggerTimes,
185188
ErrorCueTriggerTimes, TrialsTableEphys, IncludedTrials, PhasePosQuiescence]
186189
# Get all detected TTLs. These are stored for QC purposes
187190
self.frame2ttl, self.audio = raw.load_bpod_fronts(self.session_path, data=self.bpod_trials)
188191
# Exclude from trials table
189-
out, _ = run_extractor_classes(base, session_path=self.session_path, bpod_trials=self.bpod_trials, settings=self.settings,
190-
save=False, task_collection=self.task_collection)
192+
out, _ = run_extractor_classes(
193+
base + extractor_classes, session_path=self.session_path, bpod_trials=self.bpod_trials,
194+
settings=self.settings, save=False, task_collection=self.task_collection)
191195
return {k: out[k] for k in self.var_names}
192196

193197

ibllib/io/extractors/extractor_types.json

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,7 @@
1-
{"ksocha_ephysOptoStimulation": "ephys_passive_opto",
1+
{"THIS FILE": "SHOULD NO LONGER BE USED!",
2+
"SEE": "https://github.com/int-brain-lab/project_extraction?tab=readme-ov-file#project_extraction",
3+
"********": "*******************************",
4+
"ksocha_ephysOptoStimulation": "ephys_passive_opto",
25
"ksocha_ephysOptoChoiceWorld": "ephys_biased_opto",
36
"passiveChoiceWorld": "ephys_replay",
47
"opto_ephysChoiceWorld": "ephys_biased_opto",
@@ -17,5 +20,7 @@
1720
"_habituationChoiceWorld": "habituation",
1821
"_trainingChoiceWorld": "training",
1922
"ephysMockChoiceWorld": "mock_ephys",
20-
"ephys_certification": "sync_ephys"
23+
"ephys_certification": "sync_ephys",
24+
"trainingPhaseChoiceWorld": "training",
25+
"************": "*********************"
2126
}

ibllib/io/extractors/task_extractor_map.json

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,7 @@
1-
{"ephysChoiceWorld": "EphysTrials",
1+
{"!!THIS FILE": "SHOULD NOT BE EDITED...",
2+
"SEE": "PROJECT EXTRACTION REPO!!",
3+
"************": "**********************",
4+
"ephysChoiceWorld": "EphysTrials",
25
"_biasedChoiceWorld": "BiasedTrials",
36
"_habituationChoiceWorld": "HabituationTrials",
47
"_trainingChoiceWorld": "TrainingTrials",

ibllib/oneibl/data_handlers.py

Lines changed: 56 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -34,16 +34,14 @@ def __init__(self, session_path, signature, one=None):
3434
self.session_path = session_path
3535
self.signature = signature
3636
self.one = one
37+
self.processed = {} # Map of filepaths and their processed records (e.g. upload receipts or Alyx records)
3738

3839
def setUp(self):
3940
"""Function to optionally overload to download required data to run task."""
4041
pass
4142

4243
def getData(self, one=None):
43-
"""
44-
Finds the datasets required for task based on input signatures
45-
:return:
46-
"""
44+
"""Finds the datasets required for task based on input signatures."""
4745
if self.one is None and one is None:
4846
return
4947

@@ -60,6 +58,22 @@ def getData(self, one=None):
6058
df = df.droplevel(level='eid')
6159
return df
6260

61+
def getOutputFiles(self):
62+
assert self.session_path
63+
from one.alf.io import iter_datasets
64+
# Next convert datasets to frame
65+
from one.alf.cache import DATASETS_COLUMNS, _get_dataset_info
66+
# Create dataframe of all ALF datasets
67+
dsets = iter_datasets(self.session_path)
68+
records = [_get_dataset_info(self.session_path, dset, compute_hash=False) for dset in dsets]
69+
df = pd.DataFrame(records, columns=DATASETS_COLUMNS)
70+
from functools import partial
71+
filt = partial(filter_datasets, df, wildcards=True, assert_unique=False)
72+
# Filter outputs
73+
dids = pd.concat(filt(filename=file[0], collection=file[1]).index for file in self.signature['output_files'])
74+
present = df.loc[dids, :].copy()
75+
return present
76+
6377
def uploadData(self, outputs, version):
6478
"""
6579
Function to optionally overload to upload and register data
@@ -75,10 +89,7 @@ def uploadData(self, outputs, version):
7589
return versions
7690

7791
def cleanUp(self):
78-
"""
79-
Function to optionally overload to cleanup files after running task
80-
:return:
81-
"""
92+
"""Function to optionally overload to clean up files after running task."""
8293
pass
8394

8495

@@ -104,16 +115,47 @@ def __init__(self, session_path, signatures, one=None):
104115
"""
105116
super().__init__(session_path, signatures, one=one)
106117

107-
def uploadData(self, outputs, version, **kwargs):
118+
def uploadData(self, outputs, version, clobber=False, **kwargs):
108119
"""
109-
Function to upload and register data of completed task
110-
:param outputs: output files from task to register
111-
:param version: ibllib version
112-
:return: output info of registered datasets
120+
Upload and/or register output data.
121+
122+
This is typically called by :meth:`ibllib.pipes.tasks.Task.register_datasets`.
123+
124+
Parameters
125+
----------
126+
outputs : list of pathlib.Path
127+
A set of ALF paths to register to Alyx.
128+
version : str, list of str
129+
The version of ibllib used to generate these output files.
130+
clobber : bool
131+
If True, re-upload outputs that have already been passed to this method.
132+
kwargs
133+
Optional keyword arguments for one.registration.RegistrationClient.register_files.
134+
135+
Returns
136+
-------
137+
list of dicts, dict
138+
A list of newly created Alyx dataset records or the registration data if dry.
113139
"""
114140
versions = super().uploadData(outputs, version)
115141
data_repo = get_local_data_repository(self.one.alyx)
116-
return register_dataset(outputs, one=self.one, versions=versions, repository=data_repo, **kwargs)
142+
# If clobber = False, do not re-upload the outputs that have already been processed
143+
if not isinstance(outputs, list):
144+
outputs = [outputs]
145+
to_upload = list(filter(None if clobber else lambda x: x not in self.processed, outputs))
146+
records = register_dataset(to_upload, one=self.one, versions=versions, repository=data_repo, **kwargs) or []
147+
if kwargs.get('dry', False):
148+
return records
149+
# Store processed outputs
150+
self.processed.update({k: v for k, v in zip(to_upload, records) if v})
151+
return [self.processed[x] for x in outputs if x in self.processed]
152+
153+
def cleanUp(self):
154+
"""Empties and returns the processed dataset mep."""
155+
super().cleanUp()
156+
processed = self.processed
157+
self.processed = {}
158+
return processed
117159

118160

119161
class ServerGlobusDataHandler(DataHandler):

ibllib/pipes/behavior_tasks.py

Lines changed: 18 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,11 @@
88
from one.api import ONE
99

1010
from ibllib.oneibl.registration import get_lab
11+
from ibllib.oneibl.data_handlers import ServerDataHandler
1112
from ibllib.pipes import base_tasks
1213
from ibllib.io.raw_data_loaders import load_settings, load_bpod_fronts
1314
from ibllib.qc.task_extractors import TaskQCExtractor
14-
from ibllib.qc.task_metrics import HabituationQC, TaskQC
15+
from ibllib.qc.task_metrics import HabituationQC, TaskQC, update_dataset_qc
1516
from ibllib.io.extractors.ephys_passive import PassiveChoiceWorld
1617
from ibllib.io.extractors.bpod_trials import get_bpod_extractor
1718
from ibllib.io.extractors.ephys_fpga import FpgaTrials, FpgaTrialsHabituation, get_sync_and_chn_map
@@ -72,9 +73,7 @@ def signature(self):
7273
return signature
7374

7475
def _run(self, update=True, save=True):
75-
"""
76-
Extracts an iblrig training session
77-
"""
76+
"""Extracts an iblrig training session."""
7877
trials, output_files = self.extract_behaviour(save=save)
7978

8079
if trials is None:
@@ -296,7 +295,7 @@ def signature(self):
296295
}
297296
return signature
298297

299-
def _run(self, update=True, save=True):
298+
def _run(self, update=True, save=True, **kwargs):
300299
"""Extracts an iblrig training session."""
301300
trials, output_files = self.extract_behaviour(save=save)
302301
if trials is None:
@@ -305,7 +304,16 @@ def _run(self, update=True, save=True):
305304
return output_files
306305

307306
# Run the task QC
308-
self.run_qc(trials)
307+
qc = self.run_qc(trials, update=update, **kwargs)
308+
if update and not self.one.offline:
309+
on_server = self.location == 'server' and isinstance(self.data_handler, ServerDataHandler)
310+
if not on_server:
311+
_logger.warning('Updating dataset QC only supported on local servers')
312+
else:
313+
labs = get_lab(self.session_path, self.one.alyx)
314+
# registered_dsets = self.register_datasets(labs=labs)
315+
datasets = self.data_handler.uploadData(output_files, self.version, labs=labs)
316+
update_dataset_qc(qc, datasets, self.one)
309317

310318
return output_files
311319

@@ -467,14 +475,11 @@ def run_qc(self, trials_data=None, update=False, plot_qc=False, QC=None):
467475
return qc
468476

469477
def _run(self, update=True, plot_qc=True, save=True):
470-
dsets, out_files = self.extract_behaviour(save=save)
471-
472-
if not self.one or self.one.offline:
473-
return out_files
478+
output_files = super()._run(update=update, save=save, plot_qc=plot_qc)
479+
if update and not self.one.offline:
480+
self._behaviour_criterion(update=update)
474481

475-
self._behaviour_criterion(update=update)
476-
self.run_qc(dsets, update=update, plot_qc=plot_qc)
477-
return out_files
482+
return output_files
478483

479484

480485
class ChoiceWorldTrialsTimeline(ChoiceWorldTrialsNidq):

ibllib/pipes/dynamic_pipeline.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -215,10 +215,11 @@ def make_pipeline(session_path, **pkwargs):
215215
for sync_option in ('nidq', 'bpod'):
216216
if sync_option in extractor.lower() and not sync == sync_option:
217217
raise ValueError(f'Extractor "{extractor}" and sync "{sync}" do not match')
218+
# TODO Assert sync_label correct here (currently unused)
218219
# Look for the extractor in the behavior extractors module
219220
if hasattr(btasks, extractor):
220221
task = getattr(btasks, extractor)
221-
# This may happen that the extractor is tied to a specific sync task: look for TrialsChoiceWorldBpod for # example
222+
# This may happen that the extractor is tied to a specific sync task: look for TrialsChoiceWorldBpod for example
222223
elif hasattr(btasks, extractor + sync.capitalize()):
223224
task = getattr(btasks, extractor + sync.capitalize())
224225
else:
@@ -229,6 +230,8 @@ def make_pipeline(session_path, **pkwargs):
229230
else:
230231
raise NotImplementedError(
231232
f'Extractor "{extractor}" not found in main IBL pipeline nor in personal projects')
233+
_logger.debug('%s (protocol #%i, task #%i) = %s.%s',
234+
protocol, i, j, task.__module__, task.__name__)
232235
# Rename the class to something more informative
233236
task_name = f'{task.__name__}_{i:02}'
234237
# For now we assume that the second task in the list is always the trials extractor, which is dependent

ibllib/pipes/tasks.py

Lines changed: 15 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -95,15 +95,15 @@
9595

9696

9797
class Task(abc.ABC):
98-
log = '' # place holder to keep the log of the task for registration
98+
log = '' # placeholder to keep the log of the task for registration
9999
cpu = 1 # CPU resource
100100
gpu = 0 # GPU resources: as of now, either 0 or 1
101101
io_charge = 5 # integer percentage
102102
priority = 30 # integer percentage, 100 means highest priority
103103
ram = 4 # RAM needed to run (GB)
104104
one = None # one instance (optional)
105105
level = 0 # level in the pipeline hierarchy: level 0 means there is no parent task
106-
outputs = None # place holder for a list of Path containing output files
106+
outputs = None # placeholder for a list of Path containing output files
107107
time_elapsed_secs = None
108108
time_out_secs = 3600 * 2 # time-out after which a task is considered dead
109109
version = ibllib.__version__
@@ -245,16 +245,21 @@ def run(self, **kwargs):
245245
self.tearDown()
246246
return self.status
247247

248-
def register_datasets(self, one=None, **kwargs):
248+
def register_datasets(self, **kwargs):
249249
"""
250-
Register output datasets form the task to Alyx
251-
:param one:
252-
:param jobid:
253-
:param kwargs: directly passed to the register_dataset function
254-
:return:
250+
Register output datasets from the task to Alyx.
251+
252+
Parameters
253+
----------
254+
kwargs
255+
Directly passed to the `DataHandler.upload_data` method.
256+
257+
Returns
258+
-------
259+
list
260+
The output of the `DataHandler.upload_data` method, e.g. a list of registered datasets.
255261
"""
256262
_ = self.register_images()
257-
258263
return self.data_handler.uploadData(self.outputs, self.version, **kwargs)
259264

260265
def register_images(self, **kwargs):
@@ -737,7 +742,7 @@ def run_alyx_task(tdict=None, session_path=None, one=None, job_deck=None,
737742
# otherwise register data and set (provisional) status to Complete
738743
else:
739744
try:
740-
kwargs = dict(one=one, max_md5_size=max_md5_size)
745+
kwargs = dict(max_md5_size=max_md5_size)
741746
if location == 'server':
742747
# Explicitly pass lab as lab cannot be inferred from path (which the registration client tries to do).
743748
# To avoid making extra REST requests we can also set labs=None if using ONE v1.20.1.

0 commit comments

Comments
 (0)