9
9
from .errors import DataJointError , LostConnectionError
10
10
from .table import FreeTable
11
11
import signal
12
+ import multiprocessing as mp
12
13
13
14
# noinspection PyExceptionInherit,PyCallingNonCallable
14
15
15
16
logger = logging .getLogger (__name__ )
16
17
17
18
19
+ def initializer (table ):
20
+ """Save pickled copy of (disconnected) table to the current process,
21
+ then reconnect to server. For use by call_make_key()"""
22
+ mp .current_process ().table = table
23
+ table .connection .connect () # reconnect
24
+
25
+ def call_make_key (key ):
26
+ """Call current process' table.make_key()"""
27
+ table = mp .current_process ().table
28
+ error = table .make_key (key )
29
+ return error
30
+
31
+
18
32
class AutoPopulate :
19
33
"""
20
34
AutoPopulate is a mixin class that adds the method populate() to a Relation class.
@@ -102,29 +116,36 @@ def _jobs_to_do(self, restrictions):
102
116
103
117
def populate (self , * restrictions , suppress_errors = False , return_exception_objects = False ,
104
118
reserve_jobs = False , order = "original" , limit = None , max_calls = None ,
105
- display_progress = False ):
119
+ display_progress = False , multiprocess = False ):
106
120
"""
107
121
rel.populate() calls rel.make(key) for every primary key in self.key_source
108
122
for which there is not already a tuple in rel.
109
123
:param restrictions: a list of restrictions each restrict (rel.key_source - target.proj())
110
124
:param suppress_errors: if True, do not terminate execution.
111
125
:param return_exception_objects: return error objects instead of just error messages
112
- :param reserve_jobs: if true, reserves job to populate in asynchronous fashion
126
+ :param reserve_jobs: if True, reserve jobs to populate in asynchronous fashion
113
127
:param order: "original"|"reverse"|"random" - the order of execution
128
+ :param limit: if not None, check at most this many keys
129
+ :param max_calls: if not None, populate at most this many keys
114
130
:param display_progress: if True, report progress_bar
115
- :param limit : if not None, checks at most that many keys
116
- :param max_calls: if not None, populates at max that many keys
131
+ :param multiprocess : if True, use as many processes as CPU cores, or use the integer
132
+ number of processes specified
117
133
"""
118
134
if self .connection .in_transaction :
119
135
raise DataJointError ('Populate cannot be called during a transaction.' )
120
136
121
137
valid_order = ['original' , 'reverse' , 'random' ]
122
138
if order not in valid_order :
123
139
raise DataJointError ('The order argument must be one of %s' % str (valid_order ))
124
- error_list = [] if suppress_errors else None
125
140
jobs = self .connection .schemas [self .target .database ].jobs if reserve_jobs else None
126
141
127
- # define and setup signal handler for SIGTERM
142
+ self ._make_key_kwargs = {'suppress_errors' :suppress_errors ,
143
+ 'return_exception_objects' :return_exception_objects ,
144
+ 'reserve_jobs' :reserve_jobs ,
145
+ 'jobs' :jobs ,
146
+ }
147
+
148
+ # define and set up signal handler for SIGTERM:
128
149
if reserve_jobs :
129
150
def handler (signum , frame ):
130
151
logger .info ('Populate terminated by SIGTERM' )
@@ -137,55 +158,101 @@ def handler(signum, frame):
137
158
elif order == "random" :
138
159
random .shuffle (keys )
139
160
140
- call_count = 0
141
161
logger .info ('Found %d keys to populate' % len (keys ))
142
162
143
- make = self ._make_tuples if hasattr (self , '_make_tuples' ) else self .make
163
+ if max_calls is not None :
164
+ keys = keys [:max_calls ]
165
+ nkeys = len (keys )
144
166
145
- for key in (tqdm (keys ) if display_progress else keys ):
146
- if max_calls is not None and call_count >= max_calls :
147
- break
148
- if not reserve_jobs or jobs .reserve (self .target .table_name , self ._job_key (key )):
149
- self .connection .start_transaction ()
150
- if key in self .target : # already populated
151
- self .connection .cancel_transaction ()
152
- if reserve_jobs :
153
- jobs .complete (self .target .table_name , self ._job_key (key ))
167
+ if multiprocess : # True or int, presumably
168
+ if multiprocess is True :
169
+ nproc = mp .cpu_count ()
170
+ else :
171
+ if not isinstance (multiprocess , int ):
172
+ raise DataJointError ("multiprocess can be False, True or a positive integer" )
173
+ nproc = multiprocess
174
+ else :
175
+ nproc = 1
176
+ nproc = min (nproc , nkeys ) # no sense spawning more than can be used
177
+ error_list = []
178
+ if nproc > 1 : # spawn multiple processes
179
+ # prepare to pickle self:
180
+ self .connection .close () # disconnect parent process from MySQL server
181
+ del self .connection ._conn .ctx # SSLContext is not picklable
182
+ print ('*** Spawning pool of %d processes' % nproc )
183
+ # send pickled copy of self to each process,
184
+ # each worker process calls initializer(*initargs) when it starts
185
+ with mp .Pool (nproc , initializer , (self ,)) as pool :
186
+ if display_progress :
187
+ with tqdm (total = nkeys ) as pbar :
188
+ for error in pool .imap (call_make_key , keys , chunksize = 1 ):
189
+ if error is not None :
190
+ error_list .append (error )
191
+ pbar .update ()
154
192
else :
155
- logger .info ('Populating: ' + str (key ))
156
- call_count += 1
157
- self .__class__ ._allow_insert = True
158
- try :
159
- make (dict (key ))
160
- except (KeyboardInterrupt , SystemExit , Exception ) as error :
161
- try :
162
- self .connection .cancel_transaction ()
163
- except LostConnectionError :
164
- pass
165
- error_message = '{exception}{msg}' .format (
166
- exception = error .__class__ .__name__ ,
167
- msg = ': ' + str (error ) if str (error ) else '' )
168
- if reserve_jobs :
169
- # show error name and error message (if any)
170
- jobs .error (
171
- self .target .table_name , self ._job_key (key ),
172
- error_message = error_message , error_stack = traceback .format_exc ())
173
- if not suppress_errors or isinstance (error , SystemExit ):
174
- raise
175
- else :
176
- logger .error (error )
177
- error_list .append ((key , error if return_exception_objects else error_message ))
178
- else :
179
- self .connection .commit_transaction ()
180
- if reserve_jobs :
181
- jobs .complete (self .target .table_name , self ._job_key (key ))
182
- finally :
183
- self .__class__ ._allow_insert = False
193
+ for error in pool .imap (call_make_key , keys ):
194
+ if error is not None :
195
+ error_list .append (error )
196
+ self .connection .connect () # reconnect parent process to MySQL server
197
+ else : # use single process
198
+ for key in tqdm (keys ) if display_progress else keys :
199
+ error = self .make_key (key )
200
+ if error is not None :
201
+ error_list .append (error )
184
202
185
- # place back the original signal handler
203
+ del self ._make_key_kwargs # clean up
204
+
205
+ # restore original signal handler:
186
206
if reserve_jobs :
187
207
signal .signal (signal .SIGTERM , old_handler )
188
- return error_list
208
+
209
+ if suppress_errors :
210
+ return error_list
211
+
212
+ def make_key (self , key ):
213
+ make = self ._make_tuples if hasattr (self , '_make_tuples' ) else self .make
214
+
215
+ kwargs = self ._make_key_kwargs
216
+ suppress_errors = kwargs ['suppress_errors' ]
217
+ return_exception_objects = kwargs ['return_exception_objects' ]
218
+ reserve_jobs = kwargs ['reserve_jobs' ]
219
+ jobs = kwargs ['jobs' ]
220
+
221
+ if not reserve_jobs or jobs .reserve (self .target .table_name , self ._job_key (key )):
222
+ self .connection .start_transaction ()
223
+ if key in self .target : # already populated
224
+ self .connection .cancel_transaction ()
225
+ if reserve_jobs :
226
+ jobs .complete (self .target .table_name , self ._job_key (key ))
227
+ else :
228
+ logger .info ('Populating: ' + str (key ))
229
+ self .__class__ ._allow_insert = True
230
+ try :
231
+ make (dict (key ))
232
+ except (KeyboardInterrupt , SystemExit , Exception ) as error :
233
+ try :
234
+ self .connection .cancel_transaction ()
235
+ except LostConnectionError :
236
+ pass
237
+ error_message = '{exception}{msg}' .format (
238
+ exception = error .__class__ .__name__ ,
239
+ msg = ': ' + str (error ) if str (error ) else '' )
240
+ if reserve_jobs :
241
+ # show error name and error message (if any)
242
+ jobs .error (
243
+ self .target .table_name , self ._job_key (key ),
244
+ error_message = error_message , error_stack = traceback .format_exc ())
245
+ if not suppress_errors or isinstance (error , SystemExit ):
246
+ raise
247
+ else :
248
+ logger .error (error )
249
+ return (key , error if return_exception_objects else error_message )
250
+ else :
251
+ self .connection .commit_transaction ()
252
+ if reserve_jobs :
253
+ jobs .complete (self .target .table_name , self ._job_key (key ))
254
+ finally :
255
+ self .__class__ ._allow_insert = False
189
256
190
257
def progress (self , * restrictions , display = True ):
191
258
"""
0 commit comments