Skip to content

Commit 40576ce

Browse files
authored
Merge pull request #2 from A-Baji/populate-kwargs
Moved implementation to `_populate1`
2 parents 46b2705 + f6da47b commit 40576ce

19 files changed

+451
-185
lines changed

CHANGELOG.md

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,14 +5,17 @@
55
* Bugfix - Fix Python 3.10 compatibility (#983) PR #972
66
* Bugfix - Allow renaming non-conforming attributes in proj (#982) PR #972
77
* Add - Expose proxy feature for S3 external stores (#961) PR #962
8+
* Add - implement multiprocessing in populate (#695) PR #704, #969
89
* Bugfix - Dependencies not properly loaded on populate. (#902) PR #919
910
* Bugfix - Replace use of numpy aliases of built-in types with built-in type. (#938) PR #939
11+
* Bugfix - Deletes and drops must include the master of each part. (#151, #374) PR #957
1012
* Bugfix - `ExternalTable.delete` should not remove row on error (#953) PR #956
1113
* Bugfix - Fix error handling of remove_object function in `s3.py` (#952) PR #955
1214
* Bugfix - Fix regression issue with `DISTINCT` clause and `GROUP_BY` (#914) PR #963
1315
* Bugfix - Fix sql code generation to comply with sql mode `ONLY_FULL_GROUP_BY` (#916) PR #965
1416
* Bugfix - Fix count for left-joined `QueryExpressions` (#951) PR #966
1517
* Bugfix - Fix assertion error when performing a union into a join (#930) PR #967
18+
* Update `~jobs.error_stack` from blob to mediumblob to allow error stacks >64kB in jobs (#984) PR #986
1619

1720
### 0.13.2 -- May 7, 2021
1821
* Update `setuptools_certificate` dependency to new name `otumat`
@@ -133,7 +136,7 @@
133136
* Fix #628 - incompatibility with pyparsing 2.4.1
134137

135138
### 0.11.1 -- Nov 15, 2018
136-
* Fix ordering of attributes in proj (#483 and #516)
139+
* Fix ordering of attributes in proj (#483, #516)
137140
* Prohibit direct insert into auto-populated tables (#511)
138141

139142
### 0.11.0 -- Oct 25, 2018
@@ -246,7 +249,7 @@ Documentation and tutorials available at https://docs.datajoint.io and https://t
246249
* ERD() no longer text the context argument.
247250
* ERD.draw() now takes an optional context argument. By default uses the caller's locals.
248251

249-
### 0.3.2.
252+
### 0.3.2.
250253
* Fixed issue #223: `insert` can insert relations without fetching.
251254
* ERD() now takes the `context` argument, which specifies in which context to look for classes. The default is taken from the argument (schema or relation).
252255
* ERD.draw() no longer has the `prefix` argument: class names are shown as found in the context.

datajoint/autopopulate.py

Lines changed: 148 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -8,12 +8,37 @@
88
from .expression import QueryExpression, AndList
99
from .errors import DataJointError, LostConnectionError
1010
import signal
11+
import multiprocessing as mp
1112

1213
# noinspection PyExceptionInherit,PyCallingNonCallable
1314

1415
logger = logging.getLogger(__name__)
1516

1617

18+
# --- helper functions for multiprocessing --
19+
20+
def _initialize_populate(table, jobs, populate_kwargs):
21+
"""
22+
Initialize the process for mulitprocessing.
23+
Saves the unpickled copy of the table to the current process and reconnects.
24+
"""
25+
process = mp.current_process()
26+
process.table = table
27+
process.jobs = jobs
28+
process.populate_kwargs = populate_kwargs
29+
table.connection.connect() # reconnect
30+
31+
32+
def _call_populate1(key):
33+
"""
34+
Call current process' table._populate1()
35+
:key - a dict specifying job to compute
36+
:return: key, error if error, otherwise None
37+
"""
38+
process = mp.current_process()
39+
return process.table._populate1(key, process.jobs, **process.populate_kwargs)
40+
41+
1742
class AutoPopulate:
1843
"""
1944
AutoPopulate is a mixin class that adds the method populate() to a Relation class.
@@ -26,18 +51,20 @@ class AutoPopulate:
2651
@property
2752
def key_source(self):
2853
"""
29-
:return: the relation whose primary key values are passed, sequentially, to the
30-
``make`` method when populate() is called.
31-
The default value is the join of the parent relations.
32-
Users may override to change the granularity or the scope of populate() calls.
54+
:return: the query expression that yields primary key values to be passed,
55+
sequentially, to the ``make`` method when populate() is called.
56+
The default value is the join of the parent tables references from the primary key.
57+
Subclasses may override they key_source to change the scope or the granularity
58+
of the make calls.
3359
"""
3460
def _rename_attributes(table, props):
3561
return (table.proj(
3662
**{attr: ref for attr, ref in props['attr_map'].items() if attr != ref})
37-
if props['aliased'] else table)
63+
if props['aliased'] else table.proj())
3864

3965
if self._key_source is None:
40-
parents = self.target.parents(primary=True, as_objects=True, foreign_key_info=True)
66+
parents = self.target.parents(
67+
primary=True, as_objects=True, foreign_key_info=True)
4168
if not parents:
4269
raise DataJointError('A table must have dependencies '
4370
'from its primary key for auto-populate to work')
@@ -48,17 +75,19 @@ def _rename_attributes(table, props):
4875

4976
def make(self, key):
5077
"""
51-
Derived classes must implement method `make` that fetches data from tables that are
52-
above them in the dependency hierarchy, restricting by the given key, computes dependent
53-
attributes, and inserts the new tuples into self.
78+
Derived classes must implement method `make` that fetches data from tables
79+
above them in the dependency hierarchy, restricting by the given key,
80+
computes secondary attributes, and inserts the new tuples into self.
5481
"""
55-
raise NotImplementedError('Subclasses of AutoPopulate must implement the method `make`')
82+
raise NotImplementedError(
83+
'Subclasses of AutoPopulate must implement the method `make`')
5684

5785
@property
5886
def target(self):
5987
"""
6088
:return: table to be populated.
61-
In the typical case, dj.AutoPopulate is mixed into a dj.Table class by inheritance and the target is self.
89+
In the typical case, dj.AutoPopulate is mixed into a dj.Table class by
90+
inheritance and the target is self.
6291
"""
6392
return self
6493

@@ -85,41 +114,50 @@ def _jobs_to_do(self, restrictions):
85114

86115
if not isinstance(todo, QueryExpression):
87116
raise DataJointError('Invalid key_source value')
88-
# check if target lacks any attributes from the primary key of key_source
117+
89118
try:
119+
# check if target lacks any attributes from the primary key of key_source
90120
raise DataJointError(
91-
'The populate target lacks attribute %s from the primary key of key_source' % next(
92-
name for name in todo.heading.primary_key if name not in self.target.heading))
121+
'The populate target lacks attribute %s '
122+
'from the primary key of key_source' % next(
123+
name for name in todo.heading.primary_key
124+
if name not in self.target.heading))
93125
except StopIteration:
94126
pass
95127
return (todo & AndList(restrictions)).proj()
96128

97129
def populate(self, *restrictions, suppress_errors=False, return_exception_objects=False,
98130
reserve_jobs=False, order="original", limit=None, max_calls=None,
99-
display_progress=False, make_kwargs=None):
131+
display_progress=False, processes=1, make_kwargs=None):
100132
"""
101-
rel.populate() calls rel.make(key) for every primary key in self.key_source
102-
for which there is not already a tuple in rel.
103-
:param restrictions: a list of restrictions each restrict (rel.key_source - target.proj())
133+
``table.populate()`` calls ``table.make(key)`` for every primary key in
134+
``self.key_source`` for which there is not already a tuple in table.
135+
136+
:param restrictions: a list of restrictions each restrict
137+
(table.key_source - target.proj())
104138
:param suppress_errors: if True, do not terminate execution.
105139
:param return_exception_objects: return error objects instead of just error messages
106-
:param reserve_jobs: if true, reserves job to populate in asynchronous fashion
140+
:param reserve_jobs: if True, reserve jobs to populate in asynchronous fashion
107141
:param order: "original"|"reverse"|"random" - the order of execution
142+
:param limit: if not None, check at most this many keys
143+
:param max_calls: if not None, populate at most this many keys
108144
:param display_progress: if True, report progress_bar
109-
:param limit: if not None, checks at most that many keys
110-
:param max_calls: if not None, populates at max that many keys
111-
:param make_kwargs: optional dict containing keyword arguments that will be passed down to each make() call
145+
:param processes: number of processes to use. When set to a large number, then
146+
uses as many as CPU cores
147+
:param make_kwargs: Keyword arguments which do not affect the result of computation
148+
to be passed down to each ``make()`` call. Computation arguments should be
149+
specified within the pipeline e.g. using a `dj.Lookup` table.
150+
:type make_kwargs: dict, optional
112151
"""
113152
if self.connection.in_transaction:
114153
raise DataJointError('Populate cannot be called during a transaction.')
115154

116155
valid_order = ['original', 'reverse', 'random']
117156
if order not in valid_order:
118157
raise DataJointError('The order argument must be one of %s' % str(valid_order))
119-
error_list = [] if suppress_errors else None
120158
jobs = self.connection.schemas[self.target.database].jobs if reserve_jobs else None
121159

122-
# define and setup signal handler for SIGTERM
160+
# define and set up signal handler for SIGTERM:
123161
if reserve_jobs:
124162
def handler(signum, frame):
125163
logger.info('Populate terminated by SIGTERM')
@@ -132,60 +170,100 @@ def handler(signum, frame):
132170
elif order == "random":
133171
random.shuffle(keys)
134172

135-
call_count = 0
136173
logger.info('Found %d keys to populate' % len(keys))
137174

138-
make = self._make_tuples if hasattr(self, '_make_tuples') else self.make
175+
keys = keys[:max_calls]
176+
nkeys = len(keys)
139177

140-
for key in (tqdm(keys, desc=self.__class__.__name__) if display_progress else keys):
141-
if max_calls is not None and call_count >= max_calls:
142-
break
143-
if not reserve_jobs or jobs.reserve(self.target.table_name, self._job_key(key)):
144-
self.connection.start_transaction()
145-
if key in self.target: # already populated
146-
self.connection.cancel_transaction()
147-
if reserve_jobs:
148-
jobs.complete(self.target.table_name, self._job_key(key))
178+
if processes > 1:
179+
processes = min(processes, nkeys, mp.cpu_count())
180+
181+
error_list = []
182+
populate_kwargs = dict(
183+
suppress_errors=suppress_errors,
184+
return_exception_objects=return_exception_objects,
185+
make_kwargs=make_kwargs)
186+
187+
if processes == 1:
188+
for key in tqdm(keys, desc=self.__class__.__name__) if display_progress else keys:
189+
error = self._populate1(key, jobs, **populate_kwargs)
190+
if error is not None:
191+
error_list.append(error)
192+
else:
193+
# spawn multiple processes
194+
self.connection.close() # disconnect parent process from MySQL server
195+
del self.connection._conn.ctx # SSLContext is not pickleable
196+
with mp.Pool(processes, _initialize_populate, (self, populate_kwargs)) as pool:
197+
if display_progress:
198+
with tqdm(desc="Processes: ", total=nkeys) as pbar:
199+
for error in pool.imap(_call_populate1, keys, chunksize=1):
200+
if error is not None:
201+
error_list.append(error)
202+
pbar.update()
149203
else:
150-
logger.info('Populating: ' + str(key))
151-
call_count += 1
152-
self.__class__._allow_insert = True
153-
try:
154-
make(dict(key), **(make_kwargs or {}))
155-
except (KeyboardInterrupt, SystemExit, Exception) as error:
156-
try:
157-
self.connection.cancel_transaction()
158-
except LostConnectionError:
159-
pass
160-
error_message = '{exception}{msg}'.format(
161-
exception=error.__class__.__name__,
162-
msg=': ' + str(error) if str(error) else '')
163-
if reserve_jobs:
164-
# show error name and error message (if any)
165-
jobs.error(
166-
self.target.table_name, self._job_key(key),
167-
error_message=error_message, error_stack=traceback.format_exc())
168-
if not suppress_errors or isinstance(error, SystemExit):
169-
raise
170-
else:
171-
logger.error(error)
172-
error_list.append((key, error if return_exception_objects else error_message))
173-
else:
174-
self.connection.commit_transaction()
175-
if reserve_jobs:
176-
jobs.complete(self.target.table_name, self._job_key(key))
177-
finally:
178-
self.__class__._allow_insert = False
204+
for error in pool.imap(_call_populate1, keys):
205+
if error is not None:
206+
error_list.append(error)
207+
self.connection.connect() # reconnect parent process to MySQL server
179208

180-
# place back the original signal handler
209+
# restore original signal handler:
181210
if reserve_jobs:
182211
signal.signal(signal.SIGTERM, old_handler)
183-
return error_list
212+
213+
if suppress_errors:
214+
return error_list
215+
216+
def _populate1(self, key, jobs, suppress_errors, return_exception_objects, make_kwargs=None):
217+
"""
218+
populates table for one source key, calling self.make inside a transaction.
219+
:param jobs: the jobs table or None if not reserve_jobs
220+
:param key: dict specifying job to populate
221+
:param suppress_errors: bool if errors should be suppressed and returned
222+
:param return_exception_objects: if True, errors must be returned as objects
223+
:return: (key, error) when suppress_errors=True, otherwise None
224+
"""
225+
make = self._make_tuples if hasattr(self, '_make_tuples') else self.make
226+
227+
if jobs is None or jobs.reserve(self.target.table_name, self._job_key(key)):
228+
self.connection.start_transaction()
229+
if key in self.target: # already populated
230+
self.connection.cancel_transaction()
231+
if jobs is not None:
232+
jobs.complete(self.target.table_name, self._job_key(key))
233+
else:
234+
logger.info('Populating: ' + str(key))
235+
self.__class__._allow_insert = True
236+
try:
237+
make(dict(key), **(make_kwargs or {}))
238+
except (KeyboardInterrupt, SystemExit, Exception) as error:
239+
try:
240+
self.connection.cancel_transaction()
241+
except LostConnectionError:
242+
pass
243+
error_message = '{exception}{msg}'.format(
244+
exception=error.__class__.__name__,
245+
msg=': ' + str(error) if str(error) else '')
246+
if jobs is not None:
247+
# show error name and error message (if any)
248+
jobs.error(
249+
self.target.table_name, self._job_key(key),
250+
error_message=error_message, error_stack=traceback.format_exc())
251+
if not suppress_errors or isinstance(error, SystemExit):
252+
raise
253+
else:
254+
logger.error(error)
255+
return key, error if return_exception_objects else error_message
256+
else:
257+
self.connection.commit_transaction()
258+
if jobs is not None:
259+
jobs.complete(self.target.table_name, self._job_key(key))
260+
finally:
261+
self.__class__._allow_insert = False
184262

185263
def progress(self, *restrictions, display=True):
186264
"""
187-
report progress of populating the table
188-
:return: remaining, total -- tuples to be populated
265+
Report the progress of populating the table.
266+
:return: (remaining, total) -- numbers of tuples to be populated
189267
"""
190268
todo = self._jobs_to_do(restrictions)
191269
total = len(todo)
@@ -194,5 +272,6 @@ def progress(self, *restrictions, display=True):
194272
print('%-20s' % self.__class__.__name__,
195273
'Completed %d of %d (%2.1f%%) %s' % (
196274
total - remaining, total, 100 - 100 * remaining / (total+1e-12),
197-
datetime.datetime.strftime(datetime.datetime.now(), '%Y-%m-%d %H:%M:%S')), flush=True)
275+
datetime.datetime.strftime(datetime.datetime.now(),
276+
'%Y-%m-%d %H:%M:%S')), flush=True)
198277
return remaining, total

datajoint/blob.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -166,6 +166,8 @@ def pack_blob(self, obj):
166166
return self.pack_array(np.array(obj))
167167
if isinstance(obj, (bool, np.bool_)):
168168
return self.pack_array(np.array(obj))
169+
if isinstance(obj, (float, int, complex)):
170+
return self.pack_array(np.array(obj))
169171
if isinstance(obj, (datetime.datetime, datetime.date, datetime.time)):
170172
return self.pack_datetime(obj)
171173
if isinstance(obj, Decimal):

datajoint/connection.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -278,7 +278,7 @@ def query(self, query, args=(), *, as_dict=False, suppress_warnings=True, reconn
278278
# check cache first:
279279
use_query_cache = bool(self._query_cache)
280280
if use_query_cache and not re.match(r"\s*(SELECT|SHOW)", query):
281-
raise errors.DataJointError("Only SELECT query are allowed when query caching is on.")
281+
raise errors.DataJointError("Only SELECT queries are allowed when query caching is on.")
282282
if use_query_cache:
283283
if not config['query_cache']:
284284
raise errors.DataJointError("Provide filepath dj.config['query_cache'] when using query caching.")

datajoint/dependencies.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,6 @@ def unite_master_parts(lst):
2626
# move from the ith position to the (j+1)th position
2727
lst[j+1:i+1] = [name] + lst[j+1:i]
2828
break
29-
else:
30-
raise DataJointError("Found a part table {name} without its master table.".format(name=name))
3129
return lst
3230

3331

0 commit comments

Comments
 (0)