8
8
from .expression import QueryExpression , AndList
9
9
from .errors import DataJointError , LostConnectionError
10
10
import signal
11
+ import multiprocessing as mp
11
12
12
13
# noinspection PyExceptionInherit,PyCallingNonCallable
13
14
14
15
logger = logging .getLogger (__name__ )
15
16
16
17
18
+ # --- helper functions for multiprocessing --
19
+
20
+ def _initialize_populate (table , jobs , populate_kwargs ):
21
+ """
22
+ Initialize the process for mulitprocessing.
23
+ Saves the unpickled copy of the table to the current process and reconnects.
24
+ """
25
+ process = mp .current_process ()
26
+ process .table = table
27
+ process .jobs = jobs
28
+ process .populate_kwargs = populate_kwargs
29
+ table .connection .connect () # reconnect
30
+
31
+
32
+ def _call_populate1 (key ):
33
+ """
34
+ Call current process' table._populate1()
35
+ :key - a dict specifying job to compute
36
+ :return: key, error if error, otherwise None
37
+ """
38
+ process = mp .current_process ()
39
+ return process .table ._populate1 (key , process .jobs , ** process .populate_kwargs )
40
+
41
+
17
42
class AutoPopulate :
18
43
"""
19
44
AutoPopulate is a mixin class that adds the method populate() to a Relation class.
@@ -28,8 +53,9 @@ def key_source(self):
28
53
"""
29
54
:return: the query expression that yields primary key values to be passed,
30
55
sequentially, to the ``make`` method when populate() is called.
31
- The default value is the join of the parent relations.
32
- Users may override to change the granularity or the scope of populate() calls.
56
+ The default value is the join of the parent tables references from the primary key.
57
+ Subclasses may override they key_source to change the scope or the granularity
58
+ of the make calls.
33
59
"""
34
60
def _rename_attributes (table , props ):
35
61
return (table .proj (
@@ -96,29 +122,30 @@ def _jobs_to_do(self, restrictions):
96
122
97
123
def populate (self , * restrictions , suppress_errors = False , return_exception_objects = False ,
98
124
reserve_jobs = False , order = "original" , limit = None , max_calls = None ,
99
- display_progress = False ):
125
+ display_progress = False , processes = 1 ):
100
126
"""
101
- rel .populate() calls rel .make(key) for every primary key in self.key_source
102
- for which there is not already a tuple in rel .
103
- :param restrictions: a list of restrictions each restrict (rel .key_source - target.proj())
127
+ table .populate() calls table .make(key) for every primary key in self.key_source
128
+ for which there is not already a tuple in table .
129
+ :param restrictions: a list of restrictions each restrict (table .key_source - target.proj())
104
130
:param suppress_errors: if True, do not terminate execution.
105
131
:param return_exception_objects: return error objects instead of just error messages
106
- :param reserve_jobs: if true, reserves job to populate in asynchronous fashion
132
+ :param reserve_jobs: if True, reserve jobs to populate in asynchronous fashion
107
133
:param order: "original"|"reverse"|"random" - the order of execution
134
+ :param limit: if not None, check at most this many keys
135
+ :param max_calls: if not None, populate at most this many keys
108
136
:param display_progress: if True, report progress_bar
109
- :param limit: if not None, checks at most that many keys
110
- :param max_calls: if not None, populates at max that many keys
137
+ :param processes: number of processes to use. When set to a large number, then
138
+ uses as many as CPU cores
111
139
"""
112
140
if self .connection .in_transaction :
113
141
raise DataJointError ('Populate cannot be called during a transaction.' )
114
142
115
143
valid_order = ['original' , 'reverse' , 'random' ]
116
144
if order not in valid_order :
117
145
raise DataJointError ('The order argument must be one of %s' % str (valid_order ))
118
- error_list = [] if suppress_errors else None
119
146
jobs = self .connection .schemas [self .target .database ].jobs if reserve_jobs else None
120
147
121
- # define and setup signal handler for SIGTERM
148
+ # define and set up signal handler for SIGTERM:
122
149
if reserve_jobs :
123
150
def handler (signum , frame ):
124
151
logger .info ('Populate terminated by SIGTERM' )
@@ -131,60 +158,99 @@ def handler(signum, frame):
131
158
elif order == "random" :
132
159
random .shuffle (keys )
133
160
134
- call_count = 0
135
161
logger .info ('Found %d keys to populate' % len (keys ))
136
162
137
- make = self ._make_tuples if hasattr (self , '_make_tuples' ) else self .make
163
+ keys = keys [:max_calls ]
164
+ nkeys = len (keys )
138
165
139
- for key in (tqdm (keys , desc = self .__class__ .__name__ ) if display_progress else keys ):
140
- if max_calls is not None and call_count >= max_calls :
141
- break
142
- if not reserve_jobs or jobs .reserve (self .target .table_name , self ._job_key (key )):
143
- self .connection .start_transaction ()
144
- if key in self .target : # already populated
145
- self .connection .cancel_transaction ()
146
- if reserve_jobs :
147
- jobs .complete (self .target .table_name , self ._job_key (key ))
166
+ if processes > 1 :
167
+ processes = min (processes , nkeys , mp .cpu_count ())
168
+
169
+ error_list = []
170
+ populate_kwargs = dict (
171
+ suppress_errors = suppress_errors ,
172
+ return_exception_objects = return_exception_objects )
173
+
174
+ if processes == 1 :
175
+ for key in tqdm (keys , desc = self .__class__ .__name__ ) if display_progress else keys :
176
+ error = self ._populate1 (key , jobs , ** populate_kwargs )
177
+ if error is not None :
178
+ error_list .append (error )
179
+ else :
180
+ # spawn multiple processes
181
+ self .connection .close () # disconnect parent process from MySQL server
182
+ del self .connection ._conn .ctx # SSLContext is not pickleable
183
+ with mp .Pool (processes , _initialize_populate , (self , populate_kwargs )) as pool :
184
+ if display_progress :
185
+ with tqdm (desc = "Processes: " , total = nkeys ) as pbar :
186
+ for error in pool .imap (_call_populate1 , keys , chunksize = 1 ):
187
+ if error is not None :
188
+ error_list .append (error )
189
+ pbar .update ()
148
190
else :
149
- logger .info ('Populating: ' + str (key ))
150
- call_count += 1
151
- self .__class__ ._allow_insert = True
152
- try :
153
- make (dict (key ))
154
- except (KeyboardInterrupt , SystemExit , Exception ) as error :
155
- try :
156
- self .connection .cancel_transaction ()
157
- except LostConnectionError :
158
- pass
159
- error_message = '{exception}{msg}' .format (
160
- exception = error .__class__ .__name__ ,
161
- msg = ': ' + str (error ) if str (error ) else '' )
162
- if reserve_jobs :
163
- # show error name and error message (if any)
164
- jobs .error (
165
- self .target .table_name , self ._job_key (key ),
166
- error_message = error_message , error_stack = traceback .format_exc ())
167
- if not suppress_errors or isinstance (error , SystemExit ):
168
- raise
169
- else :
170
- logger .error (error )
171
- error_list .append ((key , error if return_exception_objects else error_message ))
172
- else :
173
- self .connection .commit_transaction ()
174
- if reserve_jobs :
175
- jobs .complete (self .target .table_name , self ._job_key (key ))
176
- finally :
177
- self .__class__ ._allow_insert = False
191
+ for error in pool .imap (_call_populate1 , keys ):
192
+ if error is not None :
193
+ error_list .append (error )
194
+ self .connection .connect () # reconnect parent process to MySQL server
178
195
179
- # place back the original signal handler
196
+ # restore original signal handler:
180
197
if reserve_jobs :
181
198
signal .signal (signal .SIGTERM , old_handler )
182
- return error_list
199
+
200
+ if suppress_errors :
201
+ return error_list
202
+
203
+ def _populate1 (self , key , jobs , suppress_errors , return_exception_objects ):
204
+ """
205
+ populates table for one source key, calling self.make inside a transaction.
206
+ :param jobs: the jobs table or None if not reserve_jobs
207
+ :param key: dict specifying job to populate
208
+ :param suppress_errors: bool if errors should be suppressed and returned
209
+ :param return_exception_objects: if True, errors must be returned as objects
210
+ :return: (key, error) when suppress_errors=True, otherwise None
211
+ """
212
+ make = self ._make_tuples if hasattr (self , '_make_tuples' ) else self .make
213
+
214
+ if jobs is None or jobs .reserve (self .target .table_name , self ._job_key (key )):
215
+ self .connection .start_transaction ()
216
+ if key in self .target : # already populated
217
+ self .connection .cancel_transaction ()
218
+ if jobs is not None :
219
+ jobs .complete (self .target .table_name , self ._job_key (key ))
220
+ else :
221
+ logger .info ('Populating: ' + str (key ))
222
+ self .__class__ ._allow_insert = True
223
+ try :
224
+ make (dict (key ))
225
+ except (KeyboardInterrupt , SystemExit , Exception ) as error :
226
+ try :
227
+ self .connection .cancel_transaction ()
228
+ except LostConnectionError :
229
+ pass
230
+ error_message = '{exception}{msg}' .format (
231
+ exception = error .__class__ .__name__ ,
232
+ msg = ': ' + str (error ) if str (error ) else '' )
233
+ if jobs is not None :
234
+ # show error name and error message (if any)
235
+ jobs .error (
236
+ self .target .table_name , self ._job_key (key ),
237
+ error_message = error_message , error_stack = traceback .format_exc ())
238
+ if not suppress_errors or isinstance (error , SystemExit ):
239
+ raise
240
+ else :
241
+ logger .error (error )
242
+ return key , error if return_exception_objects else error_message
243
+ else :
244
+ self .connection .commit_transaction ()
245
+ if jobs is not None :
246
+ jobs .complete (self .target .table_name , self ._job_key (key ))
247
+ finally :
248
+ self .__class__ ._allow_insert = False
183
249
184
250
def progress (self , * restrictions , display = True ):
185
251
"""
186
- report progress of populating the table
187
- :return: remaining, total -- tuples to be populated
252
+ Report the progress of populating the table.
253
+ :return: ( remaining, total) -- numbers of tuples to be populated
188
254
"""
189
255
todo = self ._jobs_to_do (restrictions )
190
256
total = len (todo )
0 commit comments