@@ -43,19 +43,46 @@ def _python_exit():
4343                        after_in_parent = _global_shutdown_lock .release )
4444
4545
46+ class  WorkerContext :
47+ 
48+     @classmethod  
49+     def  prepare (cls , initializer , initargs ):
50+         if  initializer  is  not   None :
51+             if  not  callable (initializer ):
52+                 raise  TypeError ("initializer must be a callable" )
53+         def  create_context ():
54+             return  cls (initializer , initargs )
55+         def  resolve_task (cls , fn , args , kwargs ):
56+             return  (fn , args , kwargs )
57+         return  create_context , resolve_task 
58+ 
59+     def  __init__ (self , initializer , initargs ):
60+         self .initializer  =  initializer 
61+         self .initargs  =  initargs 
62+ 
63+     def  initialize (self ):
64+         if  self .initializer  is  not   None :
65+             self .initializer (* self .initargs )
66+ 
67+     def  finalize (self ):
68+         pass 
69+ 
70+     def  run (self , task ):
71+         fn , args , kwargs  =  task 
72+         return  fn (* args , ** kwargs )
73+ 
74+ 
4675class  _WorkItem :
47-     def  __init__ (self , future , fn ,  args ,  kwargs ):
76+     def  __init__ (self , future , task ):
4877        self .future  =  future 
49-         self .fn  =  fn 
50-         self .args  =  args 
51-         self .kwargs  =  kwargs 
78+         self .task  =  task 
5279
53-     def  run (self ):
80+     def  run (self ,  ctx ):
5481        if  not  self .future .set_running_or_notify_cancel ():
5582            return 
5683
5784        try :
58-             result  =  self . fn ( * self .args ,  ** self . kwargs )
85+             result  =  ctx . run ( self .task )
5986        except  BaseException  as  exc :
6087            self .future .set_exception (exc )
6188            # Break a reference cycle with the exception 'exc' 
@@ -66,16 +93,15 @@ def run(self):
6693    __class_getitem__  =  classmethod (types .GenericAlias )
6794
6895
69- def  _worker (executor_reference , work_queue , initializer , initargs ):
70-     if  initializer  is  not   None :
71-         try :
72-             initializer (* initargs )
73-         except  BaseException :
74-             _base .LOGGER .critical ('Exception in initializer:' , exc_info = True )
75-             executor  =  executor_reference ()
76-             if  executor  is  not   None :
77-                 executor ._initializer_failed ()
78-             return 
96+ def  _worker (executor_reference , ctx , work_queue ):
97+     try :
98+         ctx .initialize ()
99+     except  BaseException :
100+         _base .LOGGER .critical ('Exception in initializer:' , exc_info = True )
101+         executor  =  executor_reference ()
102+         if  executor  is  not   None :
103+             executor ._initializer_failed ()
104+         return 
79105    try :
80106        while  True :
81107            try :
@@ -89,7 +115,7 @@ def _worker(executor_reference, work_queue, initializer, initargs):
89115                work_item  =  work_queue .get (block = True )
90116
91117            if  work_item  is  not   None :
92-                 work_item .run ()
118+                 work_item .run (ctx )
93119                # Delete references to object. See GH-60488 
94120                del  work_item 
95121                continue 
@@ -110,6 +136,8 @@ def _worker(executor_reference, work_queue, initializer, initargs):
110136            del  executor 
111137    except  BaseException :
112138        _base .LOGGER .critical ('Exception in worker' , exc_info = True )
139+     finally :
140+         ctx .finalize ()
113141
114142
115143class  BrokenThreadPool (_base .BrokenExecutor ):
@@ -123,8 +151,12 @@ class ThreadPoolExecutor(_base.Executor):
123151    # Used to assign unique thread names when thread_name_prefix is not supplied. 
124152    _counter  =  itertools .count ().__next__ 
125153
154+     @classmethod  
155+     def  prepare_context (cls , initializer , initargs ):
156+         return  WorkerContext .prepare (initializer , initargs )
157+ 
126158    def  __init__ (self , max_workers = None , thread_name_prefix = '' ,
127-                  initializer = None , initargs = ()):
159+                  initializer = None , initargs = (),  ** ctxkwargs ):
128160        """Initializes a new ThreadPoolExecutor instance. 
129161
130162        Args: 
@@ -133,6 +165,7 @@ def __init__(self, max_workers=None, thread_name_prefix='',
133165            thread_name_prefix: An optional name prefix to give our threads. 
134166            initializer: A callable used to initialize worker threads. 
135167            initargs: A tuple of arguments to pass to the initializer. 
168+             ctxkwargs: Additional arguments to cls.prepare_context(). 
136169        """ 
137170        if  max_workers  is  None :
138171            # ThreadPoolExecutor is often used to: 
@@ -146,8 +179,9 @@ def __init__(self, max_workers=None, thread_name_prefix='',
146179        if  max_workers  <=  0 :
147180            raise  ValueError ("max_workers must be greater than 0" )
148181
149-         if  initializer  is  not   None  and  not  callable (initializer ):
150-             raise  TypeError ("initializer must be a callable" )
182+         (self ._create_worker_context ,
183+          self ._resolve_work_item_task ,
184+          ) =  type (self ).prepare_context (initializer , initargs , ** ctxkwargs )
151185
152186        self ._max_workers  =  max_workers 
153187        self ._work_queue  =  queue .SimpleQueue ()
@@ -158,8 +192,6 @@ def __init__(self, max_workers=None, thread_name_prefix='',
158192        self ._shutdown_lock  =  threading .Lock ()
159193        self ._thread_name_prefix  =  (thread_name_prefix  or 
160194                                    ("ThreadPoolExecutor-%d"  %  self ._counter ()))
161-         self ._initializer  =  initializer 
162-         self ._initargs  =  initargs 
163195
164196    def  submit (self , fn , / , * args , ** kwargs ):
165197        with  self ._shutdown_lock , _global_shutdown_lock :
@@ -173,7 +205,8 @@ def submit(self, fn, /, *args, **kwargs):
173205                                   'interpreter shutdown' )
174206
175207            f  =  _base .Future ()
176-             w  =  _WorkItem (f , fn , args , kwargs )
208+             task  =  self ._resolve_work_item_task (f , fn , args , kwargs )
209+             w  =  _WorkItem (f , task )
177210
178211            self ._work_queue .put (w )
179212            self ._adjust_thread_count ()
@@ -196,9 +229,8 @@ def weakref_cb(_, q=self._work_queue):
196229                                     num_threads )
197230            t  =  threading .Thread (name = thread_name , target = _worker ,
198231                                 args = (weakref .ref (self , weakref_cb ),
199-                                        self ._work_queue ,
200-                                        self ._initializer ,
201-                                        self ._initargs ))
232+                                        self ._create_worker_context (),
233+                                        self ._work_queue ))
202234            t .start ()
203235            self ._threads .add (t )
204236            _threads_queues [t ] =  self ._work_queue 
0 commit comments