Skip to content

Commit 10511e7

Browse files
Merge pull request #1050 from ttngu207/populate_success_count
Returning success count from the `.populate()` call
2 parents 2a11279 + 18fd619 commit 10511e7

File tree

3 files changed

+118
-79
lines changed

3 files changed

+118
-79
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
- Changed - Migrate docs from `https://docs.datajoint.org/python` to `https://datajoint.com/docs/core/datajoint-python`
88
- Fixed - Updated set_password to work on MySQL 8 - PR [#1106](https://github.com/datajoint/datajoint-python/pull/1106)
99
- Added - Missing tests for set_password - PR [#1106](https://github.com/datajoint/datajoint-python/pull/1106)
10+
- Changed - Returning success count after the .populate() call - PR [#1050](https://github.com/datajoint/datajoint-python/pull/1050)
1011

1112
### 0.14.1 -- Jun 02, 2023
1213
- Fixed - Fix altering a part table that uses the "master" keyword - PR [#991](https://github.com/datajoint/datajoint-python/pull/991)

datajoint/autopopulate.py

Lines changed: 100 additions & 79 deletions
Original file line numberDiff line numberDiff line change
@@ -180,6 +180,9 @@ def populate(
180180
to be passed down to each ``make()`` call. Computation arguments should be
181181
specified within the pipeline e.g. using a `dj.Lookup` table.
182182
:type make_kwargs: dict, optional
183+
:return: a dict with two keys
184+
"success_count": the count of successful ``make()`` calls in this ``populate()`` call
185+
"error_list": the error list that is filled if `suppress_errors` is True
183186
"""
184187
if self.connection.in_transaction:
185188
raise DataJointError("Populate cannot be called during a transaction.")
@@ -222,49 +225,62 @@ def handler(signum, frame):
222225

223226
keys = keys[:max_calls]
224227
nkeys = len(keys)
225-
if not nkeys:
226-
return
227-
228-
processes = min(_ for _ in (processes, nkeys, mp.cpu_count()) if _)
229228

230229
error_list = []
231-
populate_kwargs = dict(
232-
suppress_errors=suppress_errors,
233-
return_exception_objects=return_exception_objects,
234-
make_kwargs=make_kwargs,
235-
)
230+
success_list = []
236231

237-
if processes == 1:
238-
for key in (
239-
tqdm(keys, desc=self.__class__.__name__) if display_progress else keys
240-
):
241-
error = self._populate1(key, jobs, **populate_kwargs)
242-
if error is not None:
243-
error_list.append(error)
244-
else:
245-
# spawn multiple processes
246-
self.connection.close() # disconnect parent process from MySQL server
247-
del self.connection._conn.ctx # SSLContext is not pickleable
248-
with mp.Pool(
249-
processes, _initialize_populate, (self, jobs, populate_kwargs)
250-
) as pool, (
251-
tqdm(desc="Processes: ", total=nkeys)
252-
if display_progress
253-
else contextlib.nullcontext()
254-
) as progress_bar:
255-
for error in pool.imap(_call_populate1, keys, chunksize=1):
256-
if error is not None:
257-
error_list.append(error)
258-
if display_progress:
259-
progress_bar.update()
260-
self.connection.connect() # reconnect parent process to MySQL server
232+
if nkeys:
233+
processes = min(_ for _ in (processes, nkeys, mp.cpu_count()) if _)
234+
235+
populate_kwargs = dict(
236+
suppress_errors=suppress_errors,
237+
return_exception_objects=return_exception_objects,
238+
make_kwargs=make_kwargs,
239+
)
240+
241+
if processes == 1:
242+
for key in (
243+
tqdm(keys, desc=self.__class__.__name__)
244+
if display_progress
245+
else keys
246+
):
247+
status = self._populate1(key, jobs, **populate_kwargs)
248+
if status is True:
249+
success_list.append(1)
250+
elif isinstance(status, tuple):
251+
error_list.append(status)
252+
else:
253+
assert status is False
254+
else:
255+
# spawn multiple processes
256+
self.connection.close() # disconnect parent process from MySQL server
257+
del self.connection._conn.ctx # SSLContext is not pickleable
258+
with mp.Pool(
259+
processes, _initialize_populate, (self, jobs, populate_kwargs)
260+
) as pool, (
261+
tqdm(desc="Processes: ", total=nkeys)
262+
if display_progress
263+
else contextlib.nullcontext()
264+
) as progress_bar:
265+
for status in pool.imap(_call_populate1, keys, chunksize=1):
266+
if status is True:
267+
success_list.append(1)
268+
elif isinstance(status, tuple):
269+
error_list.append(status)
270+
else:
271+
assert status is False
272+
if display_progress:
273+
progress_bar.update()
274+
self.connection.connect() # reconnect parent process to MySQL server
261275

262276
# restore original signal handler:
263277
if reserve_jobs:
264278
signal.signal(signal.SIGTERM, old_handler)
265279

266-
if suppress_errors:
267-
return error_list
280+
return {
281+
"success_count": sum(success_list),
282+
"error_list": error_list,
283+
}
268284

269285
def _populate1(
270286
self, key, jobs, suppress_errors, return_exception_objects, make_kwargs=None
@@ -275,55 +291,60 @@ def _populate1(
275291
:param key: dict specifying job to populate
276292
:param suppress_errors: bool if errors should be suppressed and returned
277293
:param return_exception_objects: if True, errors must be returned as objects
278-
:return: (key, error) when suppress_errors=True, otherwise None
294+
:return: (key, error) when suppress_errors=True,
295+
True if successfully invoke one `make()` call, otherwise False
279296
"""
280297
make = self._make_tuples if hasattr(self, "_make_tuples") else self.make
281298

282-
if jobs is None or jobs.reserve(self.target.table_name, self._job_key(key)):
283-
self.connection.start_transaction()
284-
if key in self.target: # already populated
299+
if jobs is not None and not jobs.reserve(
300+
self.target.table_name, self._job_key(key)
301+
):
302+
return False
303+
304+
self.connection.start_transaction()
305+
if key in self.target: # already populated
306+
self.connection.cancel_transaction()
307+
if jobs is not None:
308+
jobs.complete(self.target.table_name, self._job_key(key))
309+
return False
310+
311+
logger.debug(f"Making {key} -> {self.target.full_table_name}")
312+
self.__class__._allow_insert = True
313+
try:
314+
make(dict(key), **(make_kwargs or {}))
315+
except (KeyboardInterrupt, SystemExit, Exception) as error:
316+
try:
285317
self.connection.cancel_transaction()
286-
if jobs is not None:
287-
jobs.complete(self.target.table_name, self._job_key(key))
318+
except LostConnectionError:
319+
pass
320+
error_message = "{exception}{msg}".format(
321+
exception=error.__class__.__name__,
322+
msg=": " + str(error) if str(error) else "",
323+
)
324+
logger.debug(
325+
f"Error making {key} -> {self.target.full_table_name} - {error_message}"
326+
)
327+
if jobs is not None:
328+
# show error name and error message (if any)
329+
jobs.error(
330+
self.target.table_name,
331+
self._job_key(key),
332+
error_message=error_message,
333+
error_stack=traceback.format_exc(),
334+
)
335+
if not suppress_errors or isinstance(error, SystemExit):
336+
raise
288337
else:
289-
logger.debug(f"Making {key} -> {self.target.full_table_name}")
290-
self.__class__._allow_insert = True
291-
try:
292-
make(dict(key), **(make_kwargs or {}))
293-
except (KeyboardInterrupt, SystemExit, Exception) as error:
294-
try:
295-
self.connection.cancel_transaction()
296-
except LostConnectionError:
297-
pass
298-
error_message = "{exception}{msg}".format(
299-
exception=error.__class__.__name__,
300-
msg=": " + str(error) if str(error) else "",
301-
)
302-
logger.debug(
303-
f"Error making {key} -> {self.target.full_table_name} - {error_message}"
304-
)
305-
if jobs is not None:
306-
# show error name and error message (if any)
307-
jobs.error(
308-
self.target.table_name,
309-
self._job_key(key),
310-
error_message=error_message,
311-
error_stack=traceback.format_exc(),
312-
)
313-
if not suppress_errors or isinstance(error, SystemExit):
314-
raise
315-
else:
316-
logger.error(error)
317-
return key, error if return_exception_objects else error_message
318-
else:
319-
self.connection.commit_transaction()
320-
logger.debug(
321-
f"Success making {key} -> {self.target.full_table_name}"
322-
)
323-
if jobs is not None:
324-
jobs.complete(self.target.table_name, self._job_key(key))
325-
finally:
326-
self.__class__._allow_insert = False
338+
logger.error(error)
339+
return key, error if return_exception_objects else error_message
340+
else:
341+
self.connection.commit_transaction()
342+
logger.debug(f"Success making {key} -> {self.target.full_table_name}")
343+
if jobs is not None:
344+
jobs.complete(self.target.table_name, self._job_key(key))
345+
return True
346+
finally:
347+
self.__class__._allow_insert = False
327348

328349
def progress(self, *restrictions, display=False):
329350
"""

tests_old/test_autopopulate.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,23 @@ def test_populate(self):
5353
assert_true(self.ephys)
5454
assert_true(self.channel)
5555

56+
def test_populate_with_success_count(self):
57+
# test simple populate
58+
assert_true(self.subject, "root tables are empty")
59+
assert_false(self.experiment, "table already filled?")
60+
ret = self.experiment.populate()
61+
success_count = ret["success_count"]
62+
assert_equal(len(self.experiment.key_source & self.experiment), success_count)
63+
64+
# test restricted populate
65+
assert_false(self.trial, "table already filled?")
66+
restriction = self.subject.proj(animal="subject_id").fetch("KEY")[0]
67+
d = self.trial.connection.dependencies
68+
d.load()
69+
ret = self.trial.populate(restriction, suppress_errors=True)
70+
success_count = ret["success_count"]
71+
assert_equal(len(self.trial.key_source & self.trial), success_count)
72+
5673
def test_populate_exclude_error_and_ignore_jobs(self):
5774
# test simple populate
5875
assert_true(self.subject, "root tables are empty")

0 commit comments

Comments
 (0)