Skip to content

Commit 6526fa3

Browse files
committed
Raise new group exceptions on callback failure instead of just processing them
with the default exception hook. This is a breaking change to the API, but only in cases where the existing code would be completely broken anyway, so hopefully it isn't a problem. TODO: docs need updating
1 parent 3582925 commit 6526fa3

File tree

3 files changed

+132
-17
lines changed

3 files changed

+132
-17
lines changed

Lib/multiprocessing/managers.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1232,7 +1232,7 @@ def __isub__(self, value):
12321232

12331233
BasePoolProxy = MakeProxyType('PoolProxy', (
12341234
'apply', 'apply_async', 'close', 'imap', 'imap_unordered', 'join',
1235-
'map', 'map_async', 'starmap', 'starmap_async', 'terminate',
1235+
'map', 'map_async', 'starmap', 'starmap_async', 'terminate', '_check_error'
12361236
))
12371237
BasePoolProxy._method_to_typeid_ = {
12381238
'apply_async': 'AsyncResult',
@@ -1246,6 +1246,7 @@ def __enter__(self):
12461246
return self
12471247
def __exit__(self, exc_type, exc_val, exc_tb):
12481248
self.terminate()
1249+
self._check_error(exc_val)
12491250

12501251
#
12511252
# Definition of SyncManager

Lib/multiprocessing/pool.py

Lines changed: 48 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
# Licensed to PSF under a Contributor Agreement.
88
#
99

10-
__all__ = ['Pool', 'ThreadPool']
10+
__all__ = ['BrokenPoolError', 'CallbackError', 'Pool', 'ThreadPool']
1111

1212
#
1313
# Imports
@@ -69,6 +69,14 @@ def __init__(self, exc, tb):
6969
def __reduce__(self):
7070
return rebuild_exc, (self.exc, self.tb)
7171

72+
class BrokenPoolError(ExceptionGroup):
73+
def __init__(self, msg, exc):
74+
super().__init__(msg, exc)
75+
76+
class CallbackError(ExceptionGroup):
77+
def __init__(self, msg, exc):
78+
super().__init__(msg, exc)
79+
7280
def rebuild_exc(exc, tb):
7381
exc.__cause__ = RemoteTraceback(tb)
7482
return exc
@@ -198,6 +206,7 @@ def __init__(self, processes=None, initializer=None, initargs=(),
198206
self._maxtasksperchild = maxtasksperchild
199207
self._initializer = initializer
200208
self._initargs = initargs
209+
self._errors = []
201210

202211
if processes is None:
203212
processes = os.process_cpu_count() or 1
@@ -349,9 +358,17 @@ def _setup_queues(self):
349358
self._quick_get = self._outqueue._reader.recv
350359

351360
def _check_running(self):
361+
self._check_error()
352362
if self._state != RUN:
353363
raise ValueError("Pool not running")
354364

365+
def _check_error(self, exc=None):
366+
if self._errors:
367+
errs = list(self._errors)
368+
if exc is not None and not isinstance(exc, CallbackError):
369+
errs.append(exc)
370+
raise BrokenPoolError("Callback(s) failed", errs) from None
371+
355372
def apply(self, func, args=(), kwds={}):
356373
'''
357374
Equivalent of `func(*args, **kwds)`.
@@ -737,6 +754,11 @@ def __enter__(self):
737754

738755
def __exit__(self, exc_type, exc_val, exc_tb):
739756
self.terminate()
757+
self._check_error(exc_val)
758+
759+
def _error(self, error):
760+
util.debug('callback error', exc_info=error)
761+
self._errors.append(error)
740762

741763
#
742764
# Class whose instances are returned by `Pool.apply_async()`
@@ -751,6 +773,7 @@ def __init__(self, pool, callback, error_callback):
751773
self._cache = pool._cache
752774
self._callback = callback
753775
self._error_callback = error_callback
776+
self._cb_error = None
754777
self._cache[self._job] = self
755778

756779
def ready(self):
@@ -768,30 +791,31 @@ def get(self, timeout=None):
768791
self.wait(timeout)
769792
if not self.ready():
770793
raise TimeoutError
771-
if self._success:
794+
if self._cb_error:
795+
raise self._cb_error
796+
elif self._success:
772797
return self._value
773798
else:
774799
raise self._value
775800

776801
def _set(self, i, obj):
777802
self._success, self._value = obj
778803
if self._callback and self._success:
779-
self._handle_exceptions(self._callback, self._value)
804+
try:
805+
self._callback(self._value)
806+
except Exception as e:
807+
self._cb_error = CallbackError("apply callback", [e])
808+
self._pool._error(self._cb_error)
780809
if self._error_callback and not self._success:
781-
self._handle_exceptions(self._error_callback, self._value)
810+
try:
811+
self._error_callback(self._value)
812+
except Exception as e:
813+
self._cb_error = CallbackError("apply error callback", [e, self._value])
814+
self._pool._error(self._cb_error)
782815
self._event.set()
783816
del self._cache[self._job]
784817
self._pool = None
785818

786-
@staticmethod
787-
def _handle_exceptions(callback, args):
788-
try:
789-
return callback(args)
790-
except Exception as e:
791-
args = threading.ExceptHookArgs([type(e), e, e.__traceback__, None])
792-
threading.excepthook(args)
793-
del args
794-
795819
__class_getitem__ = classmethod(types.GenericAlias)
796820

797821
AsyncResult = ApplyResult # create alias -- see #17805
@@ -822,7 +846,11 @@ def _set(self, i, success_result):
822846
self._value[i*self._chunksize:(i+1)*self._chunksize] = result
823847
if self._number_left == 0:
824848
if self._callback:
825-
self._handle_exceptions(self._callback, self._value)
849+
try:
850+
self._callback(self._value)
851+
except Exception as e:
852+
self._cb_error = CallbackError("map callback", [e])
853+
self._pool._error(self._cb_error)
826854
del self._cache[self._job]
827855
self._event.set()
828856
self._pool = None
@@ -834,7 +862,12 @@ def _set(self, i, success_result):
834862
if self._number_left == 0:
835863
# only consider the result ready once all jobs are done
836864
if self._error_callback:
837-
self._handle_exceptions(self._error_callback, self._value)
865+
try:
866+
self._error_callback(self._value)
867+
except Exception as e:
868+
self._cb_error = CallbackError("map error callback",
869+
[e, self._value])
870+
self._pool._error(self._cb_error)
838871
del self._cache[self._job]
839872
self._event.set()
840873
self._pool = None

Lib/test/_test_multiprocessing.py

Lines changed: 82 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@
5454
import multiprocessing.pool
5555
import multiprocessing.queues
5656
from multiprocessing.connection import wait
57+
from multiprocessing.pool import BrokenPoolError, CallbackError
5758

5859
from multiprocessing import util
5960

@@ -3083,9 +3084,89 @@ def test_resource_warning(self):
30833084
pool = None
30843085
support.gc_collect()
30853086

3086-
def raising():
3087+
def test_callback_errors(self):
3088+
def _apply(pool, target, **kwargs):
3089+
return pool.apply_async(target, **kwargs)
3090+
3091+
def _map(pool, target, **kwargs):
3092+
return pool.map_async(target, range(5), **kwargs)
3093+
3094+
for func in [_apply, _map]:
3095+
with self.subTest(func=func):
3096+
3097+
# Fail upon trying to reuse a broken pool after callback failure:
3098+
# - BrokenPoolError containing:
3099+
# - CallbackError containing:
3100+
# - Error thrown from the callback
3101+
with self.assertRaises(BrokenPoolError) as pool_ctx:
3102+
with self.Pool(1) as pool:
3103+
res = func(pool, noop, callback=raising)
3104+
with self.assertRaises(CallbackError) as res_ctx:
3105+
res.get()
3106+
self._check_subexceptions(res_ctx.exception, [KeyError])
3107+
pool.apply_async(noop)
3108+
self._check_subexceptions(pool_ctx.exception, [CallbackError])
3109+
self._check_subexceptions(pool_ctx.exception.exceptions[0], [KeyError])
3110+
3111+
# Fail upon trying to reuse a broken pool after error callback failures:
3112+
# - BrokenPoolError containing:
3113+
# - 3x CallbackError each containing:
3114+
# - Error thrown from the callback
3115+
# - Original error
3116+
with self.assertRaises(BrokenPoolError) as pool_ctx:
3117+
with self.Pool(3) as pool:
3118+
res = [func(pool, raising2, error_callback=raising)
3119+
for _ in range(3)]
3120+
for r in res:
3121+
with self.assertRaises(CallbackError) as res_ctx:
3122+
r.get()
3123+
self._check_subexceptions(res_ctx.exception,
3124+
[KeyError, IndexError])
3125+
pool.apply_async(noop)
3126+
self._check_subexceptions(pool_ctx.exception, [CallbackError] * 3)
3127+
for se in pool_ctx.exception.exceptions:
3128+
self._check_subexceptions(se, [KeyError, IndexError])
3129+
3130+
# Exiting the context manager with a "normal" error and a failed callback
3131+
# - BrokenPoolError containing:
3132+
# - CallbackError containing:
3133+
# - Error thrown from the callback
3134+
# - Exception that caused the context manager to exit
3135+
with self.assertRaises(BrokenPoolError) as pool_ctx:
3136+
with self.Pool(1) as pool:
3137+
res = func(pool, noop, callback=raising)
3138+
with self.assertRaises(CallbackError) as res_ctx:
3139+
res.get()
3140+
raise IndexError()
3141+
self._check_subexceptions(pool_ctx.exception,
3142+
[CallbackError, IndexError])
3143+
3144+
# Exiting the context manager directly with a callback failure error
3145+
# - BrokenPoolError containing:
3146+
# - CallbackError instance containing:
3147+
# - Error thrown from the callback
3148+
# Note that only one instance of the error is present: it was
3149+
# *not* added again as it was above, since it is a CallbackError
3150+
with self.assertRaises(BrokenPoolError) as pool_ctx:
3151+
with self.Pool(1) as pool:
3152+
func(pool, noop, callback=raising).get()
3153+
self._check_subexceptions(pool_ctx.exception, [CallbackError])
3154+
self._check_subexceptions(pool_ctx.exception.exceptions[0], [KeyError])
3155+
3156+
def _check_subexceptions(self, group, sub_types):
3157+
self.assertEqual(len(group.exceptions), len(sub_types))
3158+
for sub_exc, sub_type in zip(group.exceptions, sub_types):
3159+
self.assertIsInstance(sub_exc, sub_type)
3160+
3161+
def noop(*args):
3162+
pass
3163+
3164+
def raising(*args):
30873165
raise KeyError("key")
30883166

3167+
def raising2(*args):
3168+
raise IndexError()
3169+
30893170
def unpickleable_result():
30903171
return lambda: 42
30913172

0 commit comments

Comments
 (0)