@@ -33,19 +33,47 @@ class ParallelFactory:
3333 """Abstract class that should be subclassed to provide parallel functions."""
3434
3535 _BATCH_SIZE = "PARALLEL_BATCH_SIZE"
36+ _CHUNK_SIZE = "PARALLEL_CHUNK_SIZE"
3637
37- def __init__ (self , * args , batch_size = None , ** kwargs ): # pylint: disable=unused-argument
38+ # pylint: disable=unused-argument
39+ def __init__ (self , batch_size = None , chunk_size = None , ** kwargs ):
3840 self .batch_size = batch_size or int (os .getenv (self ._BATCH_SIZE , "0" )) or None
39- self .nb_processes = 1
4041 L .info ("Using %s=%s" , self ._BATCH_SIZE , self .batch_size )
4142
43+ self .chunk_size = batch_size or int (os .getenv (self ._CHUNK_SIZE , "0" )) or None
44+ L .info ("Using %s=%s" , self ._CHUNK_SIZE , self .chunk_size )
45+
46+ self .nb_processes = 1
47+
4248 @abstractmethod
43- def get_mapper (self ):
49+ def get_mapper (self , batch_size = None , chunk_size = None , ** kwargs ):
4450 """Return a mapper function that can be used to execute functions in parallel."""
4551
4652 def shutdown (self ):
4753 """Can be used to cleanup."""
4854
55+ def _with_batches (self , mapper , func , iterable , batch_size = None ):
56+ """Wrapper on mapper function creating batches of iterable to give to mapper.
57+
58+ The batch_size is an int corresponding to the number of evaluation in each batch/
59+ """
60+ if isinstance (iterable , Iterator ):
61+ iterable = list (iterable )
62+
63+ batch_size = batch_size or self .batch_size
64+ if batch_size is not None :
65+ iterables = np .array_split (iterable , len (iterable ) // min (batch_size , len (iterable )))
66+ else :
67+ iterables = [iterable ]
68+
69+ for _iterable in iterables :
70+ yield from mapper (func , _iterable )
71+
72+ def _chunksize_to_kwargs (self , chunk_size , kwargs , label = "chunk_size" ):
73+ chunk_size = chunk_size or self .chunk_size
74+ if chunk_size is not None :
75+ kwargs [label ] = chunk_size
76+
4977
5078class NoDaemonProcess (multiprocessing .Process ):
5179 """Class that represents a non-daemon process"""
@@ -72,26 +100,10 @@ class NestedPool(Pool): # pylint: disable=abstract-method
72100 Process = NoDaemonProcess
73101
74102
75- def _with_batches (mapper , func , iterable , batch_size = None ):
76- """Wrapper on mapper function creating batches of iterable to give to mapper.
77-
78- The batch_size is an int corresponding to the number of evaluation in each batch/
79- """
80- if isinstance (iterable , Iterator ):
81- iterable = list (iterable )
82- if batch_size is not None :
83- iterables = np .array_split (iterable , len (iterable ) // min (batch_size , len (iterable )))
84- else :
85- iterables = [iterable ]
86-
87- for _iterable in iterables :
88- yield from mapper (func , _iterable )
89-
90-
91103class SerialFactory (ParallelFactory ):
92104 """Factory that do not work in parallel."""
93105
94- def get_mapper (self ):
106+ def get_mapper (self , batch_size = None , chunk_size = None , ** kwargs ):
95107 """Get a map."""
96108 return map
97109
@@ -101,18 +113,24 @@ class MultiprocessingFactory(ParallelFactory):
101113
102114 _CHUNKSIZE = "PARALLEL_CHUNKSIZE"
103115
104- def __init__ (self , * args , processes = None , ** kwargs ):
116+ def __init__ (self , processes = None , ** kwargs ):
105117 """Initialize multiprocessing factory."""
106118
107- super ().__init__ ()
108- self .pool = NestedPool (* args , ** kwargs )
119+ super ().__init__ (** kwargs )
120+
121+ self .pool = NestedPool (processes = processes )
109122 self .nb_processes = processes or os .cpu_count ()
110123
111- def get_mapper (self ):
124+ def get_mapper (self , batch_size = None , chunk_size = None , ** kwargs ):
112125 """Get a NestedPool."""
126+ self ._chunksize_to_kwargs (chunk_size , kwargs )
113127
114128 def _mapper (func , iterable ):
115- return _with_batches (self .pool .imap_unordered , func , iterable , self .batch_size )
129+ return self ._with_batches (
130+ partial (self .pool .imap_unordered , ** kwargs ),
131+ func ,
132+ iterable ,
133+ )
116134
117135 return _mapper
118136
@@ -126,24 +144,29 @@ class IPyParallelFactory(ParallelFactory):
126144
127145 _IPYTHON_PROFILE = "IPYTHON_PROFILE"
128146
129- def __init__ (self , * args , * *kwargs ):
147+ def __init__ (self , ** kwargs ):
130148 """Initialize the ipyparallel factory."""
131149
132- super ().__init__ ()
150+ super ().__init__ (** kwargs )
133151 self .rc = None
134152 self .nb_processes = 1
135153
136- def get_mapper (self ):
154+ def get_mapper (self , batch_size = None , chunk_size = None , ** kwargs ):
137155 """Get an ipyparallel mapper using the profile name provided."""
138- profile = os .getenv (self ._IPYTHON_PROFILE , "DEFAULT_IPYTHON_PROFILE" )
156+ profile = os .getenv (self ._IPYTHON_PROFILE , None )
139157 L .debug ("Using %s=%s" , self ._IPYTHON_PROFILE , profile )
140158 self .rc = ipyparallel .Client (profile = profile )
141159 self .nb_processes = len (self .rc .ids )
142160 lview = self .rc .load_balanced_view ()
143161
162+ if "ordered" not in kwargs :
163+ kwargs ["ordered" ] = False
164+
165+ self ._chunksize_to_kwargs (chunk_size , kwargs )
166+
144167 def _mapper (func , iterable ):
145- return _with_batches (
146- partial (lview .imap , ordered = False ), func , iterable , self . batch_size
168+ return self . _with_batches (
169+ partial (lview .imap , ** kwargs ), func , iterable , batch_size = batch_size
147170 )
148171
149172 return _mapper
@@ -159,7 +182,7 @@ class DaskFactory(ParallelFactory):
159182
160183 _SCHEDULER_PATH = "PARALLEL_DASK_SCHEDULER_PATH"
161184
162- def __init__ (self , * args , * *kwargs ):
185+ def __init__ (self , ** kwargs ):
163186 """Initialize the dask factory."""
164187 dask_scheduler_path = os .getenv (self ._SCHEDULER_PATH )
165188 if dask_scheduler_path :
@@ -172,7 +195,7 @@ def __init__(self, *args, **kwargs):
172195 L .info ("Starting dask_mpi..." )
173196 self .client = dask .distributed .Client ()
174197 self .nb_processes = len (self .client .scheduler_info ()["workers" ])
175- super ().__init__ ()
198+ super ().__init__ (** kwargs )
176199
177200 def shutdown (self ):
178201 """Retire the workers on the scheduler."""
@@ -181,16 +204,17 @@ def shutdown(self):
181204 self .client .retire_workers ()
182205 self .client = None
183206
184- def get_mapper (self ):
207+ def get_mapper (self , batch_size = None , chunk_size = None , ** kwargs ):
185208 """Get a Dask mapper."""
209+ self ._chunksize_to_kwargs (chunk_size , kwargs , label = "batch_size" )
186210
187211 def _mapper (func , iterable ):
188212 def _dask_mapper (func , iterable ):
189- futures = self .client .map (func , iterable )
213+ futures = self .client .map (func , iterable , ** kwargs )
190214 for _future , result in dask .distributed .as_completed (futures , with_results = True ):
191215 yield result
192216
193- return _with_batches (_dask_mapper , func , iterable , self . batch_size )
217+ return self . _with_batches (_dask_mapper , func , iterable , batch_size = batch_size )
194218
195219 return _mapper
196220
0 commit comments