Skip to content

Commit 63dd89b

Browse files
Make ThreadPoolExecutor extensible.
1 parent ee1b8ce commit 63dd89b

File tree

1 file changed

+58
-26
lines changed

1 file changed

+58
-26
lines changed

Lib/concurrent/futures/thread.py

Lines changed: 58 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -43,19 +43,46 @@ def _python_exit():
4343
after_in_parent=_global_shutdown_lock.release)
4444

4545

46+
class WorkerContext:
47+
48+
@classmethod
49+
def prepare(cls, initializer, initargs):
50+
if initializer is not None:
51+
if not callable(initializer):
52+
raise TypeError("initializer must be a callable")
53+
def create_context():
54+
return cls(initializer, initargs)
55+
def resolve_task(cls, fn, args, kwargs):
56+
return (fn, args, kwargs)
57+
return create_context, resolve_task
58+
59+
def __init__(self, initializer, initargs):
60+
self.initializer = initializer
61+
self.initargs = initargs
62+
63+
def initialize(self):
64+
if self.initializer is not None:
65+
self.initializer(*self.initargs)
66+
67+
def finalize(self):
68+
pass
69+
70+
def run(self, task):
71+
fn, args, kwargs = task
72+
return fn(*args, **kwargs)
73+
74+
4675
class _WorkItem:
47-
def __init__(self, future, fn, args, kwargs):
76+
def __init__(self, future, task):
4877
self.future = future
49-
self.fn = fn
50-
self.args = args
51-
self.kwargs = kwargs
78+
self.task = task
5279

53-
def run(self):
80+
def run(self, ctx):
5481
if not self.future.set_running_or_notify_cancel():
5582
return
5683

5784
try:
58-
result = self.fn(*self.args, **self.kwargs)
85+
result = ctx.run(self.task)
5986
except BaseException as exc:
6087
self.future.set_exception(exc)
6188
# Break a reference cycle with the exception 'exc'
@@ -66,16 +93,15 @@ def run(self):
6693
__class_getitem__ = classmethod(types.GenericAlias)
6794

6895

69-
def _worker(executor_reference, work_queue, initializer, initargs):
70-
if initializer is not None:
71-
try:
72-
initializer(*initargs)
73-
except BaseException:
74-
_base.LOGGER.critical('Exception in initializer:', exc_info=True)
75-
executor = executor_reference()
76-
if executor is not None:
77-
executor._initializer_failed()
78-
return
96+
def _worker(executor_reference, ctx, work_queue):
97+
try:
98+
ctx.initialize()
99+
except BaseException:
100+
_base.LOGGER.critical('Exception in initializer:', exc_info=True)
101+
executor = executor_reference()
102+
if executor is not None:
103+
executor._initializer_failed()
104+
return
79105
try:
80106
while True:
81107
try:
@@ -89,7 +115,7 @@ def _worker(executor_reference, work_queue, initializer, initargs):
89115
work_item = work_queue.get(block=True)
90116

91117
if work_item is not None:
92-
work_item.run()
118+
work_item.run(ctx)
93119
# Delete references to object. See GH-60488
94120
del work_item
95121
continue
@@ -110,6 +136,8 @@ def _worker(executor_reference, work_queue, initializer, initargs):
110136
del executor
111137
except BaseException:
112138
_base.LOGGER.critical('Exception in worker', exc_info=True)
139+
finally:
140+
ctx.finalize()
113141

114142

115143
class BrokenThreadPool(_base.BrokenExecutor):
@@ -123,8 +151,12 @@ class ThreadPoolExecutor(_base.Executor):
123151
# Used to assign unique thread names when thread_name_prefix is not supplied.
124152
_counter = itertools.count().__next__
125153

154+
@classmethod
155+
def prepare_context(cls, initializer, initargs):
156+
return WorkerContext.prepare(initializer, initargs)
157+
126158
def __init__(self, max_workers=None, thread_name_prefix='',
127-
initializer=None, initargs=()):
159+
initializer=None, initargs=(), **ctxkwargs):
128160
"""Initializes a new ThreadPoolExecutor instance.
129161
130162
Args:
@@ -133,6 +165,7 @@ def __init__(self, max_workers=None, thread_name_prefix='',
133165
thread_name_prefix: An optional name prefix to give our threads.
134166
initializer: A callable used to initialize worker threads.
135167
initargs: A tuple of arguments to pass to the initializer.
168+
ctxkwargs: Additional arguments to cls.prepare_context().
136169
"""
137170
if max_workers is None:
138171
# ThreadPoolExecutor is often used to:
@@ -146,8 +179,9 @@ def __init__(self, max_workers=None, thread_name_prefix='',
146179
if max_workers <= 0:
147180
raise ValueError("max_workers must be greater than 0")
148181

149-
if initializer is not None and not callable(initializer):
150-
raise TypeError("initializer must be a callable")
182+
(self._create_worker_context,
183+
self._resolve_work_item_task,
184+
) = type(self).prepare_context(initializer, initargs, **ctxkwargs)
151185

152186
self._max_workers = max_workers
153187
self._work_queue = queue.SimpleQueue()
@@ -158,8 +192,6 @@ def __init__(self, max_workers=None, thread_name_prefix='',
158192
self._shutdown_lock = threading.Lock()
159193
self._thread_name_prefix = (thread_name_prefix or
160194
("ThreadPoolExecutor-%d" % self._counter()))
161-
self._initializer = initializer
162-
self._initargs = initargs
163195

164196
def submit(self, fn, /, *args, **kwargs):
165197
with self._shutdown_lock, _global_shutdown_lock:
@@ -173,7 +205,8 @@ def submit(self, fn, /, *args, **kwargs):
173205
'interpreter shutdown')
174206

175207
f = _base.Future()
176-
w = _WorkItem(f, fn, args, kwargs)
208+
task = self._resolve_work_item_task(f, fn, args, kwargs)
209+
w = _WorkItem(f, task)
177210

178211
self._work_queue.put(w)
179212
self._adjust_thread_count()
@@ -196,9 +229,8 @@ def weakref_cb(_, q=self._work_queue):
196229
num_threads)
197230
t = threading.Thread(name=thread_name, target=_worker,
198231
args=(weakref.ref(self, weakref_cb),
199-
self._work_queue,
200-
self._initializer,
201-
self._initargs))
232+
self._create_worker_context(),
233+
self._work_queue))
202234
t.start()
203235
self._threads.add(t)
204236
_threads_queues[t] = self._work_queue

0 commit comments

Comments
 (0)