88from .expression import QueryExpression , AndList
99from .errors import DataJointError , LostConnectionError
1010import signal
11+ import multiprocessing as mp
1112
1213# noinspection PyExceptionInherit,PyCallingNonCallable
1314
1415logger = logging .getLogger (__name__ )
1516
1617
18+ # --- helper functions for multiprocessing --
19+
20+ def _initialize_populate (table , jobs , populate_kwargs ):
21+ """
22+ Initialize the process for mulitprocessing.
23+ Saves the unpickled copy of the table to the current process and reconnects.
24+ """
25+ process = mp .current_process ()
26+ process .table = table
27+ process .jobs = jobs
28+ process .populate_kwargs = populate_kwargs
29+ table .connection .connect () # reconnect
30+
31+
32+ def _call_populate1 (key ):
33+ """
34+ Call current process' table._populate1()
35+ :key - a dict specifying job to compute
36+ :return: key, error if error, otherwise None
37+ """
38+ process = mp .current_process ()
39+ return process .table ._populate1 (key , process .jobs , ** process .populate_kwargs )
40+
41+
1742class AutoPopulate :
1843 """
1944 AutoPopulate is a mixin class that adds the method populate() to a Relation class.
@@ -26,18 +51,20 @@ class AutoPopulate:
2651 @property
2752 def key_source (self ):
2853 """
29- :return: the relation whose primary key values are passed, sequentially, to the
30- ``make`` method when populate() is called.
31- The default value is the join of the parent relations.
32- Users may override to change the granularity or the scope of populate() calls.
54+ :return: the query expression that yields primary key values to be passed,
55+ sequentially, to the ``make`` method when populate() is called.
56+ The default value is the join of the parent tables references from the primary key.
57+ Subclasses may override they key_source to change the scope or the granularity
58+ of the make calls.
3359 """
3460 def _rename_attributes (table , props ):
3561 return (table .proj (
3662 ** {attr : ref for attr , ref in props ['attr_map' ].items () if attr != ref })
37- if props ['aliased' ] else table )
63+ if props ['aliased' ] else table . proj () )
3864
3965 if self ._key_source is None :
40- parents = self .target .parents (primary = True , as_objects = True , foreign_key_info = True )
66+ parents = self .target .parents (
67+ primary = True , as_objects = True , foreign_key_info = True )
4168 if not parents :
4269 raise DataJointError ('A table must have dependencies '
4370 'from its primary key for auto-populate to work' )
@@ -48,17 +75,19 @@ def _rename_attributes(table, props):
4875
4976 def make (self , key ):
5077 """
51- Derived classes must implement method `make` that fetches data from tables that are
52- above them in the dependency hierarchy, restricting by the given key, computes dependent
53- attributes, and inserts the new tuples into self.
78+ Derived classes must implement method `make` that fetches data from tables
79+ above them in the dependency hierarchy, restricting by the given key,
80+ computes secondary attributes, and inserts the new tuples into self.
5481 """
55- raise NotImplementedError ('Subclasses of AutoPopulate must implement the method `make`' )
82+ raise NotImplementedError (
83+ 'Subclasses of AutoPopulate must implement the method `make`' )
5684
5785 @property
5886 def target (self ):
5987 """
6088 :return: table to be populated.
61- In the typical case, dj.AutoPopulate is mixed into a dj.Table class by inheritance and the target is self.
89+ In the typical case, dj.AutoPopulate is mixed into a dj.Table class by
90+ inheritance and the target is self.
6291 """
6392 return self
6493
@@ -85,41 +114,49 @@ def _jobs_to_do(self, restrictions):
85114
86115 if not isinstance (todo , QueryExpression ):
87116 raise DataJointError ('Invalid key_source value' )
88- # check if target lacks any attributes from the primary key of key_source
117+
89118 try :
119+ # check if target lacks any attributes from the primary key of key_source
90120 raise DataJointError (
91- 'The populate target lacks attribute %s from the primary key of key_source' % next (
92- name for name in todo .heading .primary_key if name not in self .target .heading ))
121+ 'The populate target lacks attribute %s '
122+ 'from the primary key of key_source' % next (
123+ name for name in todo .heading .primary_key
124+ if name not in self .target .heading ))
93125 except StopIteration :
94126 pass
95127 return (todo & AndList (restrictions )).proj ()
96128
97129 def populate (self , * restrictions , suppress_errors = False , return_exception_objects = False ,
98130 reserve_jobs = False , order = "original" , limit = None , max_calls = None ,
99- display_progress = False , make_kwargs = None ):
131+ display_progress = False , processes = 1 , make_kwargs = None ):
100132 """
101- rel.populate() calls rel.make(key) for every primary key in self.key_source
102- for which there is not already a tuple in rel.
103- :param restrictions: a list of restrictions each restrict (rel.key_source - target.proj())
133+ table.populate() calls table.make(key) for every primary key in self.key_source
134+ for which there is not already a tuple in table.
135+ :param restrictions: a list of restrictions each restrict
136+ (table.key_source - target.proj())
104137 :param suppress_errors: if True, do not terminate execution.
105138 :param return_exception_objects: return error objects instead of just error messages
106- :param reserve_jobs: if true, reserves job to populate in asynchronous fashion
139+ :param reserve_jobs: if True, reserve jobs to populate in asynchronous fashion
107140 :param order: "original"|"reverse"|"random" - the order of execution
141+ :param limit: if not None, check at most this many keys
142+ :param max_calls: if not None, populate at most this many keys
108143 :param display_progress: if True, report progress_bar
109144 :param limit: if not None, checks at most that many keys
110145 :param max_calls: if not None, populates at max that many keys
111- :param make_kwargs: optional dict containing keyword arguments that will be passed down to each make() call
146+ :param processes: number of processes to use. When set to a large number, then
147+ uses as many as CPU cores
148+ :param make_kwargs: optional dict containing keyword arguments that will be
149+ passed down to each make() call
112150 """
113151 if self .connection .in_transaction :
114152 raise DataJointError ('Populate cannot be called during a transaction.' )
115153
116154 valid_order = ['original' , 'reverse' , 'random' ]
117155 if order not in valid_order :
118156 raise DataJointError ('The order argument must be one of %s' % str (valid_order ))
119- error_list = [] if suppress_errors else None
120157 jobs = self .connection .schemas [self .target .database ].jobs if reserve_jobs else None
121158
122- # define and setup signal handler for SIGTERM
159+ # define and set up signal handler for SIGTERM:
123160 if reserve_jobs :
124161 def handler (signum , frame ):
125162 logger .info ('Populate terminated by SIGTERM' )
@@ -132,60 +169,99 @@ def handler(signum, frame):
132169 elif order == "random" :
133170 random .shuffle (keys )
134171
135- call_count = 0
136172 logger .info ('Found %d keys to populate' % len (keys ))
137173
138- make = self ._make_tuples if hasattr (self , '_make_tuples' ) else self .make
174+ keys = keys [:max_calls ]
175+ nkeys = len (keys )
139176
140- for key in (tqdm (keys , desc = self .__class__ .__name__ ) if display_progress else keys ):
141- if max_calls is not None and call_count >= max_calls :
142- break
143- if not reserve_jobs or jobs .reserve (self .target .table_name , self ._job_key (key )):
144- self .connection .start_transaction ()
145- if key in self .target : # already populated
146- self .connection .cancel_transaction ()
147- if reserve_jobs :
148- jobs .complete (self .target .table_name , self ._job_key (key ))
177+ if processes > 1 :
178+ processes = min (processes , nkeys , mp .cpu_count ())
179+
180+ error_list = []
181+ populate_kwargs = dict (
182+ suppress_errors = suppress_errors ,
183+ return_exception_objects = return_exception_objects )
184+
185+ if processes == 1 :
186+ for key in tqdm (keys , desc = self .__class__ .__name__ ) if display_progress else keys :
187+ error = self ._populate1 (key , jobs , ** populate_kwargs )
188+ if error is not None :
189+ error_list .append (error )
190+ else :
191+ # spawn multiple processes
192+ self .connection .close () # disconnect parent process from MySQL server
193+ del self .connection ._conn .ctx # SSLContext is not pickleable
194+ with mp .Pool (processes , _initialize_populate , (self , populate_kwargs )) as pool :
195+ if display_progress :
196+ with tqdm (desc = "Processes: " , total = nkeys ) as pbar :
197+ for error in pool .imap (_call_populate1 , keys , chunksize = 1 ):
198+ if error is not None :
199+ error_list .append (error )
200+ pbar .update ()
149201 else :
150- logger .info ('Populating: ' + str (key ))
151- call_count += 1
152- self .__class__ ._allow_insert = True
153- try :
154- make (dict (key ), ** (make_kwargs or {}))
155- except (KeyboardInterrupt , SystemExit , Exception ) as error :
156- try :
157- self .connection .cancel_transaction ()
158- except LostConnectionError :
159- pass
160- error_message = '{exception}{msg}' .format (
161- exception = error .__class__ .__name__ ,
162- msg = ': ' + str (error ) if str (error ) else '' )
163- if reserve_jobs :
164- # show error name and error message (if any)
165- jobs .error (
166- self .target .table_name , self ._job_key (key ),
167- error_message = error_message , error_stack = traceback .format_exc ())
168- if not suppress_errors or isinstance (error , SystemExit ):
169- raise
170- else :
171- logger .error (error )
172- error_list .append ((key , error if return_exception_objects else error_message ))
173- else :
174- self .connection .commit_transaction ()
175- if reserve_jobs :
176- jobs .complete (self .target .table_name , self ._job_key (key ))
177- finally :
178- self .__class__ ._allow_insert = False
202+ for error in pool .imap (_call_populate1 , keys ):
203+ if error is not None :
204+ error_list .append (error )
205+ self .connection .connect () # reconnect parent process to MySQL server
179206
180- # place back the original signal handler
207+ # restore original signal handler:
181208 if reserve_jobs :
182209 signal .signal (signal .SIGTERM , old_handler )
183- return error_list
210+
211+ if suppress_errors :
212+ return error_list
213+
214+ def _populate1 (self , key , jobs , suppress_errors , return_exception_objects ):
215+ """
216+ populates table for one source key, calling self.make inside a transaction.
217+ :param jobs: the jobs table or None if not reserve_jobs
218+ :param key: dict specifying job to populate
219+ :param suppress_errors: bool if errors should be suppressed and returned
220+ :param return_exception_objects: if True, errors must be returned as objects
221+ :return: (key, error) when suppress_errors=True, otherwise None
222+ """
223+ make = self ._make_tuples if hasattr (self , '_make_tuples' ) else self .make
224+
225+ if jobs is None or jobs .reserve (self .target .table_name , self ._job_key (key )):
226+ self .connection .start_transaction ()
227+ if key in self .target : # already populated
228+ self .connection .cancel_transaction ()
229+ if jobs is not None :
230+ jobs .complete (self .target .table_name , self ._job_key (key ))
231+ else :
232+ logger .info ('Populating: ' + str (key ))
233+ self .__class__ ._allow_insert = True
234+ try :
235+ make (dict (key ))
236+ except (KeyboardInterrupt , SystemExit , Exception ) as error :
237+ try :
238+ self .connection .cancel_transaction ()
239+ except LostConnectionError :
240+ pass
241+ error_message = '{exception}{msg}' .format (
242+ exception = error .__class__ .__name__ ,
243+ msg = ': ' + str (error ) if str (error ) else '' )
244+ if jobs is not None :
245+ # show error name and error message (if any)
246+ jobs .error (
247+ self .target .table_name , self ._job_key (key ),
248+ error_message = error_message , error_stack = traceback .format_exc ())
249+ if not suppress_errors or isinstance (error , SystemExit ):
250+ raise
251+ else :
252+ logger .error (error )
253+ return key , error if return_exception_objects else error_message
254+ else :
255+ self .connection .commit_transaction ()
256+ if jobs is not None :
257+ jobs .complete (self .target .table_name , self ._job_key (key ))
258+ finally :
259+ self .__class__ ._allow_insert = False
184260
185261 def progress (self , * restrictions , display = True ):
186262 """
187- report progress of populating the table
188- :return: remaining, total -- tuples to be populated
263+ Report the progress of populating the table.
264+ :return: ( remaining, total) -- numbers of tuples to be populated
189265 """
190266 todo = self ._jobs_to_do (restrictions )
191267 total = len (todo )
@@ -194,5 +270,6 @@ def progress(self, *restrictions, display=True):
194270 print ('%-20s' % self .__class__ .__name__ ,
195271 'Completed %d of %d (%2.1f%%) %s' % (
196272 total - remaining , total , 100 - 100 * remaining / (total + 1e-12 ),
197- datetime .datetime .strftime (datetime .datetime .now (), '%Y-%m-%d %H:%M:%S' )), flush = True )
273+ datetime .datetime .strftime (datetime .datetime .now (),
274+ '%Y-%m-%d %H:%M:%S' )), flush = True )
198275 return remaining , total
0 commit comments