Skip to content

Commit 3e6b174

Browse files
Merge pull request #704 from mspacek/mp
Add multiprocessing to AutoPopulate
2 parents e2f2b18 + b3ee8d9 commit 3e6b174

File tree

12 files changed

+169
-70
lines changed

12 files changed

+169
-70
lines changed

CHANGELOG.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,10 @@
11
## Release notes
22

3+
## 0.12.3 -- Nov 22, 2019
4+
* Bugfix #675 (PR #705) networkx 2.4+ is now supported
5+
* Bugfix #698 and #699 (PR #706) display table definition in doc string and help
6+
* Bugfix #701 (PR #702) job reservation works with native python datatype support disabled
7+
38
### 0.12.2 -- Nov 11, 2019
49
* Bugfix - Convoluted error thrown if there is a reference to a non-existent table attribute (#691)
510
* Bugfix - Insert into external does not trim leading slash if defined in `dj.config['stores']['<store>']['location']` (#692)

datajoint/autopopulate.py

Lines changed: 115 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -10,12 +10,26 @@
1010
from .errors import DataJointError
1111
from .table import FreeTable
1212
import signal
13+
import multiprocessing as mp
1314

1415
# noinspection PyExceptionInherit,PyCallingNonCallable
1516

1617
logger = logging.getLogger(__name__)
1718

1819

20+
def initializer(table):
21+
"""Save pickled copy of (disconnected) table to the current process,
22+
then reconnect to server. For use by call_make_key()"""
23+
mp.current_process().table = table
24+
table.connection.connect() # reconnect
25+
26+
def call_make_key(key):
27+
"""Call current process' table.make_key()"""
28+
table = mp.current_process().table
29+
error = table.make_key(key)
30+
return error
31+
32+
1933
class AutoPopulate:
2034
"""
2135
AutoPopulate is a mixin class that adds the method populate() to a Relation class.
@@ -103,29 +117,36 @@ def _jobs_to_do(self, restrictions):
103117

104118
def populate(self, *restrictions, suppress_errors=False, return_exception_objects=False,
105119
reserve_jobs=False, order="original", limit=None, max_calls=None,
106-
display_progress=False):
120+
display_progress=False, multiprocess=False):
107121
"""
108122
rel.populate() calls rel.make(key) for every primary key in self.key_source
109123
for which there is not already a tuple in rel.
110124
:param restrictions: a list of restrictions each restrict (rel.key_source - target.proj())
111125
:param suppress_errors: if True, do not terminate execution.
112126
:param return_exception_objects: return error objects instead of just error messages
113-
:param reserve_jobs: if true, reserves job to populate in asynchronous fashion
127+
:param reserve_jobs: if True, reserve jobs to populate in asynchronous fashion
114128
:param order: "original"|"reverse"|"random" - the order of execution
129+
:param limit: if not None, check at most this many keys
130+
:param max_calls: if not None, populate at most this many keys
115131
:param display_progress: if True, report progress_bar
116-
:param limit: if not None, checks at most that many keys
117-
:param max_calls: if not None, populates at max that many keys
132+
:param multiprocess: if True, use as many processes as CPU cores, or use the integer
133+
number of processes specified
118134
"""
119135
if self.connection.in_transaction:
120136
raise DataJointError('Populate cannot be called during a transaction.')
121137

122138
valid_order = ['original', 'reverse', 'random']
123139
if order not in valid_order:
124140
raise DataJointError('The order argument must be one of %s' % str(valid_order))
125-
error_list = [] if suppress_errors else None
126141
jobs = self.connection.schemas[self.target.database].jobs if reserve_jobs else None
127142

128-
# define and setup signal handler for SIGTERM
143+
self._make_key_kwargs = {'suppress_errors':suppress_errors,
144+
'return_exception_objects':return_exception_objects,
145+
'reserve_jobs':reserve_jobs,
146+
'jobs':jobs,
147+
}
148+
149+
# define and set up signal handler for SIGTERM:
129150
if reserve_jobs:
130151
def handler(signum, frame):
131152
logger.info('Populate terminated by SIGTERM')
@@ -138,55 +159,101 @@ def handler(signum, frame):
138159
elif order == "random":
139160
random.shuffle(keys)
140161

141-
call_count = 0
142162
logger.info('Found %d keys to populate' % len(keys))
143163

144-
make = self._make_tuples if hasattr(self, '_make_tuples') else self.make
164+
if max_calls is not None:
165+
keys = keys[:max_calls]
166+
nkeys = len(keys)
145167

146-
for key in (tqdm(keys) if display_progress else keys):
147-
if max_calls is not None and call_count >= max_calls:
148-
break
149-
if not reserve_jobs or jobs.reserve(self.target.table_name, self._job_key(key)):
150-
self.connection.start_transaction()
151-
if key in self.target: # already populated
152-
self.connection.cancel_transaction()
153-
if reserve_jobs:
154-
jobs.complete(self.target.table_name, self._job_key(key))
168+
if multiprocess: # True or int, presumably
169+
if multiprocess is True:
170+
nproc = mp.cpu_count()
171+
else:
172+
if not isinstance(multiprocess, int):
173+
raise DataJointError("multiprocess can be False, True or a positive integer")
174+
nproc = multiprocess
175+
else:
176+
nproc = 1
177+
nproc = min(nproc, nkeys) # no sense spawning more than can be used
178+
error_list = []
179+
if nproc > 1: # spawn multiple processes
180+
# prepare to pickle self:
181+
self.connection.close() # disconnect parent process from MySQL server
182+
del self.connection._conn.ctx # SSLContext is not picklable
183+
print('*** Spawning pool of %d processes' % nproc)
184+
# send pickled copy of self to each process,
185+
# each worker process calls initializer(*initargs) when it starts
186+
with mp.Pool(nproc, initializer, (self,)) as pool:
187+
if display_progress:
188+
with tqdm(total=nkeys) as pbar:
189+
for error in pool.imap(call_make_key, keys, chunksize=1):
190+
if error is not None:
191+
error_list.append(error)
192+
pbar.update()
155193
else:
156-
logger.info('Populating: ' + str(key))
157-
call_count += 1
158-
self.__class__._allow_insert = True
159-
try:
160-
make(dict(key))
161-
except (KeyboardInterrupt, SystemExit, Exception) as error:
162-
try:
163-
self.connection.cancel_transaction()
164-
except OperationalError:
165-
pass
166-
error_message = '{exception}{msg}'.format(
167-
exception=error.__class__.__name__,
168-
msg=': ' + str(error) if str(error) else '')
169-
if reserve_jobs:
170-
# show error name and error message (if any)
171-
jobs.error(
172-
self.target.table_name, self._job_key(key),
173-
error_message=error_message, error_stack=traceback.format_exc())
174-
if not suppress_errors or isinstance(error, SystemExit):
175-
raise
176-
else:
177-
logger.error(error)
178-
error_list.append((key, error if return_exception_objects else error_message))
179-
else:
180-
self.connection.commit_transaction()
181-
if reserve_jobs:
182-
jobs.complete(self.target.table_name, self._job_key(key))
183-
finally:
184-
self.__class__._allow_insert = False
194+
for error in pool.imap(call_make_key, keys):
195+
if error is not None:
196+
error_list.append(error)
197+
self.connection.connect() # reconnect parent process to MySQL server
198+
else: # use single process
199+
for key in tqdm(keys) if display_progress else keys:
200+
error = self.make_key(key)
201+
if error is not None:
202+
error_list.append(error)
185203

186-
# place back the original signal handler
204+
del self._make_key_kwargs # clean up
205+
206+
# restore original signal handler:
187207
if reserve_jobs:
188208
signal.signal(signal.SIGTERM, old_handler)
189-
return error_list
209+
210+
if suppress_errors:
211+
return error_list
212+
213+
def make_key(self, key):
214+
make = self._make_tuples if hasattr(self, '_make_tuples') else self.make
215+
216+
kwargs = self._make_key_kwargs
217+
suppress_errors = kwargs['suppress_errors']
218+
return_exception_objects = kwargs['return_exception_objects']
219+
reserve_jobs = kwargs['reserve_jobs']
220+
jobs = kwargs['jobs']
221+
222+
if not reserve_jobs or jobs.reserve(self.target.table_name, self._job_key(key)):
223+
self.connection.start_transaction()
224+
if key in self.target: # already populated
225+
self.connection.cancel_transaction()
226+
if reserve_jobs:
227+
jobs.complete(self.target.table_name, self._job_key(key))
228+
else:
229+
logger.info('Populating: ' + str(key))
230+
self.__class__._allow_insert = True
231+
try:
232+
make(dict(key))
233+
except (KeyboardInterrupt, SystemExit, Exception) as error:
234+
try:
235+
self.connection.cancel_transaction()
236+
except OperationalError:
237+
pass
238+
error_message = '{exception}{msg}'.format(
239+
exception=error.__class__.__name__,
240+
msg=': ' + str(error) if str(error) else '')
241+
if reserve_jobs:
242+
# show error name and error message (if any)
243+
jobs.error(
244+
self.target.table_name, self._job_key(key),
245+
error_message=error_message, error_stack=traceback.format_exc())
246+
if not suppress_errors or isinstance(error, SystemExit):
247+
raise
248+
else:
249+
logger.error(error)
250+
return (key, error if return_exception_objects else error_message)
251+
else:
252+
self.connection.commit_transaction()
253+
if reserve_jobs:
254+
jobs.complete(self.target.table_name, self._job_key(key))
255+
finally:
256+
self.__class__._allow_insert = False
190257

191258
def progress(self, *restrictions, display=True):
192259
"""

datajoint/blob.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -152,10 +152,8 @@ def pack_blob(self, obj):
152152
return self.pack_array(np.array(obj))
153153
if isinstance(obj, (bool, np.bool, np.bool_)):
154154
return self.pack_array(np.array(obj))
155-
if isinstance(obj, float):
156-
return self.pack_array(np.array(obj, dtype=np.float64))
157-
if isinstance(obj, int):
158-
return self.pack_array(np.array(obj, dtype=np.int64))
155+
if isinstance(obj, (float, int, complex)):
156+
return self.pack_array(np.array(obj))
159157
if isinstance(obj, (datetime.datetime, datetime.date, datetime.time)):
160158
return self.pack_datetime(obj)
161159
if isinstance(obj, Decimal):

datajoint/heading.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -247,7 +247,10 @@ def init_from_database(self, conn, database, table_name, context):
247247
category = next(c for c in SPECIAL_TYPES if TYPE_PATTERN[c].match(attr['type']))
248248
except StopIteration:
249249
if attr['type'].startswith('external'):
250-
raise DataJointError('Legacy datatype `{type}`.'.format(**attr)) from None
250+
url = "https://docs.datajoint.io/python/admin/5-blob-config.html" \
251+
"#migration-between-datajoint-v0-11-and-v0-12"
252+
raise DataJointError('Legacy datatype `{type}`. Migrate your external stores to '
253+
'datajoint 0.12: {url}'.format(url=url, **attr)) from None
251254
raise DataJointError('Unknown attribute type `{type}`'.format(**attr)) from None
252255
if category == 'FILEPATH' and not _support_filepath_types():
253256
raise DataJointError("""

datajoint/schema.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -191,6 +191,10 @@ def process_table_class(self, table_class, context, assert_declared=False):
191191
instance.declare(context)
192192
is_declared = is_declared or instance.is_declared
193193

194+
# add table definition to the doc string
195+
if isinstance(table_class.definition, str):
196+
table_class.__doc__ = (table_class.__doc__ or "") + "\nTable definition:\n\n" + table_class.definition
197+
194198
# fill values in Lookup tables from their contents property
195199
if isinstance(instance, Lookup) and hasattr(instance, 'contents') and is_declared:
196200
contents = list(instance.contents)

datajoint/table.py

Lines changed: 8 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -45,14 +45,9 @@ def heading(self):
4545
"""
4646
if self._heading is None:
4747
self._heading = Heading() # instance-level heading
48-
if not self._heading: # lazy loading of heading
49-
if self.connection is None:
50-
raise DataJointError(
51-
'DataJoint class is missing a database connection. '
52-
'Missing schema decorator on the class? (e.g. @schema)')
53-
else:
54-
self._heading.init_from_database(
55-
self.connection, self.database, self.table_name, self.declaration_context)
48+
if not self._heading and self.connection is not None: # lazy loading of heading
49+
self._heading.init_from_database(
50+
self.connection, self.database, self.table_name, self.declaration_context)
5651
return self._heading
5752

5853
def declare(self, context=None):
@@ -411,7 +406,7 @@ def delete(self, verbose=True):
411406
print('About to delete:')
412407

413408
if not already_in_transaction:
414-
self.connection.start_transaction()
409+
conn.start_transaction()
415410
total = 0
416411
try:
417412
for name, table in reversed(list(delete_list.items())):
@@ -423,25 +418,25 @@ def delete(self, verbose=True):
423418
except:
424419
# Delete failed, perhaps due to insufficient privileges. Cancel transaction.
425420
if not already_in_transaction:
426-
self.connection.cancel_transaction()
421+
conn.cancel_transaction()
427422
raise
428423
else:
429424
assert not (already_in_transaction and safe)
430425
if not total:
431426
print('Nothing to delete')
432427
if not already_in_transaction:
433-
self.connection.cancel_transaction()
428+
conn.cancel_transaction()
434429
else:
435430
if already_in_transaction:
436431
if verbose:
437432
print('The delete is pending within the ongoing transaction.')
438433
else:
439434
if not safe or user_choice("Proceed?", default='no') == 'yes':
440-
self.connection.commit_transaction()
435+
conn.commit_transaction()
441436
if verbose or safe:
442437
print('Committed.')
443438
else:
444-
self.connection.cancel_transaction()
439+
conn.cancel_transaction()
445440
if verbose or safe:
446441
print('Cancelled deletes.')
447442

datajoint/user_tables.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
# attributes that trigger instantiation of user classes
1414
supported_class_attrs = {
1515
'key_source', 'describe', 'alter', 'heading', 'populate', 'progress', 'primary_key', 'proj', 'aggr',
16-
'fetch', 'fetch1','head', 'tail',
16+
'fetch', 'fetch1', 'head', 'tail',
1717
'insert', 'insert1', 'drop', 'drop_quick', 'delete', 'delete_quick'}
1818

1919

@@ -92,7 +92,7 @@ def table_name(cls):
9292

9393
@ClassProperty
9494
def full_table_name(cls):
95-
if cls not in {Manual, Imported, Lookup, Computed, Part}:
95+
if cls not in {Manual, Imported, Lookup, Computed, Part, UserTable}:
9696
if cls.database is None:
9797
raise DataJointError('Class %s is not properly declared (schema decorator not applied?)' % cls.__name__)
9898
return r"`{0:s}`.`{1:s}`".format(cls.database, cls.table_name)

datajoint/version.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
1-
__version__ = "0.12.2"
1+
__version__ = "0.12.3"
22

33
assert len(__version__) <= 10 # The log table limits version to the 10 characters

docs-parts/intro/Releases_lang1.rst

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,11 @@
1-
0.12.1 -- Nov 11, 2019
1+
0.12.3 -- Nov 22, 2019
2+
----------------------
3+
* Bugfix - networkx 2.4 causes error in diagrams (#675) PR #705
4+
* Bugfix - include definition in doc string and help (#698, #699) PR #706
5+
* Bugfix - job reservation fails when native python datatype support is disabled (#701) PR #702
6+
7+
8+
0.12.2 -- Nov 11, 2019
29
-------------------------
310
* Bugfix - Convoluted error thrown if there is a reference to a non-existent table attribute (#691)
411
* Bugfix - Insert into external does not trim leading slash if defined in `dj.config['stores']['<store>']['location']` (#692)

tests/schema.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,9 @@
1313

1414
@schema
1515
class TTest(dj.Lookup):
16+
"""
17+
doc string
18+
"""
1619
definition = """
1720
key : int # key
1821
---

0 commit comments

Comments
 (0)