8
8
from .expression import QueryExpression , AndList
9
9
from .errors import DataJointError , LostConnectionError
10
10
import signal
11
+ import multiprocessing as mp
11
12
12
13
# noinspection PyExceptionInherit,PyCallingNonCallable
13
14
14
15
logger = logging .getLogger (__name__ )
15
16
16
17
18
+ # --- helper functions for multiprocessing --
19
+
20
+ def _initialize_populate (table , jobs , populate_kwargs ):
21
+ """
22
+ Initialize the process for mulitprocessing.
23
+ Saves the unpickled copy of the table to the current process and reconnects.
24
+ """
25
+ process = mp .current_process ()
26
+ process .table = table
27
+ process .jobs = jobs
28
+ process .populate_kwargs = populate_kwargs
29
+ table .connection .connect () # reconnect
30
+
31
+
32
+ def _call_populate1 (key ):
33
+ """
34
+ Call current process' table._populate1()
35
+ :key - a dict specifying job to compute
36
+ :return: key, error if error, otherwise None
37
+ """
38
+ process = mp .current_process ()
39
+ return process .table ._populate1 (key , process .jobs , ** process .populate_kwargs )
40
+
41
+
17
42
class AutoPopulate :
18
43
"""
19
44
AutoPopulate is a mixin class that adds the method populate() to a Relation class.
@@ -26,18 +51,20 @@ class AutoPopulate:
26
51
@property
27
52
def key_source (self ):
28
53
"""
29
- :return: the relation whose primary key values are passed, sequentially, to the
30
- ``make`` method when populate() is called.
31
- The default value is the join of the parent relations.
32
- Users may override to change the granularity or the scope of populate() calls.
54
+ :return: the query expression that yields primary key values to be passed,
55
+ sequentially, to the ``make`` method when populate() is called.
56
+ The default value is the join of the parent tables references from the primary key.
57
+ Subclasses may override they key_source to change the scope or the granularity
58
+ of the make calls.
33
59
"""
34
60
def _rename_attributes (table , props ):
35
61
return (table .proj (
36
62
** {attr : ref for attr , ref in props ['attr_map' ].items () if attr != ref })
37
- if props ['aliased' ] else table )
63
+ if props ['aliased' ] else table . proj () )
38
64
39
65
if self ._key_source is None :
40
- parents = self .target .parents (primary = True , as_objects = True , foreign_key_info = True )
66
+ parents = self .target .parents (
67
+ primary = True , as_objects = True , foreign_key_info = True )
41
68
if not parents :
42
69
raise DataJointError ('A table must have dependencies '
43
70
'from its primary key for auto-populate to work' )
@@ -48,17 +75,19 @@ def _rename_attributes(table, props):
48
75
49
76
def make (self , key ):
50
77
"""
51
- Derived classes must implement method `make` that fetches data from tables that are
52
- above them in the dependency hierarchy, restricting by the given key, computes dependent
53
- attributes, and inserts the new tuples into self.
78
+ Derived classes must implement method `make` that fetches data from tables
79
+ above them in the dependency hierarchy, restricting by the given key,
80
+ computes secondary attributes, and inserts the new tuples into self.
54
81
"""
55
- raise NotImplementedError ('Subclasses of AutoPopulate must implement the method `make`' )
82
+ raise NotImplementedError (
83
+ 'Subclasses of AutoPopulate must implement the method `make`' )
56
84
57
85
@property
58
86
def target (self ):
59
87
"""
60
88
:return: table to be populated.
61
- In the typical case, dj.AutoPopulate is mixed into a dj.Table class by inheritance and the target is self.
89
+ In the typical case, dj.AutoPopulate is mixed into a dj.Table class by
90
+ inheritance and the target is self.
62
91
"""
63
92
return self
64
93
@@ -85,41 +114,49 @@ def _jobs_to_do(self, restrictions):
85
114
86
115
if not isinstance (todo , QueryExpression ):
87
116
raise DataJointError ('Invalid key_source value' )
88
- # check if target lacks any attributes from the primary key of key_source
117
+
89
118
try :
119
+ # check if target lacks any attributes from the primary key of key_source
90
120
raise DataJointError (
91
- 'The populate target lacks attribute %s from the primary key of key_source' % next (
92
- name for name in todo .heading .primary_key if name not in self .target .heading ))
121
+ 'The populate target lacks attribute %s '
122
+ 'from the primary key of key_source' % next (
123
+ name for name in todo .heading .primary_key
124
+ if name not in self .target .heading ))
93
125
except StopIteration :
94
126
pass
95
127
return (todo & AndList (restrictions )).proj ()
96
128
97
129
def populate (self , * restrictions , suppress_errors = False , return_exception_objects = False ,
98
130
reserve_jobs = False , order = "original" , limit = None , max_calls = None ,
99
- display_progress = False , make_kwargs = None ):
131
+ display_progress = False , processes = 1 , make_kwargs = None ):
100
132
"""
101
- rel.populate() calls rel.make(key) for every primary key in self.key_source
102
- for which there is not already a tuple in rel.
103
- :param restrictions: a list of restrictions each restrict (rel.key_source - target.proj())
133
+ table.populate() calls table.make(key) for every primary key in self.key_source
134
+ for which there is not already a tuple in table.
135
+ :param restrictions: a list of restrictions each restrict
136
+ (table.key_source - target.proj())
104
137
:param suppress_errors: if True, do not terminate execution.
105
138
:param return_exception_objects: return error objects instead of just error messages
106
- :param reserve_jobs: if true, reserves job to populate in asynchronous fashion
139
+ :param reserve_jobs: if True, reserve jobs to populate in asynchronous fashion
107
140
:param order: "original"|"reverse"|"random" - the order of execution
141
+ :param limit: if not None, check at most this many keys
142
+ :param max_calls: if not None, populate at most this many keys
108
143
:param display_progress: if True, report progress_bar
109
144
:param limit: if not None, checks at most that many keys
110
145
:param max_calls: if not None, populates at max that many keys
111
- :param make_kwargs: optional dict containing keyword arguments that will be passed down to each make() call
146
+ :param processes: number of processes to use. When set to a large number, then
147
+ uses as many as CPU cores
148
+ :param make_kwargs: optional dict containing keyword arguments that will be
149
+ passed down to each make() call
112
150
"""
113
151
if self .connection .in_transaction :
114
152
raise DataJointError ('Populate cannot be called during a transaction.' )
115
153
116
154
valid_order = ['original' , 'reverse' , 'random' ]
117
155
if order not in valid_order :
118
156
raise DataJointError ('The order argument must be one of %s' % str (valid_order ))
119
- error_list = [] if suppress_errors else None
120
157
jobs = self .connection .schemas [self .target .database ].jobs if reserve_jobs else None
121
158
122
- # define and setup signal handler for SIGTERM
159
+ # define and set up signal handler for SIGTERM:
123
160
if reserve_jobs :
124
161
def handler (signum , frame ):
125
162
logger .info ('Populate terminated by SIGTERM' )
@@ -132,60 +169,99 @@ def handler(signum, frame):
132
169
elif order == "random" :
133
170
random .shuffle (keys )
134
171
135
- call_count = 0
136
172
logger .info ('Found %d keys to populate' % len (keys ))
137
173
138
- make = self ._make_tuples if hasattr (self , '_make_tuples' ) else self .make
174
+ keys = keys [:max_calls ]
175
+ nkeys = len (keys )
139
176
140
- for key in (tqdm (keys , desc = self .__class__ .__name__ ) if display_progress else keys ):
141
- if max_calls is not None and call_count >= max_calls :
142
- break
143
- if not reserve_jobs or jobs .reserve (self .target .table_name , self ._job_key (key )):
144
- self .connection .start_transaction ()
145
- if key in self .target : # already populated
146
- self .connection .cancel_transaction ()
147
- if reserve_jobs :
148
- jobs .complete (self .target .table_name , self ._job_key (key ))
177
+ if processes > 1 :
178
+ processes = min (processes , nkeys , mp .cpu_count ())
179
+
180
+ error_list = []
181
+ populate_kwargs = dict (
182
+ suppress_errors = suppress_errors ,
183
+ return_exception_objects = return_exception_objects )
184
+
185
+ if processes == 1 :
186
+ for key in tqdm (keys , desc = self .__class__ .__name__ ) if display_progress else keys :
187
+ error = self ._populate1 (key , jobs , ** populate_kwargs )
188
+ if error is not None :
189
+ error_list .append (error )
190
+ else :
191
+ # spawn multiple processes
192
+ self .connection .close () # disconnect parent process from MySQL server
193
+ del self .connection ._conn .ctx # SSLContext is not pickleable
194
+ with mp .Pool (processes , _initialize_populate , (self , populate_kwargs )) as pool :
195
+ if display_progress :
196
+ with tqdm (desc = "Processes: " , total = nkeys ) as pbar :
197
+ for error in pool .imap (_call_populate1 , keys , chunksize = 1 ):
198
+ if error is not None :
199
+ error_list .append (error )
200
+ pbar .update ()
149
201
else :
150
- logger .info ('Populating: ' + str (key ))
151
- call_count += 1
152
- self .__class__ ._allow_insert = True
153
- try :
154
- make (dict (key ), ** (make_kwargs or {}))
155
- except (KeyboardInterrupt , SystemExit , Exception ) as error :
156
- try :
157
- self .connection .cancel_transaction ()
158
- except LostConnectionError :
159
- pass
160
- error_message = '{exception}{msg}' .format (
161
- exception = error .__class__ .__name__ ,
162
- msg = ': ' + str (error ) if str (error ) else '' )
163
- if reserve_jobs :
164
- # show error name and error message (if any)
165
- jobs .error (
166
- self .target .table_name , self ._job_key (key ),
167
- error_message = error_message , error_stack = traceback .format_exc ())
168
- if not suppress_errors or isinstance (error , SystemExit ):
169
- raise
170
- else :
171
- logger .error (error )
172
- error_list .append ((key , error if return_exception_objects else error_message ))
173
- else :
174
- self .connection .commit_transaction ()
175
- if reserve_jobs :
176
- jobs .complete (self .target .table_name , self ._job_key (key ))
177
- finally :
178
- self .__class__ ._allow_insert = False
202
+ for error in pool .imap (_call_populate1 , keys ):
203
+ if error is not None :
204
+ error_list .append (error )
205
+ self .connection .connect () # reconnect parent process to MySQL server
179
206
180
- # place back the original signal handler
207
+ # restore original signal handler:
181
208
if reserve_jobs :
182
209
signal .signal (signal .SIGTERM , old_handler )
183
- return error_list
210
+
211
+ if suppress_errors :
212
+ return error_list
213
+
214
+ def _populate1 (self , key , jobs , suppress_errors , return_exception_objects ):
215
+ """
216
+ populates table for one source key, calling self.make inside a transaction.
217
+ :param jobs: the jobs table or None if not reserve_jobs
218
+ :param key: dict specifying job to populate
219
+ :param suppress_errors: bool if errors should be suppressed and returned
220
+ :param return_exception_objects: if True, errors must be returned as objects
221
+ :return: (key, error) when suppress_errors=True, otherwise None
222
+ """
223
+ make = self ._make_tuples if hasattr (self , '_make_tuples' ) else self .make
224
+
225
+ if jobs is None or jobs .reserve (self .target .table_name , self ._job_key (key )):
226
+ self .connection .start_transaction ()
227
+ if key in self .target : # already populated
228
+ self .connection .cancel_transaction ()
229
+ if jobs is not None :
230
+ jobs .complete (self .target .table_name , self ._job_key (key ))
231
+ else :
232
+ logger .info ('Populating: ' + str (key ))
233
+ self .__class__ ._allow_insert = True
234
+ try :
235
+ make (dict (key ))
236
+ except (KeyboardInterrupt , SystemExit , Exception ) as error :
237
+ try :
238
+ self .connection .cancel_transaction ()
239
+ except LostConnectionError :
240
+ pass
241
+ error_message = '{exception}{msg}' .format (
242
+ exception = error .__class__ .__name__ ,
243
+ msg = ': ' + str (error ) if str (error ) else '' )
244
+ if jobs is not None :
245
+ # show error name and error message (if any)
246
+ jobs .error (
247
+ self .target .table_name , self ._job_key (key ),
248
+ error_message = error_message , error_stack = traceback .format_exc ())
249
+ if not suppress_errors or isinstance (error , SystemExit ):
250
+ raise
251
+ else :
252
+ logger .error (error )
253
+ return key , error if return_exception_objects else error_message
254
+ else :
255
+ self .connection .commit_transaction ()
256
+ if jobs is not None :
257
+ jobs .complete (self .target .table_name , self ._job_key (key ))
258
+ finally :
259
+ self .__class__ ._allow_insert = False
184
260
185
261
def progress (self , * restrictions , display = True ):
186
262
"""
187
- report progress of populating the table
188
- :return: remaining, total -- tuples to be populated
263
+ Report the progress of populating the table.
264
+ :return: ( remaining, total) -- numbers of tuples to be populated
189
265
"""
190
266
todo = self ._jobs_to_do (restrictions )
191
267
total = len (todo )
@@ -194,5 +270,6 @@ def progress(self, *restrictions, display=True):
194
270
print ('%-20s' % self .__class__ .__name__ ,
195
271
'Completed %d of %d (%2.1f%%) %s' % (
196
272
total - remaining , total , 100 - 100 * remaining / (total + 1e-12 ),
197
- datetime .datetime .strftime (datetime .datetime .now (), '%Y-%m-%d %H:%M:%S' )), flush = True )
273
+ datetime .datetime .strftime (datetime .datetime .now (),
274
+ '%Y-%m-%d %H:%M:%S' )), flush = True )
198
275
return remaining , total
0 commit comments