@@ -116,7 +116,8 @@ class Task(abc.ABC):
116116 on_error = 'continue' # whether to raise an exception on error ('raise') or report the error and continue ('continue')
117117
118118 def __init__ (self , session_path , parents = None , taskid = None , one = None ,
119- machine = None , clobber = True , location = 'server' , scratch_folder = None , on_error = 'continue' , force = False , ** kwargs ):
119+ machine = None , clobber = True , location = 'server' , scratch_folder = None , on_error = 'continue' ,
120+ force = False , data_handler_class = None , ** kwargs ):
120121 """
121122 Base task class
122123 :param session_path: session path
@@ -127,7 +128,10 @@ def __init__(self, session_path, parents=None, taskid=None, one=None,
127128 :param clobber: whether or not to overwrite log on rerun
128129 :param location: location where task is run. Options are 'server' (lab local servers'), 'remote' (remote compute node,
129130 data required for task downloaded via one), 'AWS' (remote compute node, data required for task downloaded via AWS),
130- or 'SDSC' (SDSC flatiron compute node)
131+ or 'SDSC' (SDSC flatiron compute node). The data_handler_class parameter will override the location in
132+ determining the handler class.
133+ :param data_handler_class: custom class to handle data. If not provided, the location will
134+ be used to infer the handler class.
131135 :param scratch_folder: optional: Path where to write intermediate temporary data
132136 :param force: whether to re-download missing input files on local server if not present
133137 :param args: running arguments
@@ -149,6 +153,7 @@ def __init__(self, session_path, parents=None, taskid=None, one=None,
149153 self .plot_tasks = [] # Plotting task/ tasks to create plot outputs during the task
150154 self .scratch_folder = scratch_folder
151155 self .kwargs = kwargs
156+ self .data_handler_class = data_handler_class
152157
153158 @property
154159 def signature (self ) -> Dict [str , List ]:
@@ -551,27 +556,28 @@ def get_data_handler(self, location=None):
551556 Gets the relevant data handler based on location argument
552557 :return:
553558 """
554- location = str .lower (location or self .location )
555- if location == 'local' :
556- return data_handlers .LocalDataHandler (self .session_path , self .signature , one = self .one )
557- self .one = self .one or ONE ()
558- if location == 'server' :
559- dhandler = data_handlers .ServerDataHandler (self .session_path , self .signature , one = self .one )
560- elif location == 'serverglobus' :
561- dhandler = data_handlers .ServerGlobusDataHandler (self .session_path , self .signature , one = self .one )
562- elif location == 'remote' :
563- dhandler = data_handlers .RemoteHttpDataHandler (self .session_path , self .signature , one = self .one )
564- elif location == 'aws' :
565- dhandler = data_handlers .RemoteAwsDataHandler (self .session_path , self .signature , one = self .one )
566- elif location == 'sdsc' :
567- dhandler = data_handlers .SDSCDataHandler (self .session_path , self .signature , one = self .one )
568- elif location == 'popeye' :
569- dhandler = data_handlers .PopeyeDataHandler (self .session_path , self .signature , one = self .one )
570- elif location == 'ec2' :
571- dhandler = data_handlers .RemoteEC2DataHandler (self .session_path , self .signature , one = self .one )
572- else :
573- raise ValueError (f'Unknown location "{ location } "' )
574- return dhandler
559+ if self .data_handler_class is None :
560+ location = str .lower (location or self .location )
561+ if location == 'local' :
562+ return data_handlers .LocalDataHandler (self .session_path , self .signature , one = self .one )
563+ self .one = self .one or ONE ()
564+ if location == 'server' :
565+ self .data_handler_class = data_handlers .ServerDataHandler
566+ elif location == 'serverglobus' :
567+ self .data_handler_class = data_handlers .ServerGlobusDataHandler
568+ elif location == 'remote' :
569+ self .data_handler_class = data_handlers .RemoteHttpDataHandler
570+ elif location == 'aws' :
571+ self .data_handler_class = data_handlers .RemoteAwsDataHandler
572+ elif location == 'sdsc' :
573+ self .data_handler_class = data_handlers .SDSCDataHandler
574+ elif location == 'popeye' :
575+ self .data_handler_class = data_handlers .PopeyeDataHandler
576+ elif location == 'ec2' :
577+ self .data_handler_class = data_handlers .RemoteEC2DataHandler
578+ else :
579+ raise ValueError (f'Unknown location "{ location } "' )
580+ return self .data_handler_class (self .session_path , self .signature , one = self .one )
575581
576582 @staticmethod
577583 def make_lock_file (taskname = '' , time_out_secs = 7200 ):
0 commit comments