Skip to content

Commit a91f239

Browse files
authored
Merge pull request #969 from dimitri-yatsenko/mp
Implement multiprocessing in `populate`
2 parents 9b8c4eb + 9ba4e86 commit a91f239

File tree

8 files changed

+133
-61
lines changed

8 files changed

+133
-61
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
* Bugfix - Fix Python 3.10 compatibility (#983) PR #972
66
* Bugfix - Allow renaming non-conforming attributes in proj (#982) PR #972
77
* Add - Expose proxy feature for S3 external stores (#961) PR #962
8+
* Add - implement multiprocessing in populate (#695) PR #704, #969
89
* Bugfix - Dependencies not properly loaded on populate. (#902) PR #919
910
* Bugfix - Replace use of numpy aliases of built-in types with built-in type. (#938) PR #939
1011
* Bugfix - `ExternalTable.delete` should not remove row on error (#953) PR #956

datajoint/autopopulate.py

Lines changed: 121 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -8,12 +8,37 @@
88
from .expression import QueryExpression, AndList
99
from .errors import DataJointError, LostConnectionError
1010
import signal
11+
import multiprocessing as mp
1112

1213
# noinspection PyExceptionInherit,PyCallingNonCallable
1314

1415
logger = logging.getLogger(__name__)
1516

1617

18+
# --- helper functions for multiprocessing --
19+
20+
def _initialize_populate(table, jobs, populate_kwargs):
21+
"""
22+
Initialize the process for mulitprocessing.
23+
Saves the unpickled copy of the table to the current process and reconnects.
24+
"""
25+
process = mp.current_process()
26+
process.table = table
27+
process.jobs = jobs
28+
process.populate_kwargs = populate_kwargs
29+
table.connection.connect() # reconnect
30+
31+
32+
def _call_populate1(key):
33+
"""
34+
Call current process' table._populate1()
35+
:key - a dict specifying job to compute
36+
:return: key, error if error, otherwise None
37+
"""
38+
process = mp.current_process()
39+
return process.table._populate1(key, process.jobs, **process.populate_kwargs)
40+
41+
1742
class AutoPopulate:
1843
"""
1944
AutoPopulate is a mixin class that adds the method populate() to a Relation class.
@@ -28,8 +53,9 @@ def key_source(self):
2853
"""
2954
:return: the query expression that yields primary key values to be passed,
3055
sequentially, to the ``make`` method when populate() is called.
31-
The default value is the join of the parent relations.
32-
Users may override to change the granularity or the scope of populate() calls.
56+
The default value is the join of the parent tables references from the primary key.
57+
Subclasses may override they key_source to change the scope or the granularity
58+
of the make calls.
3359
"""
3460
def _rename_attributes(table, props):
3561
return (table.proj(
@@ -96,29 +122,30 @@ def _jobs_to_do(self, restrictions):
96122

97123
def populate(self, *restrictions, suppress_errors=False, return_exception_objects=False,
98124
reserve_jobs=False, order="original", limit=None, max_calls=None,
99-
display_progress=False):
125+
display_progress=False, processes=1):
100126
"""
101-
rel.populate() calls rel.make(key) for every primary key in self.key_source
102-
for which there is not already a tuple in rel.
103-
:param restrictions: a list of restrictions each restrict (rel.key_source - target.proj())
127+
table.populate() calls table.make(key) for every primary key in self.key_source
128+
for which there is not already a tuple in table.
129+
:param restrictions: a list of restrictions each restrict (table.key_source - target.proj())
104130
:param suppress_errors: if True, do not terminate execution.
105131
:param return_exception_objects: return error objects instead of just error messages
106-
:param reserve_jobs: if true, reserves job to populate in asynchronous fashion
132+
:param reserve_jobs: if True, reserve jobs to populate in asynchronous fashion
107133
:param order: "original"|"reverse"|"random" - the order of execution
134+
:param limit: if not None, check at most this many keys
135+
:param max_calls: if not None, populate at most this many keys
108136
:param display_progress: if True, report progress_bar
109-
:param limit: if not None, checks at most that many keys
110-
:param max_calls: if not None, populates at max that many keys
137+
:param processes: number of processes to use. When set to a large number, then
138+
uses as many as CPU cores
111139
"""
112140
if self.connection.in_transaction:
113141
raise DataJointError('Populate cannot be called during a transaction.')
114142

115143
valid_order = ['original', 'reverse', 'random']
116144
if order not in valid_order:
117145
raise DataJointError('The order argument must be one of %s' % str(valid_order))
118-
error_list = [] if suppress_errors else None
119146
jobs = self.connection.schemas[self.target.database].jobs if reserve_jobs else None
120147

121-
# define and setup signal handler for SIGTERM
148+
# define and set up signal handler for SIGTERM:
122149
if reserve_jobs:
123150
def handler(signum, frame):
124151
logger.info('Populate terminated by SIGTERM')
@@ -131,60 +158,99 @@ def handler(signum, frame):
131158
elif order == "random":
132159
random.shuffle(keys)
133160

134-
call_count = 0
135161
logger.info('Found %d keys to populate' % len(keys))
136162

137-
make = self._make_tuples if hasattr(self, '_make_tuples') else self.make
163+
keys = keys[:max_calls]
164+
nkeys = len(keys)
138165

139-
for key in (tqdm(keys, desc=self.__class__.__name__) if display_progress else keys):
140-
if max_calls is not None and call_count >= max_calls:
141-
break
142-
if not reserve_jobs or jobs.reserve(self.target.table_name, self._job_key(key)):
143-
self.connection.start_transaction()
144-
if key in self.target: # already populated
145-
self.connection.cancel_transaction()
146-
if reserve_jobs:
147-
jobs.complete(self.target.table_name, self._job_key(key))
166+
if processes > 1:
167+
processes = min(processes, nkeys, mp.cpu_count())
168+
169+
error_list = []
170+
populate_kwargs = dict(
171+
suppress_errors=suppress_errors,
172+
return_exception_objects=return_exception_objects)
173+
174+
if processes == 1:
175+
for key in tqdm(keys, desc=self.__class__.__name__) if display_progress else keys:
176+
error = self._populate1(key, jobs, **populate_kwargs)
177+
if error is not None:
178+
error_list.append(error)
179+
else:
180+
# spawn multiple processes
181+
self.connection.close() # disconnect parent process from MySQL server
182+
del self.connection._conn.ctx # SSLContext is not pickleable
183+
with mp.Pool(processes, _initialize_populate, (self, populate_kwargs)) as pool:
184+
if display_progress:
185+
with tqdm(desc="Processes: ", total=nkeys) as pbar:
186+
for error in pool.imap(_call_populate1, keys, chunksize=1):
187+
if error is not None:
188+
error_list.append(error)
189+
pbar.update()
148190
else:
149-
logger.info('Populating: ' + str(key))
150-
call_count += 1
151-
self.__class__._allow_insert = True
152-
try:
153-
make(dict(key))
154-
except (KeyboardInterrupt, SystemExit, Exception) as error:
155-
try:
156-
self.connection.cancel_transaction()
157-
except LostConnectionError:
158-
pass
159-
error_message = '{exception}{msg}'.format(
160-
exception=error.__class__.__name__,
161-
msg=': ' + str(error) if str(error) else '')
162-
if reserve_jobs:
163-
# show error name and error message (if any)
164-
jobs.error(
165-
self.target.table_name, self._job_key(key),
166-
error_message=error_message, error_stack=traceback.format_exc())
167-
if not suppress_errors or isinstance(error, SystemExit):
168-
raise
169-
else:
170-
logger.error(error)
171-
error_list.append((key, error if return_exception_objects else error_message))
172-
else:
173-
self.connection.commit_transaction()
174-
if reserve_jobs:
175-
jobs.complete(self.target.table_name, self._job_key(key))
176-
finally:
177-
self.__class__._allow_insert = False
191+
for error in pool.imap(_call_populate1, keys):
192+
if error is not None:
193+
error_list.append(error)
194+
self.connection.connect() # reconnect parent process to MySQL server
178195

179-
# place back the original signal handler
196+
# restore original signal handler:
180197
if reserve_jobs:
181198
signal.signal(signal.SIGTERM, old_handler)
182-
return error_list
199+
200+
if suppress_errors:
201+
return error_list
202+
203+
def _populate1(self, key, jobs, suppress_errors, return_exception_objects):
204+
"""
205+
populates table for one source key, calling self.make inside a transaction.
206+
:param jobs: the jobs table or None if not reserve_jobs
207+
:param key: dict specifying job to populate
208+
:param suppress_errors: bool if errors should be suppressed and returned
209+
:param return_exception_objects: if True, errors must be returned as objects
210+
:return: (key, error) when suppress_errors=True, otherwise None
211+
"""
212+
make = self._make_tuples if hasattr(self, '_make_tuples') else self.make
213+
214+
if jobs is None or jobs.reserve(self.target.table_name, self._job_key(key)):
215+
self.connection.start_transaction()
216+
if key in self.target: # already populated
217+
self.connection.cancel_transaction()
218+
if jobs is not None:
219+
jobs.complete(self.target.table_name, self._job_key(key))
220+
else:
221+
logger.info('Populating: ' + str(key))
222+
self.__class__._allow_insert = True
223+
try:
224+
make(dict(key))
225+
except (KeyboardInterrupt, SystemExit, Exception) as error:
226+
try:
227+
self.connection.cancel_transaction()
228+
except LostConnectionError:
229+
pass
230+
error_message = '{exception}{msg}'.format(
231+
exception=error.__class__.__name__,
232+
msg=': ' + str(error) if str(error) else '')
233+
if jobs is not None:
234+
# show error name and error message (if any)
235+
jobs.error(
236+
self.target.table_name, self._job_key(key),
237+
error_message=error_message, error_stack=traceback.format_exc())
238+
if not suppress_errors or isinstance(error, SystemExit):
239+
raise
240+
else:
241+
logger.error(error)
242+
return key, error if return_exception_objects else error_message
243+
else:
244+
self.connection.commit_transaction()
245+
if jobs is not None:
246+
jobs.complete(self.target.table_name, self._job_key(key))
247+
finally:
248+
self.__class__._allow_insert = False
183249

184250
def progress(self, *restrictions, display=True):
185251
"""
186-
report progress of populating the table
187-
:return: remaining, total -- tuples to be populated
252+
Report the progress of populating the table.
253+
:return: (remaining, total) -- numbers of tuples to be populated
188254
"""
189255
todo = self._jobs_to_do(restrictions)
190256
total = len(todo)

datajoint/blob.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -166,6 +166,8 @@ def pack_blob(self, obj):
166166
return self.pack_array(np.array(obj))
167167
if isinstance(obj, (bool, np.bool_)):
168168
return self.pack_array(np.array(obj))
169+
if isinstance(obj, (float, int, complex)):
170+
return self.pack_array(np.array(obj))
169171
if isinstance(obj, (datetime.datetime, datetime.date, datetime.time)):
170172
return self.pack_datetime(obj)
171173
if isinstance(obj, Decimal):

datajoint/connection.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -278,7 +278,7 @@ def query(self, query, args=(), *, as_dict=False, suppress_warnings=True, reconn
278278
# check cache first:
279279
use_query_cache = bool(self._query_cache)
280280
if use_query_cache and not re.match(r"\s*(SELECT|SHOW)", query):
281-
raise errors.DataJointError("Only SELECT query are allowed when query caching is on.")
281+
raise errors.DataJointError("Only SELECT queries are allowed when query caching is on.")
282282
if use_query_cache:
283283
if not config['query_cache']:
284284
raise errors.DataJointError("Provide filepath dj.config['query_cache'] when using query caching.")

datajoint/expression.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ class QueryExpression:
3737
"""
3838
_restriction = None
3939
_restriction_attributes = None
40-
_left = [] # True for left joins, False for inner joins
40+
_left = [] # list of booleans True for left joins, False for inner joins
4141
_original_heading = None # heading before projections
4242

4343
# subclasses or instantiators must provide values
@@ -263,7 +263,7 @@ def join(self, other, semantic_check=True, left=False):
263263
if semantic_check:
264264
assert_join_compatibility(self, other)
265265
join_attributes = set(n for n in self.heading.names if n in other.heading.names)
266-
# needs subquery if FROM class has common attributes with the other's FROM clause
266+
# needs subquery if self's FROM clause has common attributes with other's FROM clause
267267
need_subquery1 = need_subquery2 = bool(
268268
(set(self.original_heading.names) & set(other.original_heading.names))
269269
- join_attributes)
@@ -306,7 +306,7 @@ def proj(self, *attributes, **named_attributes):
306306
self.proj(...) or self.proj(Ellipsis) -- include all attributes (return self)
307307
self.proj() -- include only primary key
308308
self.proj('attr1', 'attr2') -- include primary key and attributes attr1 and attr2
309-
self.proj(..., '-attr1', '-attr2') -- include attributes except attr1 and attr2
309+
self.proj(..., '-attr1', '-attr2') -- include all attributes except attr1 and attr2
310310
self.proj(name1='attr1') -- include primary key and 'attr1' renamed as name1
311311
self.proj('attr1', dup='(attr1)') -- include primary key and attribute attr1 twice, with the duplicate 'dup'
312312
self.proj(k='abs(attr1)') adds the new attribute k with the value computed as an expression (SQL syntax)

docs-parts/intro/Releases_lang1.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
* Bugfix - Fix Python 3.10 compatibility (#983) PR #972
55
* Bugfix - Allow renaming non-conforming attributes in proj (#982) PR #972
66
* Add - Expose proxy feature for S3 external stores (#961) PR #962
7+
* Add - implement multiprocessing in populate (#695) PR #704, #969
78
* Bugfix - Dependencies not properly loaded on populate. (#902) PR #919
89
* Bugfix - Replace use of numpy aliases of built-in types with built-in type. (#938) PR #939
910
* Bugfix - `ExternalTable.delete` should not remove row on error (#953) PR #956

tests/test_autopopulate.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@ class TestPopulate:
88
"""
99
Test base relations: insert, delete
1010
"""
11-
1211
def setUp(self):
1312
self.user = schema.User()
1413
self.subject = schema.Subject()
@@ -53,7 +52,7 @@ def test_populate(self):
5352

5453
def test_allow_direct_insert(self):
5554
assert_true(self.subject, 'root tables are empty')
56-
key = self.subject.fetch('KEY')[0]
55+
key = self.subject.fetch('KEY', limit=1)[0]
5756
key['experiment_id'] = 1000
5857
key['experiment_date'] = '2018-10-30'
5958
self.experiment.insert1(key, allow_direct_insert=True)

tests/test_blob.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,9 @@ def test_pack():
2323
x = np.random.randn(10)
2424
assert_array_equal(x, unpack(pack(x)), "Arrays do not match!")
2525

26+
x = 7j
27+
assert_equal(x, unpack(pack(x)), "Complex scalar does not match")
28+
2629
x = np.float32(np.random.randn(3, 4, 5))
2730
assert_array_equal(x, unpack(pack(x)), "Arrays do not match!")
2831

0 commit comments

Comments
 (0)