Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 14 additions & 4 deletions Lib/multiprocessing/pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -776,13 +776,23 @@ def get(self, timeout=None):
def _set(self, i, obj):
self._success, self._value = obj
if self._callback and self._success:
self._callback(self._value)
self._handle_exceptions(self._callback, self._value)
if self._error_callback and not self._success:
self._error_callback(self._value)
self._handle_exceptions(self._error_callback, self._value)
self._event.set()
del self._cache[self._job]
self._pool = None

@staticmethod
def _handle_exceptions(callback, args):
try:
return callback(args)
except Exception as e:
args = threading.ExceptHookArgs([type(e), e, e.__traceback__,
threading.current_thread()])
threading.excepthook(args)
del args

__class_getitem__ = classmethod(types.GenericAlias)

AsyncResult = ApplyResult # create alias -- see #17805
Expand Down Expand Up @@ -813,7 +823,7 @@ def _set(self, i, success_result):
self._value[i*self._chunksize:(i+1)*self._chunksize] = result
if self._number_left == 0:
if self._callback:
self._callback(self._value)
self._handle_exceptions(self._callback, self._value)
del self._cache[self._job]
self._event.set()
self._pool = None
Expand All @@ -825,7 +835,7 @@ def _set(self, i, success_result):
if self._number_left == 0:
# only consider the result ready once all jobs are done
if self._error_callback:
self._error_callback(self._value)
self._handle_exceptions(self._error_callback, self._value)
del self._cache[self._job]
self._event.set()
self._pool = None
Expand Down
34 changes: 33 additions & 1 deletion Lib/test/_test_multiprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -3172,7 +3172,39 @@ def test_resource_warning(self):
pool = None
support.gc_collect()

def raising():
def test_callback_errors(self):
if self.TYPE == 'manager':
self.skipTest("cannot intercept excepthook in manager")

def _apply(pool, target, **kwargs):
return pool.apply_async(target, **kwargs)

def _map(pool, target, **kwargs):
return pool.map_async(target, range(1), **kwargs)

def record_exceptions(errs):
def record(args):
errs.append(args.exc_type)
return record

errs = []
for func in [_apply, _map]:
with self.subTest(func=func):
saved_hook = threading.excepthook
threading.excepthook = record_exceptions(errs)
try:
with self.Pool(1) as pool:
res = func(pool, noop, callback=raising)
res.get()
finally:
threading.excepthook = saved_hook

self.assertEqual(errs, [KeyError, KeyError])

def noop(*args):
pass

def raising(*args):
raise KeyError("key")

def unpickleable_result():
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
Handle exceptions thrown by callbacks passed to
:class:`multiprocessing.pool.Pool` ``*_async`` methods, preventing them from
breaking the pool.
Loading