Skip to content

Commit bd5b35f

Browse files
Merge branch 'master' into issue151
2 parents c14a07c + 1cd5d7c commit bd5b35f

File tree

8 files changed

+151
-71
lines changed

8 files changed

+151
-71
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
* Bugfix - Fix Python 3.10 compatibility (#983) PR #972
66
* Bugfix - Allow renaming non-conforming attributes in proj (#982) PR #972
77
* Add - Expose proxy feature for S3 external stores (#961) PR #962
8+
* Add - implement multiprocessing in populate (#695) PR #704, #969
89
* Bugfix - Dependencies not properly loaded on populate. (#902) PR #919
910
* Bugfix - Replace use of numpy aliases of built-in types with built-in type. (#938) PR #939
1011
* Bugfix - Deletes and drops must include the master of each part. (#151, #374) PR #957

datajoint/autopopulate.py

Lines changed: 139 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -8,12 +8,37 @@
88
from .expression import QueryExpression, AndList
99
from .errors import DataJointError, LostConnectionError
1010
import signal
11+
import multiprocessing as mp
1112

1213
# noinspection PyExceptionInherit,PyCallingNonCallable
1314

1415
logger = logging.getLogger(__name__)
1516

1617

18+
# --- helper functions for multiprocessing --
19+
20+
def _initialize_populate(table, jobs, populate_kwargs):
21+
"""
22+
Initialize the process for mulitprocessing.
23+
Saves the unpickled copy of the table to the current process and reconnects.
24+
"""
25+
process = mp.current_process()
26+
process.table = table
27+
process.jobs = jobs
28+
process.populate_kwargs = populate_kwargs
29+
table.connection.connect() # reconnect
30+
31+
32+
def _call_populate1(key):
33+
"""
34+
Call current process' table._populate1()
35+
:key - a dict specifying job to compute
36+
:return: key, error if error, otherwise None
37+
"""
38+
process = mp.current_process()
39+
return process.table._populate1(key, process.jobs, **process.populate_kwargs)
40+
41+
1742
class AutoPopulate:
1843
"""
1944
AutoPopulate is a mixin class that adds the method populate() to a Relation class.
@@ -28,16 +53,18 @@ def key_source(self):
2853
"""
2954
:return: the query expression that yields primary key values to be passed,
3055
sequentially, to the ``make`` method when populate() is called.
31-
The default value is the join of the parent relations.
32-
Users may override to change the granularity or the scope of populate() calls.
56+
The default value is the join of the parent tables references from the primary key.
57+
Subclasses may override they key_source to change the scope or the granularity
58+
of the make calls.
3359
"""
3460
def _rename_attributes(table, props):
3561
return (table.proj(
3662
**{attr: ref for attr, ref in props['attr_map'].items() if attr != ref})
3763
if props['aliased'] else table.proj())
3864

3965
if self._key_source is None:
40-
parents = self.target.parents(primary=True, as_objects=True, foreign_key_info=True)
66+
parents = self.target.parents(
67+
primary=True, as_objects=True, foreign_key_info=True)
4168
if not parents:
4269
raise DataJointError('A table must have dependencies '
4370
'from its primary key for auto-populate to work')
@@ -48,17 +75,19 @@ def _rename_attributes(table, props):
4875

4976
def make(self, key):
5077
"""
51-
Derived classes must implement method `make` that fetches data from tables that are
52-
above them in the dependency hierarchy, restricting by the given key, computes dependent
53-
attributes, and inserts the new tuples into self.
78+
Derived classes must implement method `make` that fetches data from tables
79+
above them in the dependency hierarchy, restricting by the given key,
80+
computes secondary attributes, and inserts the new tuples into self.
5481
"""
55-
raise NotImplementedError('Subclasses of AutoPopulate must implement the method `make`')
82+
raise NotImplementedError(
83+
'Subclasses of AutoPopulate must implement the method `make`')
5684

5785
@property
5886
def target(self):
5987
"""
6088
:return: table to be populated.
61-
In the typical case, dj.AutoPopulate is mixed into a dj.Table class by inheritance and the target is self.
89+
In the typical case, dj.AutoPopulate is mixed into a dj.Table class by
90+
inheritance and the target is self.
6291
"""
6392
return self
6493

@@ -85,40 +114,45 @@ def _jobs_to_do(self, restrictions):
85114

86115
if not isinstance(todo, QueryExpression):
87116
raise DataJointError('Invalid key_source value')
88-
# check if target lacks any attributes from the primary key of key_source
117+
89118
try:
119+
# check if target lacks any attributes from the primary key of key_source
90120
raise DataJointError(
91-
'The populate target lacks attribute %s from the primary key of key_source' % next(
92-
name for name in todo.heading.primary_key if name not in self.target.heading))
121+
'The populate target lacks attribute %s '
122+
'from the primary key of key_source' % next(
123+
name for name in todo.heading.primary_key
124+
if name not in self.target.heading))
93125
except StopIteration:
94126
pass
95127
return (todo & AndList(restrictions)).proj()
96128

97129
def populate(self, *restrictions, suppress_errors=False, return_exception_objects=False,
98130
reserve_jobs=False, order="original", limit=None, max_calls=None,
99-
display_progress=False):
131+
display_progress=False, processes=1):
100132
"""
101-
rel.populate() calls rel.make(key) for every primary key in self.key_source
102-
for which there is not already a tuple in rel.
103-
:param restrictions: a list of restrictions each restrict (rel.key_source - target.proj())
133+
table.populate() calls table.make(key) for every primary key in self.key_source
134+
for which there is not already a tuple in table.
135+
:param restrictions: a list of restrictions each restrict
136+
(table.key_source - target.proj())
104137
:param suppress_errors: if True, do not terminate execution.
105138
:param return_exception_objects: return error objects instead of just error messages
106-
:param reserve_jobs: if true, reserves job to populate in asynchronous fashion
139+
:param reserve_jobs: if True, reserve jobs to populate in asynchronous fashion
107140
:param order: "original"|"reverse"|"random" - the order of execution
141+
:param limit: if not None, check at most this many keys
142+
:param max_calls: if not None, populate at most this many keys
108143
:param display_progress: if True, report progress_bar
109-
:param limit: if not None, checks at most that many keys
110-
:param max_calls: if not None, populates at max that many keys
144+
:param processes: number of processes to use. When set to a large number, then
145+
uses as many as CPU cores
111146
"""
112147
if self.connection.in_transaction:
113148
raise DataJointError('Populate cannot be called during a transaction.')
114149

115150
valid_order = ['original', 'reverse', 'random']
116151
if order not in valid_order:
117152
raise DataJointError('The order argument must be one of %s' % str(valid_order))
118-
error_list = [] if suppress_errors else None
119153
jobs = self.connection.schemas[self.target.database].jobs if reserve_jobs else None
120154

121-
# define and setup signal handler for SIGTERM
155+
# define and set up signal handler for SIGTERM:
122156
if reserve_jobs:
123157
def handler(signum, frame):
124158
logger.info('Populate terminated by SIGTERM')
@@ -131,60 +165,99 @@ def handler(signum, frame):
131165
elif order == "random":
132166
random.shuffle(keys)
133167

134-
call_count = 0
135168
logger.info('Found %d keys to populate' % len(keys))
136169

137-
make = self._make_tuples if hasattr(self, '_make_tuples') else self.make
170+
keys = keys[:max_calls]
171+
nkeys = len(keys)
138172

139-
for key in (tqdm(keys, desc=self.__class__.__name__) if display_progress else keys):
140-
if max_calls is not None and call_count >= max_calls:
141-
break
142-
if not reserve_jobs or jobs.reserve(self.target.table_name, self._job_key(key)):
143-
self.connection.start_transaction()
144-
if key in self.target: # already populated
145-
self.connection.cancel_transaction()
146-
if reserve_jobs:
147-
jobs.complete(self.target.table_name, self._job_key(key))
173+
if processes > 1:
174+
processes = min(processes, nkeys, mp.cpu_count())
175+
176+
error_list = []
177+
populate_kwargs = dict(
178+
suppress_errors=suppress_errors,
179+
return_exception_objects=return_exception_objects)
180+
181+
if processes == 1:
182+
for key in tqdm(keys, desc=self.__class__.__name__) if display_progress else keys:
183+
error = self._populate1(key, jobs, **populate_kwargs)
184+
if error is not None:
185+
error_list.append(error)
186+
else:
187+
# spawn multiple processes
188+
self.connection.close() # disconnect parent process from MySQL server
189+
del self.connection._conn.ctx # SSLContext is not pickleable
190+
with mp.Pool(processes, _initialize_populate, (self, populate_kwargs)) as pool:
191+
if display_progress:
192+
with tqdm(desc="Processes: ", total=nkeys) as pbar:
193+
for error in pool.imap(_call_populate1, keys, chunksize=1):
194+
if error is not None:
195+
error_list.append(error)
196+
pbar.update()
148197
else:
149-
logger.info('Populating: ' + str(key))
150-
call_count += 1
151-
self.__class__._allow_insert = True
152-
try:
153-
make(dict(key))
154-
except (KeyboardInterrupt, SystemExit, Exception) as error:
155-
try:
156-
self.connection.cancel_transaction()
157-
except LostConnectionError:
158-
pass
159-
error_message = '{exception}{msg}'.format(
160-
exception=error.__class__.__name__,
161-
msg=': ' + str(error) if str(error) else '')
162-
if reserve_jobs:
163-
# show error name and error message (if any)
164-
jobs.error(
165-
self.target.table_name, self._job_key(key),
166-
error_message=error_message, error_stack=traceback.format_exc())
167-
if not suppress_errors or isinstance(error, SystemExit):
168-
raise
169-
else:
170-
logger.error(error)
171-
error_list.append((key, error if return_exception_objects else error_message))
172-
else:
173-
self.connection.commit_transaction()
174-
if reserve_jobs:
175-
jobs.complete(self.target.table_name, self._job_key(key))
176-
finally:
177-
self.__class__._allow_insert = False
198+
for error in pool.imap(_call_populate1, keys):
199+
if error is not None:
200+
error_list.append(error)
201+
self.connection.connect() # reconnect parent process to MySQL server
178202

179-
# place back the original signal handler
203+
# restore original signal handler:
180204
if reserve_jobs:
181205
signal.signal(signal.SIGTERM, old_handler)
182-
return error_list
206+
207+
if suppress_errors:
208+
return error_list
209+
210+
def _populate1(self, key, jobs, suppress_errors, return_exception_objects):
211+
"""
212+
populates table for one source key, calling self.make inside a transaction.
213+
:param jobs: the jobs table or None if not reserve_jobs
214+
:param key: dict specifying job to populate
215+
:param suppress_errors: bool if errors should be suppressed and returned
216+
:param return_exception_objects: if True, errors must be returned as objects
217+
:return: (key, error) when suppress_errors=True, otherwise None
218+
"""
219+
make = self._make_tuples if hasattr(self, '_make_tuples') else self.make
220+
221+
if jobs is None or jobs.reserve(self.target.table_name, self._job_key(key)):
222+
self.connection.start_transaction()
223+
if key in self.target: # already populated
224+
self.connection.cancel_transaction()
225+
if jobs is not None:
226+
jobs.complete(self.target.table_name, self._job_key(key))
227+
else:
228+
logger.info('Populating: ' + str(key))
229+
self.__class__._allow_insert = True
230+
try:
231+
make(dict(key))
232+
except (KeyboardInterrupt, SystemExit, Exception) as error:
233+
try:
234+
self.connection.cancel_transaction()
235+
except LostConnectionError:
236+
pass
237+
error_message = '{exception}{msg}'.format(
238+
exception=error.__class__.__name__,
239+
msg=': ' + str(error) if str(error) else '')
240+
if jobs is not None:
241+
# show error name and error message (if any)
242+
jobs.error(
243+
self.target.table_name, self._job_key(key),
244+
error_message=error_message, error_stack=traceback.format_exc())
245+
if not suppress_errors or isinstance(error, SystemExit):
246+
raise
247+
else:
248+
logger.error(error)
249+
return key, error if return_exception_objects else error_message
250+
else:
251+
self.connection.commit_transaction()
252+
if jobs is not None:
253+
jobs.complete(self.target.table_name, self._job_key(key))
254+
finally:
255+
self.__class__._allow_insert = False
183256

184257
def progress(self, *restrictions, display=True):
185258
"""
186-
report progress of populating the table
187-
:return: remaining, total -- tuples to be populated
259+
Report the progress of populating the table.
260+
:return: (remaining, total) -- numbers of tuples to be populated
188261
"""
189262
todo = self._jobs_to_do(restrictions)
190263
total = len(todo)
@@ -193,5 +266,6 @@ def progress(self, *restrictions, display=True):
193266
print('%-20s' % self.__class__.__name__,
194267
'Completed %d of %d (%2.1f%%) %s' % (
195268
total - remaining, total, 100 - 100 * remaining / (total+1e-12),
196-
datetime.datetime.strftime(datetime.datetime.now(), '%Y-%m-%d %H:%M:%S')), flush=True)
269+
datetime.datetime.strftime(datetime.datetime.now(),
270+
'%Y-%m-%d %H:%M:%S')), flush=True)
197271
return remaining, total

datajoint/blob.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -166,6 +166,8 @@ def pack_blob(self, obj):
166166
return self.pack_array(np.array(obj))
167167
if isinstance(obj, (bool, np.bool_)):
168168
return self.pack_array(np.array(obj))
169+
if isinstance(obj, (float, int, complex)):
170+
return self.pack_array(np.array(obj))
169171
if isinstance(obj, (datetime.datetime, datetime.date, datetime.time)):
170172
return self.pack_datetime(obj)
171173
if isinstance(obj, Decimal):

datajoint/connection.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -278,7 +278,7 @@ def query(self, query, args=(), *, as_dict=False, suppress_warnings=True, reconn
278278
# check cache first:
279279
use_query_cache = bool(self._query_cache)
280280
if use_query_cache and not re.match(r"\s*(SELECT|SHOW)", query):
281-
raise errors.DataJointError("Only SELECT query are allowed when query caching is on.")
281+
raise errors.DataJointError("Only SELECT queries are allowed when query caching is on.")
282282
if use_query_cache:
283283
if not config['query_cache']:
284284
raise errors.DataJointError("Provide filepath dj.config['query_cache'] when using query caching.")

datajoint/expression.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ class QueryExpression:
3737
"""
3838
_restriction = None
3939
_restriction_attributes = None
40-
_left = [] # True for left joins, False for inner joins
40+
_left = [] # list of booleans True for left joins, False for inner joins
4141
_original_heading = None # heading before projections
4242

4343
# subclasses or instantiators must provide values
@@ -263,7 +263,7 @@ def join(self, other, semantic_check=True, left=False):
263263
if semantic_check:
264264
assert_join_compatibility(self, other)
265265
join_attributes = set(n for n in self.heading.names if n in other.heading.names)
266-
# needs subquery if FROM class has common attributes with the other's FROM clause
266+
# needs subquery if self's FROM clause has common attributes with other's FROM clause
267267
need_subquery1 = need_subquery2 = bool(
268268
(set(self.original_heading.names) & set(other.original_heading.names))
269269
- join_attributes)
@@ -306,7 +306,7 @@ def proj(self, *attributes, **named_attributes):
306306
self.proj(...) or self.proj(Ellipsis) -- include all attributes (return self)
307307
self.proj() -- include only primary key
308308
self.proj('attr1', 'attr2') -- include primary key and attributes attr1 and attr2
309-
self.proj(..., '-attr1', '-attr2') -- include attributes except attr1 and attr2
309+
self.proj(..., '-attr1', '-attr2') -- include all attributes except attr1 and attr2
310310
self.proj(name1='attr1') -- include primary key and 'attr1' renamed as name1
311311
self.proj('attr1', dup='(attr1)') -- include primary key and attribute attr1 twice, with the duplicate 'dup'
312312
self.proj(k='abs(attr1)') adds the new attribute k with the value computed as an expression (SQL syntax)

docs-parts/intro/Releases_lang1.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
* Bugfix - Fix Python 3.10 compatibility (#983) PR #972
55
* Bugfix - Allow renaming non-conforming attributes in proj (#982) PR #972
66
* Add - Expose proxy feature for S3 external stores (#961) PR #962
7+
* Add - implement multiprocessing in populate (#695) PR #704, #969
78
* Bugfix - Dependencies not properly loaded on populate. (#902) PR #919
89
* Bugfix - Replace use of numpy aliases of built-in types with built-in type. (#938) PR #939
910
* Bugfix - Deletes and drops must include the master of each part. (#151, #374) PR #957

tests/test_autopopulate.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@ class TestPopulate:
88
"""
99
Test base relations: insert, delete
1010
"""
11-
1211
def setUp(self):
1312
self.user = schema.User()
1413
self.subject = schema.Subject()
@@ -53,7 +52,7 @@ def test_populate(self):
5352

5453
def test_allow_direct_insert(self):
5554
assert_true(self.subject, 'root tables are empty')
56-
key = self.subject.fetch('KEY')[0]
55+
key = self.subject.fetch('KEY', limit=1)[0]
5756
key['experiment_id'] = 1000
5857
key['experiment_date'] = '2018-10-30'
5958
self.experiment.insert1(key, allow_direct_insert=True)

tests/test_blob.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,9 @@ def test_pack():
2323
x = np.random.randn(10)
2424
assert_array_equal(x, unpack(pack(x)), "Arrays do not match!")
2525

26+
x = 7j
27+
assert_equal(x, unpack(pack(x)), "Complex scalar does not match")
28+
2629
x = np.float32(np.random.randn(3, 4, 5))
2730
assert_array_equal(x, unpack(pack(x)), "Arrays do not match!")
2831

0 commit comments

Comments
 (0)