Skip to content

Commit fa59df2

Browse files
2 parents d04f6ce + 3e6b174 commit fa59df2

File tree

3 files changed

+120
-48
lines changed

3 files changed

+120
-48
lines changed

datajoint/autopopulate.py

Lines changed: 115 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -9,12 +9,26 @@
99
from .errors import DataJointError, LostConnectionError
1010
from .table import FreeTable
1111
import signal
12+
import multiprocessing as mp
1213

1314
# noinspection PyExceptionInherit,PyCallingNonCallable
1415

1516
logger = logging.getLogger(__name__)
1617

1718

19+
def initializer(table):
20+
"""Save pickled copy of (disconnected) table to the current process,
21+
then reconnect to server. For use by call_make_key()"""
22+
mp.current_process().table = table
23+
table.connection.connect() # reconnect
24+
25+
def call_make_key(key):
26+
"""Call current process' table.make_key()"""
27+
table = mp.current_process().table
28+
error = table.make_key(key)
29+
return error
30+
31+
1832
class AutoPopulate:
1933
"""
2034
AutoPopulate is a mixin class that adds the method populate() to a Relation class.
@@ -102,29 +116,36 @@ def _jobs_to_do(self, restrictions):
102116

103117
def populate(self, *restrictions, suppress_errors=False, return_exception_objects=False,
104118
reserve_jobs=False, order="original", limit=None, max_calls=None,
105-
display_progress=False):
119+
display_progress=False, multiprocess=False):
106120
"""
107121
rel.populate() calls rel.make(key) for every primary key in self.key_source
108122
for which there is not already a tuple in rel.
109123
:param restrictions: a list of restrictions each restrict (rel.key_source - target.proj())
110124
:param suppress_errors: if True, do not terminate execution.
111125
:param return_exception_objects: return error objects instead of just error messages
112-
:param reserve_jobs: if true, reserves job to populate in asynchronous fashion
126+
:param reserve_jobs: if True, reserve jobs to populate in asynchronous fashion
113127
:param order: "original"|"reverse"|"random" - the order of execution
128+
:param limit: if not None, check at most this many keys
129+
:param max_calls: if not None, populate at most this many keys
114130
:param display_progress: if True, report progress_bar
115-
:param limit: if not None, checks at most that many keys
116-
:param max_calls: if not None, populates at max that many keys
131+
:param multiprocess: if True, use as many processes as CPU cores, or use the integer
132+
number of processes specified
117133
"""
118134
if self.connection.in_transaction:
119135
raise DataJointError('Populate cannot be called during a transaction.')
120136

121137
valid_order = ['original', 'reverse', 'random']
122138
if order not in valid_order:
123139
raise DataJointError('The order argument must be one of %s' % str(valid_order))
124-
error_list = [] if suppress_errors else None
125140
jobs = self.connection.schemas[self.target.database].jobs if reserve_jobs else None
126141

127-
# define and setup signal handler for SIGTERM
142+
self._make_key_kwargs = {'suppress_errors':suppress_errors,
143+
'return_exception_objects':return_exception_objects,
144+
'reserve_jobs':reserve_jobs,
145+
'jobs':jobs,
146+
}
147+
148+
# define and set up signal handler for SIGTERM:
128149
if reserve_jobs:
129150
def handler(signum, frame):
130151
logger.info('Populate terminated by SIGTERM')
@@ -137,55 +158,101 @@ def handler(signum, frame):
137158
elif order == "random":
138159
random.shuffle(keys)
139160

140-
call_count = 0
141161
logger.info('Found %d keys to populate' % len(keys))
142162

143-
make = self._make_tuples if hasattr(self, '_make_tuples') else self.make
163+
if max_calls is not None:
164+
keys = keys[:max_calls]
165+
nkeys = len(keys)
144166

145-
for key in (tqdm(keys) if display_progress else keys):
146-
if max_calls is not None and call_count >= max_calls:
147-
break
148-
if not reserve_jobs or jobs.reserve(self.target.table_name, self._job_key(key)):
149-
self.connection.start_transaction()
150-
if key in self.target: # already populated
151-
self.connection.cancel_transaction()
152-
if reserve_jobs:
153-
jobs.complete(self.target.table_name, self._job_key(key))
167+
if multiprocess: # True or int, presumably
168+
if multiprocess is True:
169+
nproc = mp.cpu_count()
170+
else:
171+
if not isinstance(multiprocess, int):
172+
raise DataJointError("multiprocess can be False, True or a positive integer")
173+
nproc = multiprocess
174+
else:
175+
nproc = 1
176+
nproc = min(nproc, nkeys) # no sense spawning more than can be used
177+
error_list = []
178+
if nproc > 1: # spawn multiple processes
179+
# prepare to pickle self:
180+
self.connection.close() # disconnect parent process from MySQL server
181+
del self.connection._conn.ctx # SSLContext is not picklable
182+
print('*** Spawning pool of %d processes' % nproc)
183+
# send pickled copy of self to each process,
184+
# each worker process calls initializer(*initargs) when it starts
185+
with mp.Pool(nproc, initializer, (self,)) as pool:
186+
if display_progress:
187+
with tqdm(total=nkeys) as pbar:
188+
for error in pool.imap(call_make_key, keys, chunksize=1):
189+
if error is not None:
190+
error_list.append(error)
191+
pbar.update()
154192
else:
155-
logger.info('Populating: ' + str(key))
156-
call_count += 1
157-
self.__class__._allow_insert = True
158-
try:
159-
make(dict(key))
160-
except (KeyboardInterrupt, SystemExit, Exception) as error:
161-
try:
162-
self.connection.cancel_transaction()
163-
except LostConnectionError:
164-
pass
165-
error_message = '{exception}{msg}'.format(
166-
exception=error.__class__.__name__,
167-
msg=': ' + str(error) if str(error) else '')
168-
if reserve_jobs:
169-
# show error name and error message (if any)
170-
jobs.error(
171-
self.target.table_name, self._job_key(key),
172-
error_message=error_message, error_stack=traceback.format_exc())
173-
if not suppress_errors or isinstance(error, SystemExit):
174-
raise
175-
else:
176-
logger.error(error)
177-
error_list.append((key, error if return_exception_objects else error_message))
178-
else:
179-
self.connection.commit_transaction()
180-
if reserve_jobs:
181-
jobs.complete(self.target.table_name, self._job_key(key))
182-
finally:
183-
self.__class__._allow_insert = False
193+
for error in pool.imap(call_make_key, keys):
194+
if error is not None:
195+
error_list.append(error)
196+
self.connection.connect() # reconnect parent process to MySQL server
197+
else: # use single process
198+
for key in tqdm(keys) if display_progress else keys:
199+
error = self.make_key(key)
200+
if error is not None:
201+
error_list.append(error)
184202

185-
# place back the original signal handler
203+
del self._make_key_kwargs # clean up
204+
205+
# restore original signal handler:
186206
if reserve_jobs:
187207
signal.signal(signal.SIGTERM, old_handler)
188-
return error_list
208+
209+
if suppress_errors:
210+
return error_list
211+
212+
def make_key(self, key):
213+
make = self._make_tuples if hasattr(self, '_make_tuples') else self.make
214+
215+
kwargs = self._make_key_kwargs
216+
suppress_errors = kwargs['suppress_errors']
217+
return_exception_objects = kwargs['return_exception_objects']
218+
reserve_jobs = kwargs['reserve_jobs']
219+
jobs = kwargs['jobs']
220+
221+
if not reserve_jobs or jobs.reserve(self.target.table_name, self._job_key(key)):
222+
self.connection.start_transaction()
223+
if key in self.target: # already populated
224+
self.connection.cancel_transaction()
225+
if reserve_jobs:
226+
jobs.complete(self.target.table_name, self._job_key(key))
227+
else:
228+
logger.info('Populating: ' + str(key))
229+
self.__class__._allow_insert = True
230+
try:
231+
make(dict(key))
232+
except (KeyboardInterrupt, SystemExit, Exception) as error:
233+
try:
234+
self.connection.cancel_transaction()
235+
except LostConnectionError:
236+
pass
237+
error_message = '{exception}{msg}'.format(
238+
exception=error.__class__.__name__,
239+
msg=': ' + str(error) if str(error) else '')
240+
if reserve_jobs:
241+
# show error name and error message (if any)
242+
jobs.error(
243+
self.target.table_name, self._job_key(key),
244+
error_message=error_message, error_stack=traceback.format_exc())
245+
if not suppress_errors or isinstance(error, SystemExit):
246+
raise
247+
else:
248+
logger.error(error)
249+
return (key, error if return_exception_objects else error_message)
250+
else:
251+
self.connection.commit_transaction()
252+
if reserve_jobs:
253+
jobs.complete(self.target.table_name, self._job_key(key))
254+
finally:
255+
self.__class__._allow_insert = False
189256

190257
def progress(self, *restrictions, display=True):
191258
"""

datajoint/blob.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -166,6 +166,8 @@ def pack_blob(self, obj):
166166
return self.pack_array(np.array(obj))
167167
if isinstance(obj, (np.bool, np.bool_)):
168168
return self.pack_array(np.array(obj))
169+
if isinstance(obj, (float, int, complex)):
170+
return self.pack_array(np.array(obj))
169171
if isinstance(obj, (datetime.datetime, datetime.date, datetime.time)):
170172
return self.pack_datetime(obj)
171173
if isinstance(obj, Decimal):

tests/test_blob.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,9 @@ def test_pack():
2323
x = np.random.randn(10)
2424
assert_array_equal(x, unpack(pack(x)), "Arrays do not match!")
2525

26+
x = 7j
27+
assert_equal(x, unpack(pack(x)), "Complex scalar does not match")
28+
2629
x = np.float32(np.random.randn(3, 4, 5))
2730
assert_array_equal(x, unpack(pack(x)), "Arrays do not match!")
2831

0 commit comments

Comments
 (0)