77# Licensed to PSF under a Contributor Agreement.
88#
99
10- __all__ = ['Pool' , 'ThreadPool' ]
10+ __all__ = ['BrokenPoolError' , 'CallbackError' , ' Pool' , 'ThreadPool' ]
1111
1212#
1313# Imports
@@ -69,6 +69,14 @@ def __init__(self, exc, tb):
6969 def __reduce__ (self ):
7070 return rebuild_exc , (self .exc , self .tb )
7171
72+ class BrokenPoolError (ExceptionGroup ):
73+ def __init__ (self , msg , exc ):
74+ super ().__init__ (msg , exc )
75+
76+ class CallbackError (ExceptionGroup ):
77+ def __init__ (self , msg , exc ):
78+ super ().__init__ (msg , exc )
79+
7280def rebuild_exc (exc , tb ):
7381 exc .__cause__ = RemoteTraceback (tb )
7482 return exc
@@ -198,6 +206,7 @@ def __init__(self, processes=None, initializer=None, initargs=(),
198206 self ._maxtasksperchild = maxtasksperchild
199207 self ._initializer = initializer
200208 self ._initargs = initargs
209+ self ._errors = []
201210
202211 if processes is None :
203212 processes = os .process_cpu_count () or 1
@@ -349,9 +358,17 @@ def _setup_queues(self):
349358 self ._quick_get = self ._outqueue ._reader .recv
350359
351360 def _check_running (self ):
361+ self ._check_error ()
352362 if self ._state != RUN :
353363 raise ValueError ("Pool not running" )
354364
365+ def _check_error (self , exc = None ):
366+ if self ._errors :
367+ errs = list (self ._errors )
368+ if exc is not None and not isinstance (exc , CallbackError ):
369+ errs .append (exc )
370+ raise BrokenPoolError ("Callback(s) failed" , errs ) from None
371+
355372 def apply (self , func , args = (), kwds = {}):
356373 '''
357374 Equivalent of `func(*args, **kwds)`.
@@ -737,6 +754,11 @@ def __enter__(self):
737754
738755 def __exit__ (self , exc_type , exc_val , exc_tb ):
739756 self .terminate ()
757+ self ._check_error (exc_val )
758+
759+ def _error (self , error ):
760+ util .debug ('callback error' , exc_info = error )
761+ self ._errors .append (error )
740762
741763#
742764# Class whose instances are returned by `Pool.apply_async()`
@@ -751,6 +773,7 @@ def __init__(self, pool, callback, error_callback):
751773 self ._cache = pool ._cache
752774 self ._callback = callback
753775 self ._error_callback = error_callback
776+ self ._cb_error = None
754777 self ._cache [self ._job ] = self
755778
756779 def ready (self ):
@@ -768,30 +791,31 @@ def get(self, timeout=None):
768791 self .wait (timeout )
769792 if not self .ready ():
770793 raise TimeoutError
771- if self ._success :
794+ if self ._cb_error :
795+ raise self ._cb_error
796+ elif self ._success :
772797 return self ._value
773798 else :
774799 raise self ._value
775800
776801 def _set (self , i , obj ):
777802 self ._success , self ._value = obj
778803 if self ._callback and self ._success :
779- self ._handle_exceptions (self ._callback , self ._value )
804+ try :
805+ self ._callback (self ._value )
806+ except Exception as e :
807+ self ._cb_error = CallbackError ("apply callback" , [e ])
808+ self ._pool ._error (self ._cb_error )
780809 if self ._error_callback and not self ._success :
781- self ._handle_exceptions (self ._error_callback , self ._value )
810+ try :
811+ self ._error_callback (self ._value )
812+ except Exception as e :
813+ self ._cb_error = CallbackError ("apply error callback" , [e , self ._value ])
814+ self ._pool ._error (self ._cb_error )
782815 self ._event .set ()
783816 del self ._cache [self ._job ]
784817 self ._pool = None
785818
786- @staticmethod
787- def _handle_exceptions (callback , args ):
788- try :
789- return callback (args )
790- except Exception as e :
791- args = threading .ExceptHookArgs ([type (e ), e , e .__traceback__ , None ])
792- threading .excepthook (args )
793- del args
794-
795819 __class_getitem__ = classmethod (types .GenericAlias )
796820
797821AsyncResult = ApplyResult # create alias -- see #17805
@@ -822,7 +846,11 @@ def _set(self, i, success_result):
822846 self ._value [i * self ._chunksize :(i + 1 )* self ._chunksize ] = result
823847 if self ._number_left == 0 :
824848 if self ._callback :
825- self ._handle_exceptions (self ._callback , self ._value )
849+ try :
850+ self ._callback (self ._value )
851+ except Exception as e :
852+ self ._cb_error = CallbackError ("map callback" , [e ])
853+ self ._pool ._error (self ._cb_error )
826854 del self ._cache [self ._job ]
827855 self ._event .set ()
828856 self ._pool = None
@@ -834,7 +862,12 @@ def _set(self, i, success_result):
834862 if self ._number_left == 0 :
835863 # only consider the result ready once all jobs are done
836864 if self ._error_callback :
837- self ._handle_exceptions (self ._error_callback , self ._value )
865+ try :
866+ self ._error_callback (self ._value )
867+ except Exception as e :
868+ self ._cb_error = CallbackError ("map error callback" ,
869+ [e , self ._value ])
870+ self ._pool ._error (self ._cb_error )
838871 del self ._cache [self ._job ]
839872 self ._event .set ()
840873 self ._pool = None
0 commit comments