10
10
from .errors import DataJointError
11
11
from .table import FreeTable
12
12
import signal
13
+ import multiprocessing as mp
13
14
14
15
# noinspection PyExceptionInherit,PyCallingNonCallable
15
16
16
17
logger = logging .getLogger (__name__ )
17
18
18
19
20
+ def initializer (table ):
21
+ """Save pickled copy of (disconnected) table to the current process,
22
+ then reconnect to server. For use by call_make_key()"""
23
+ mp .current_process ().table = table
24
+ table .connection .connect () # reconnect
25
+
26
+ def call_make_key (key ):
27
+ """Call current process' table.make_key()"""
28
+ table = mp .current_process ().table
29
+ error = table .make_key (key )
30
+ return error
31
+
32
+
19
33
class AutoPopulate :
20
34
"""
21
35
AutoPopulate is a mixin class that adds the method populate() to a Relation class.
@@ -103,29 +117,36 @@ def _jobs_to_do(self, restrictions):
103
117
104
118
def populate (self , * restrictions , suppress_errors = False , return_exception_objects = False ,
105
119
reserve_jobs = False , order = "original" , limit = None , max_calls = None ,
106
- display_progress = False ):
120
+ display_progress = False , multiprocess = False ):
107
121
"""
108
122
rel.populate() calls rel.make(key) for every primary key in self.key_source
109
123
for which there is not already a tuple in rel.
110
124
:param restrictions: a list of restrictions each restrict (rel.key_source - target.proj())
111
125
:param suppress_errors: if True, do not terminate execution.
112
126
:param return_exception_objects: return error objects instead of just error messages
113
- :param reserve_jobs: if true, reserves job to populate in asynchronous fashion
127
+ :param reserve_jobs: if True, reserve jobs to populate in asynchronous fashion
114
128
:param order: "original"|"reverse"|"random" - the order of execution
129
+ :param limit: if not None, check at most this many keys
130
+ :param max_calls: if not None, populate at most this many keys
115
131
:param display_progress: if True, report progress_bar
116
- :param limit : if not None, checks at most that many keys
117
- :param max_calls: if not None, populates at max that many keys
132
+ :param multiprocess : if True, use as many processes as CPU cores, or use the integer
133
+ number of processes specified
118
134
"""
119
135
if self .connection .in_transaction :
120
136
raise DataJointError ('Populate cannot be called during a transaction.' )
121
137
122
138
valid_order = ['original' , 'reverse' , 'random' ]
123
139
if order not in valid_order :
124
140
raise DataJointError ('The order argument must be one of %s' % str (valid_order ))
125
- error_list = [] if suppress_errors else None
126
141
jobs = self .connection .schemas [self .target .database ].jobs if reserve_jobs else None
127
142
128
- # define and setup signal handler for SIGTERM
143
+ self ._make_key_kwargs = {'suppress_errors' :suppress_errors ,
144
+ 'return_exception_objects' :return_exception_objects ,
145
+ 'reserve_jobs' :reserve_jobs ,
146
+ 'jobs' :jobs ,
147
+ }
148
+
149
+ # define and set up signal handler for SIGTERM:
129
150
if reserve_jobs :
130
151
def handler (signum , frame ):
131
152
logger .info ('Populate terminated by SIGTERM' )
@@ -138,55 +159,101 @@ def handler(signum, frame):
138
159
elif order == "random" :
139
160
random .shuffle (keys )
140
161
141
- call_count = 0
142
162
logger .info ('Found %d keys to populate' % len (keys ))
143
163
144
- make = self ._make_tuples if hasattr (self , '_make_tuples' ) else self .make
164
+ if max_calls is not None :
165
+ keys = keys [:max_calls ]
166
+ nkeys = len (keys )
145
167
146
- for key in (tqdm (keys ) if display_progress else keys ):
147
- if max_calls is not None and call_count >= max_calls :
148
- break
149
- if not reserve_jobs or jobs .reserve (self .target .table_name , self ._job_key (key )):
150
- self .connection .start_transaction ()
151
- if key in self .target : # already populated
152
- self .connection .cancel_transaction ()
153
- if reserve_jobs :
154
- jobs .complete (self .target .table_name , self ._job_key (key ))
168
+ if multiprocess : # True or int, presumably
169
+ if multiprocess is True :
170
+ nproc = mp .cpu_count ()
171
+ else :
172
+ if not isinstance (multiprocess , int ):
173
+ raise DataJointError ("multiprocess can be False, True or a positive integer" )
174
+ nproc = multiprocess
175
+ else :
176
+ nproc = 1
177
+ nproc = min (nproc , nkeys ) # no sense spawning more than can be used
178
+ error_list = []
179
+ if nproc > 1 : # spawn multiple processes
180
+ # prepare to pickle self:
181
+ self .connection .close () # disconnect parent process from MySQL server
182
+ del self .connection ._conn .ctx # SSLContext is not picklable
183
+ print ('*** Spawning pool of %d processes' % nproc )
184
+ # send pickled copy of self to each process,
185
+ # each worker process calls initializer(*initargs) when it starts
186
+ with mp .Pool (nproc , initializer , (self ,)) as pool :
187
+ if display_progress :
188
+ with tqdm (total = nkeys ) as pbar :
189
+ for error in pool .imap (call_make_key , keys , chunksize = 1 ):
190
+ if error is not None :
191
+ error_list .append (error )
192
+ pbar .update ()
155
193
else :
156
- logger .info ('Populating: ' + str (key ))
157
- call_count += 1
158
- self .__class__ ._allow_insert = True
159
- try :
160
- make (dict (key ))
161
- except (KeyboardInterrupt , SystemExit , Exception ) as error :
162
- try :
163
- self .connection .cancel_transaction ()
164
- except OperationalError :
165
- pass
166
- error_message = '{exception}{msg}' .format (
167
- exception = error .__class__ .__name__ ,
168
- msg = ': ' + str (error ) if str (error ) else '' )
169
- if reserve_jobs :
170
- # show error name and error message (if any)
171
- jobs .error (
172
- self .target .table_name , self ._job_key (key ),
173
- error_message = error_message , error_stack = traceback .format_exc ())
174
- if not suppress_errors or isinstance (error , SystemExit ):
175
- raise
176
- else :
177
- logger .error (error )
178
- error_list .append ((key , error if return_exception_objects else error_message ))
179
- else :
180
- self .connection .commit_transaction ()
181
- if reserve_jobs :
182
- jobs .complete (self .target .table_name , self ._job_key (key ))
183
- finally :
184
- self .__class__ ._allow_insert = False
194
+ for error in pool .imap (call_make_key , keys ):
195
+ if error is not None :
196
+ error_list .append (error )
197
+ self .connection .connect () # reconnect parent process to MySQL server
198
+ else : # use single process
199
+ for key in tqdm (keys ) if display_progress else keys :
200
+ error = self .make_key (key )
201
+ if error is not None :
202
+ error_list .append (error )
185
203
186
- # place back the original signal handler
204
+ del self ._make_key_kwargs # clean up
205
+
206
+ # restore original signal handler:
187
207
if reserve_jobs :
188
208
signal .signal (signal .SIGTERM , old_handler )
189
- return error_list
209
+
210
+ if suppress_errors :
211
+ return error_list
212
+
213
+ def make_key (self , key ):
214
+ make = self ._make_tuples if hasattr (self , '_make_tuples' ) else self .make
215
+
216
+ kwargs = self ._make_key_kwargs
217
+ suppress_errors = kwargs ['suppress_errors' ]
218
+ return_exception_objects = kwargs ['return_exception_objects' ]
219
+ reserve_jobs = kwargs ['reserve_jobs' ]
220
+ jobs = kwargs ['jobs' ]
221
+
222
+ if not reserve_jobs or jobs .reserve (self .target .table_name , self ._job_key (key )):
223
+ self .connection .start_transaction ()
224
+ if key in self .target : # already populated
225
+ self .connection .cancel_transaction ()
226
+ if reserve_jobs :
227
+ jobs .complete (self .target .table_name , self ._job_key (key ))
228
+ else :
229
+ logger .info ('Populating: ' + str (key ))
230
+ self .__class__ ._allow_insert = True
231
+ try :
232
+ make (dict (key ))
233
+ except (KeyboardInterrupt , SystemExit , Exception ) as error :
234
+ try :
235
+ self .connection .cancel_transaction ()
236
+ except OperationalError :
237
+ pass
238
+ error_message = '{exception}{msg}' .format (
239
+ exception = error .__class__ .__name__ ,
240
+ msg = ': ' + str (error ) if str (error ) else '' )
241
+ if reserve_jobs :
242
+ # show error name and error message (if any)
243
+ jobs .error (
244
+ self .target .table_name , self ._job_key (key ),
245
+ error_message = error_message , error_stack = traceback .format_exc ())
246
+ if not suppress_errors or isinstance (error , SystemExit ):
247
+ raise
248
+ else :
249
+ logger .error (error )
250
+ return (key , error if return_exception_objects else error_message )
251
+ else :
252
+ self .connection .commit_transaction ()
253
+ if reserve_jobs :
254
+ jobs .complete (self .target .table_name , self ._job_key (key ))
255
+ finally :
256
+ self .__class__ ._allow_insert = False
190
257
191
258
def progress (self , * restrictions , display = True ):
192
259
"""
0 commit comments