Skip to content

Commit 5932080

Browse files
committed
allow custom data handler to be passed to a task
1 parent 5af9425 commit 5932080

File tree

2 files changed

+40
-23
lines changed

2 files changed

+40
-23
lines changed

ibllib/pipes/tasks.py

Lines changed: 29 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,8 @@ class Task(abc.ABC):
116116
on_error = 'continue' # whether to raise an exception on error ('raise') or report the error and continue ('continue')
117117

118118
def __init__(self, session_path, parents=None, taskid=None, one=None,
119-
machine=None, clobber=True, location='server', scratch_folder=None, on_error='continue', force=False, **kwargs):
119+
machine=None, clobber=True, location='server', scratch_folder=None, on_error='continue',
120+
force=False, data_handler_class=None, **kwargs):
120121
"""
121122
Base task class
122123
:param session_path: session path
@@ -127,7 +128,10 @@ def __init__(self, session_path, parents=None, taskid=None, one=None,
127128
:param clobber: whether or not to overwrite log on rerun
128129
:param location: location where task is run. Options are 'server' (lab local servers'), 'remote' (remote compute node,
129130
data required for task downloaded via one), 'AWS' (remote compute node, data required for task downloaded via AWS),
130-
or 'SDSC' (SDSC flatiron compute node)
131+
or 'SDSC' (SDSC flatiron compute node). The data_handler_class parameter will override the location in
132+
determining the handler class.
133+
:param data_handler_class: custom class to handle data. If not provided, the location will
134+
be used to infer the handler class.
131135
:param scratch_folder: optional: Path where to write intermediate temporary data
132136
:param force: whether to re-download missing input files on local server if not present
133137
:param args: running arguments
@@ -149,6 +153,7 @@ def __init__(self, session_path, parents=None, taskid=None, one=None,
149153
self.plot_tasks = [] # Plotting task/ tasks to create plot outputs during the task
150154
self.scratch_folder = scratch_folder
151155
self.kwargs = kwargs
156+
self.data_handler_class = data_handler_class
152157

153158
@property
154159
def signature(self) -> Dict[str, List]:
@@ -551,27 +556,28 @@ def get_data_handler(self, location=None):
551556
Gets the relevant data handler based on location argument
552557
:return:
553558
"""
554-
location = str.lower(location or self.location)
555-
if location == 'local':
556-
return data_handlers.LocalDataHandler(self.session_path, self.signature, one=self.one)
557-
self.one = self.one or ONE()
558-
if location == 'server':
559-
dhandler = data_handlers.ServerDataHandler(self.session_path, self.signature, one=self.one)
560-
elif location == 'serverglobus':
561-
dhandler = data_handlers.ServerGlobusDataHandler(self.session_path, self.signature, one=self.one)
562-
elif location == 'remote':
563-
dhandler = data_handlers.RemoteHttpDataHandler(self.session_path, self.signature, one=self.one)
564-
elif location == 'aws':
565-
dhandler = data_handlers.RemoteAwsDataHandler(self.session_path, self.signature, one=self.one)
566-
elif location == 'sdsc':
567-
dhandler = data_handlers.SDSCDataHandler(self.session_path, self.signature, one=self.one)
568-
elif location == 'popeye':
569-
dhandler = data_handlers.PopeyeDataHandler(self.session_path, self.signature, one=self.one)
570-
elif location == 'ec2':
571-
dhandler = data_handlers.RemoteEC2DataHandler(self.session_path, self.signature, one=self.one)
572-
else:
573-
raise ValueError(f'Unknown location "{location}"')
574-
return dhandler
559+
if self.data_handler_class is None:
560+
location = str.lower(location or self.location)
561+
if location == 'local':
562+
return data_handlers.LocalDataHandler(self.session_path, self.signature, one=self.one)
563+
self.one = self.one or ONE()
564+
if location == 'server':
565+
self.data_handler_class = data_handlers.ServerDataHandler
566+
elif location == 'serverglobus':
567+
self.data_handler_class = data_handlers.ServerGlobusDataHandler
568+
elif location == 'remote':
569+
self.data_handler_class = data_handlers.RemoteHttpDataHandler
570+
elif location == 'aws':
571+
self.data_handler_class = data_handlers.RemoteAwsDataHandler
572+
elif location == 'sdsc':
573+
self.data_handler_class = data_handlers.SDSCDataHandler
574+
elif location == 'popeye':
575+
self.data_handler_class = data_handlers.PopeyeDataHandler
576+
elif location == 'ec2':
577+
self.data_handler_class = data_handlers.RemoteEC2DataHandler
578+
else:
579+
raise ValueError(f'Unknown location "{location}"')
580+
return self.data_handler_class(self.session_path, self.signature, one=self.one)
575581

576582
@staticmethod
577583
def make_lock_file(taskname='', time_out_secs=7200):

ibllib/tests/test_tasks.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
import numpy as np
1111

1212
import ibllib.pipes.tasks
13+
import ibllib.oneibl.data_handlers
1314
from ibllib.pipes.base_tasks import ExperimentDescriptionRegisterRaw
1415
from ibllib.pipes.video_tasks import VideoConvert
1516
from ibllib.io import session_params
@@ -401,6 +402,16 @@ def test_input_files_to_register(self):
401402
for f in ('alf/foo.bar.*', 'alf/bar.baz.npy', 'alf/baz.foo.npy'):
402403
self.assertRegex(cm.output[-1], re.escape(f))
403404

405+
def test_location_data_handler(self):
406+
task = Task00(self.session_path)
407+
self.assertIsInstance(task.get_data_handler(), ibllib.oneibl.data_handlers.ServerDataHandler)
408+
409+
class TotoDataHandler(ibllib.oneibl.data_handlers.DataHandler):
410+
pass
411+
412+
task = Task00(self.session_path, data_handler_class=TotoDataHandler)
413+
self.assertIsInstance(task.get_data_handler(), TotoDataHandler)
414+
404415

405416
class TestMisc(unittest.TestCase):
406417
"""Tests for misc functions in ibllib.pipes.tasks module."""

0 commit comments

Comments
 (0)