Skip to content

Commit 227c20d

Browse files
authored
Merge pull request #1058 from int-brain-lab/datahandler
allow custom data handler to be passed to a task
2 parents 2157109 + fd9f1dc commit 227c20d

File tree

5 files changed

+50
-33
lines changed

5 files changed

+50
-33
lines changed

brainbox/tests/test_behavior.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -259,7 +259,7 @@ def test_get_movement_onset(self):
259259
expected = [np.nan, 79.66293334, 100.73593334, 129.26693334, np.nan]
260260
np.testing.assert_array_almost_equal(times, expected)
261261
with self.assertRaises(ValueError):
262-
wheel.get_movement_onset(intervals, np.random.permutation(self.trials['feedback_times']))
262+
wheel.get_movement_onset(intervals, np.flipud(self.trials['feedback_times']))
263263

264264

265265
class TestTraining(unittest.TestCase):

ibllib/oneibl/patcher.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -243,8 +243,10 @@ def __init__(self, client_name='default', one=None, label='ibllib patch'):
243243
# transfers/delete from the current computer to the flatiron: mandatory and executed first
244244
local_id = self.endpoints['local']['id']
245245
self.globus_transfer = globus_sdk.TransferData(
246-
self.client, local_id, flatiron_id, verify_checksum=True, sync_level='checksum', label=label)
247-
self.globus_delete = globus_sdk.DeleteData(self.client, flatiron_id, label=label)
246+
source_endpoint=local_id, destination_endpoint=flatiron_id,
247+
verify_checksum=True, sync_level='checksum', label=label
248+
)
249+
self.globus_delete = globus_sdk.DeleteData(endpoint=flatiron_id, label=label)
248250
# transfers/delete from flatiron to optional third parties to synchronize / delete
249251
self.globus_transfers_locals = {}
250252
self.globus_deletes_locals = {}
@@ -303,7 +305,7 @@ def patch_datasets(self, file_list, **kwargs):
303305
# if there is no transfer already created, initialize it
304306
if repo_gid not in self.globus_transfers_locals:
305307
self.globus_transfers_locals[repo_gid] = globus_sdk.TransferData(
306-
self.client, flatiron_id, repo_gid, verify_checksum=True,
308+
source_endpoint=flatiron_id, destination_endpoint=repo_gid, verify_checksum=True,
307309
sync_level='checksum', label=f"{self.label} on {fr['data_repository']}")
308310
# get the local server path and create the transfer item
309311
local_server_path = self.to_address(fr['relative_path'], fr['data_repository'])
@@ -343,17 +345,15 @@ def _wait_for_task(resp):
343345
_wait_for_task(gtc.submit_transfer(self.globus_transfer))
344346
# re-initialize the globus_transfer property
345347
self.globus_transfer = globus_sdk.TransferData(
346-
gtc,
347-
self.globus_transfer['source_endpoint'],
348-
self.globus_transfer['destination_endpoint'],
348+
source_endpoint=self.globus_transfer['source_endpoint'],
349+
destination_endpoint=self.globus_transfer['destination_endpoint'],
349350
label=self.globus_transfer['label'],
350351
verify_checksum=True, sync_level='checksum')
351352

352353
# do the same for deletes
353354
if len(self.globus_delete['DATA']) > 0:
354355
_wait_for_task(gtc.submit_delete(self.globus_delete))
355356
self.globus_delete = globus_sdk.DeleteData(
356-
gtc,
357357
endpoint=self.globus_delete['endpoint'],
358358
label=self.globus_delete['label'])
359359

ibllib/pipes/tasks.py

Lines changed: 29 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,8 @@ class Task(abc.ABC):
116116
on_error = 'continue' # whether to raise an exception on error ('raise') or report the error and continue ('continue')
117117

