1414 import dask_mpi
1515
1616 dask_available = True
17- except ImportError :
17+ except ImportError : # pragma: no cover
1818 dask_available = False
1919
2020try :
2121 import ipyparallel
2222
2323 ipyparallel_available = True
24- except ImportError :
24+ except ImportError : # pragma: no cover
2525 ipyparallel_available = False
2626
2727
2828L = logging .getLogger (__name__ )
2929
3030
31+ def _func_wrapper (data , func , func_args , func_kwargs ):
32+ """Function wrapper used to pass args and kwargs."""
33+ return func (data , * func_args , ** func_kwargs )
34+
35+
3136class ParallelFactory :
3237 """Abstract class that should be subclassed to provide parallel functions."""
3338
@@ -56,6 +61,10 @@ def get_mapper(self, batch_size=None, chunk_size=None, **kwargs):
5661 def shutdown (self ):
5762 """Can be used to cleanup."""
5863
64+ def mappable_func (self , func , * args , ** kwargs ):
65+ """Can be used to add args and kwargs to a function before calling the mapper."""
66+ return partial (_func_wrapper , func = func , func_args = args , func_kwargs = kwargs )
67+
5968 def _with_batches (self , mapper , func , iterable , batch_size = None ):
6069 """Wrapper on mapper function creating batches of iterable to give to mapper.
6170
@@ -95,7 +104,7 @@ def __init__(self, group=None, target=None, name=None, args=(), kwargs={}):
95104
96105 def _get_daemon (self ): # pylint: disable=no-self-use
97106 """Get daemon flag"""
98- return False
107+ return False # pragma: no cover
99108
100109 def _set_daemon (self , value ):
101110 """Set daemon flag"""
@@ -114,7 +123,12 @@ class SerialFactory(ParallelFactory):
114123
115124 def get_mapper (self , batch_size = None , chunk_size = None , ** kwargs ):
116125 """Get a map."""
117- return map
126+
127+ def _mapper (func , iterable , * func_args , ** func_kwargs ):
128+ func = self .mappable_func (func , * func_args , ** func_kwargs )
129+ return self ._with_batches (map , func , iterable )
130+
131+ return _mapper
118132
119133
120134class MultiprocessingFactory (ParallelFactory ):
@@ -134,7 +148,8 @@ def get_mapper(self, batch_size=None, chunk_size=None, **kwargs):
134148 """Get a NestedPool."""
135149 self ._chunksize_to_kwargs (chunk_size , kwargs , label = "chunksize" )
136150
137- def _mapper (func , iterable ):
151+ def _mapper (func , iterable , * func_args , ** func_kwargs ):
152+ func = self .mappable_func (func , * func_args , ** func_kwargs )
138153 return self ._with_batches (
139154 partial (self .pool .imap_unordered , ** kwargs ),
140155 func ,
@@ -164,12 +179,13 @@ def __init__(self, batch_size=None, chunk_size=None, profile=None, **kwargs):
164179
165180 def get_mapper (self , batch_size = None , chunk_size = None , ** kwargs ):
166181 """Get an ipyparallel mapper using the profile name provided."""
167- if "ordered" not in kwargs :
182+ if "ordered" not in kwargs : # pragma: no cover
168183 kwargs ["ordered" ] = False
169184
170185 self ._chunksize_to_kwargs (chunk_size , kwargs )
171186
172- def _mapper (func , iterable ):
187+ def _mapper (func , iterable , * func_args , ** func_kwargs ):
188+ func = self .mappable_func (func , * func_args , ** func_kwargs )
173189 return self ._with_batches (
174190 partial (self .lview .imap , ** kwargs ), func , iterable , batch_size = batch_size
175191 )
@@ -178,7 +194,7 @@ def _mapper(func, iterable):
178194
179195 def shutdown (self ):
180196 """Remove zmq."""
181- if self .rc is not None :
197+ if self .rc is not None : # pragma: no cover
182198 self .rc .close ()
183199
184200
@@ -193,11 +209,11 @@ def __init__(
193209 """Initialize the dask factory."""
194210 dask_scheduler_path = scheduler_file or os .getenv (self ._SCHEDULER_PATH )
195211 self .interactive = True
196- if dask_scheduler_path :
212+ if dask_scheduler_path : # pragma: no cover
197213 L .info ("Connecting dask_mpi with scheduler %s" , dask_scheduler_path )
198- if address :
214+ if address : # pragma: no cover
199215 L .info ("Connecting dask_mpi with address %s" , address )
200- if not dask_scheduler_path and not address :
216+ if not dask_scheduler_path and not address : # pragma: no cover
201217 self .interactive = False
202218 dask_mpi .initialize ()
203219 L .info ("Starting dask_mpi..." )
@@ -213,19 +229,20 @@ def shutdown(self):
213229 """Close the scheduler and the cluster if it was created by the factory."""
214230 cluster = self .client .cluster
215231 self .client .close ()
216- if not self .interactive :
232+ if not self .interactive : # pragma: no cover
217233 cluster .close ()
218234
219235 def get_mapper (self , batch_size = None , chunk_size = None , ** kwargs ):
220236 """Get a Dask mapper."""
221237 self ._chunksize_to_kwargs (chunk_size , kwargs , label = "batch_size" )
222238
223- def _mapper (func , iterable ):
224- def _dask_mapper (func , iterable ):
239+ def _mapper (func , iterable , * func_args , ** func_kwargs ):
240+ def _dask_mapper (func , iterable , ** kwargs ):
225241 futures = self .client .map (func , iterable , ** kwargs )
226242 for _future , result in dask .distributed .as_completed (futures , with_results = True ):
227243 yield result
228244
245+ func = self .mappable_func (func , * func_args , ** func_kwargs )
229246 return self ._with_batches (_dask_mapper , func , iterable , batch_size = batch_size )
230247
231248 return _mapper
@@ -245,9 +262,9 @@ def init_parallel_factory(parallel_lib, *args, **kwargs):
245262 None : SerialFactory ,
246263 "multiprocessing" : MultiprocessingFactory ,
247264 }
248- if dask_available :
265+ if dask_available : # pragma: no cover
249266 parallel_factories ["dask" ] = DaskFactory
250- if ipyparallel_available :
267+ if ipyparallel_available : # pragma: no cover
251268 parallel_factories ["ipyparallel" ] = IPyParallelFactory
252269
253270 try :
0 commit comments