118118
def __init__(self, session_path, parents=None, taskid=None, one=None,
119-
machine=None, clobber=True, location='server', scratch_folder=None, on_error='continue', force=False, **kwargs):
119+
machine=None, clobber=True, location='server', scratch_folder=None, on_error='continue',
120+
force=False, data_handler_class=None, **kwargs):
120121
"""
121122
Base task class
122123
:param session_path: session path
@@ -127,7 +128,10 @@ def __init__(self, session_path, parents=None, taskid=None, one=None,
127128
:param clobber: whether or not to overwrite log on rerun
128129
:param location: location where task is run. Options are 'server' (lab local servers'), 'remote' (remote compute node,
129130
data required for task downloaded via one), 'AWS' (remote compute node, data required for task downloaded via AWS),
130-
or 'SDSC' (SDSC flatiron compute node)
131+
or 'SDSC' (SDSC flatiron compute node). The data_handler_class parameter will override the location in
132+
determining the handler class.
133+
:param data_handler_class: custom class to handle data. If not provided, the location will
134+
be used to infer the handler class.
131135
:param scratch_folder: optional: Path where to write intermediate temporary data
132136
:param force: whether to re-download missing input files on local server if not present
133137
:param args: running arguments
@@ -149,6 +153,7 @@ def __init__(self, session_path, parents=None, taskid=None, one=None,
149153
self.plot_tasks = [] # Plotting task/ tasks to create plot outputs during the task
150154
self.scratch_folder = scratch_folder
151155
self.kwargs = kwargs
156+
self.data_handler_class = data_handler_class
152157

153158
@property
154159
def signature(self) -> Dict[str, List]:
@@ -551,27 +556,28 @@ def get_data_handler(self, location=None):
551556
Gets the relevant data handler based on location argument
552557
:return:
553558
"""
554-
location = str.lower(location or self.location)
555-
if location == 'local':
556-
return data_handlers.LocalDataHandler(self.session_path, self.signature, one=self.one)
557-
self.one = self.one or ONE()
558-
if location == 'server':
559-
dhandler = data_handlers.ServerDataHandler(self.session_path, self.signature, one=self.one)
560-
elif location == 'serverglobus':
561-
dhandler = data_handlers.ServerGlobusDataHandler(self.session_path, self.signature, one=self.one)
562-
elif location == 'remote':
563-
dhandler = data_handlers.RemoteHttpDataHandler(self.session_path, self.signature, one=self.one)
564-
elif location == 'aws':
565-
dhandler = data_handlers.RemoteAwsDataHandler(self.session_path, self.signature, one=self.one)
566-
elif location == 'sdsc':
567-
dhandler = data_handlers.SDSCDataHandler(self.session_path, self.signature, one=self.one)
568-
elif location == 'popeye':
569-
dhandler = data_handlers.PopeyeDataHandler(self.session_path, self.signature, one=self.one)
570-
elif location == 'ec2':
571-
dhandler = data_handlers.RemoteEC2DataHandler(self.session_path, self.signature, one=self.one)
572-
else:
573-
raise ValueError(f'Unknown location "{location}"')
574-
return dhandler
559+
if self.data_handler_class is None:
560+
location = str.lower(location or self.location)
561+
if location == 'local':
562+
return data_handlers.LocalDataHandler(self.session_path, self.signature, one=self.one)
563+
self.one = self.one or ONE()
564+
if location == 'server':
565+
self.data_handler_class = data_handlers.ServerDataHandler
566+
elif location == 'serverglobus':
567+
self.data_handler_class = data_handlers.ServerGlobusDataHandler
568+
elif location == 'remote':
569+
self.data_handler_class = data_handlers.RemoteHttpDataHandler
570+
elif location == 'aws':
571+
self.data_handler_class = data_handlers.RemoteAwsDataHandler
572+
elif location == 'sdsc':
573+
self.data_handler_class = data_handlers.SDSCDataHandler
574+
elif location == 'popeye':
575+
self.data_handler_class = data_handlers.PopeyeDataHandler
576+
elif location == 'ec2':
577+
self.data_handler_class = data_handlers.RemoteEC2DataHandler
578+
else:
579+
raise ValueError(f'Unknown location "{location}"')
580+
return self.data_handler_class(self.session_path, self.signature, one=self.one)
575581

576582
@staticmethod
577583
def make_lock_file(taskname='', time_out_secs=7200):

ibllib/tests/test_tasks.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
import numpy as np
1111

1212
import ibllib.pipes.tasks
13+
import ibllib.oneibl.data_handlers
1314
from ibllib.pipes.base_tasks import ExperimentDescriptionRegisterRaw
1415
from ibllib.pipes.video_tasks import VideoConvert
1516
from ibllib.io import session_params
@@ -401,6 +402,16 @@ def test_input_files_to_register(self):
401402
for f in ('alf/foo.bar.*', 'alf/bar.baz.npy', 'alf/baz.foo.npy'):
402403
self.assertRegex(cm.output[-1], re.escape(f))
403404

405+
def test_location_data_handler(self):
406+
task = Task00(self.session_path)
407+
self.assertIsInstance(task.get_data_handler(), ibllib.oneibl.data_handlers.ServerDataHandler)
408+
409+
class TotoDataHandler(ibllib.oneibl.data_handlers.DataHandler):
410+
pass
411+
412+
task = Task00(self.session_path, data_handler_class=TotoDataHandler)
413+
self.assertIsInstance(task.get_data_handler(), TotoDataHandler)
414+
404415

405416
class TestMisc(unittest.TestCase):
406417
"""Tests for misc functions in ibllib.pipes.tasks module."""

requirements.txt

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ boto3
22
click>=7.0.0
33
colorlog>=4.0.2
44
flake8>=3.7.8
5-
globus-sdk
5+
globus-sdk>=4.0.0
66
graphviz
77
matplotlib>=3.0.3
88
numba>=0.56
@@ -28,7 +28,7 @@ iblutil>=1.13.0
2828
iblqt>=0.8.2
2929
mtscomp>=1.0.1
3030
ONE-api>=3.2.0
31-
phylib>=2.6.0
31+
phylib>=2.6.2
3232
psychofit
3333
slidingRP>=1.1.1 # steinmetz lab refractory period metrics
3434
pyqt5

0 commit comments

Comments
 (0)