Skip to content

Commit 893d7e3

Browse files
authored
Fix worker shutdown (#83)
If process was killed while its output queue is accessed, the queue would be left broken, causing msg_pump to not die, and .join() to never unblock. Also, fix a misc. indentation bug on a test case
1 parent 5eb4ead commit 893d7e3

File tree

2 files changed

+34
-31
lines changed

2 files changed

+34
-31
lines changed

compiler_opt/distributed/local/local_worker_manager.py

Lines changed: 33 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -59,14 +59,10 @@ class TaskResult:
5959
value: Any
6060

6161

62-
def _run_impl(in_q: 'queue.Queue[Task]', out_q: 'queue.Queue[TaskResult]',
62+
def _run_impl(pipe: multiprocessing.connection.Connection,
6363
worker_class: 'type[Worker]', *args, **kwargs):
6464
"""Worker process entrypoint."""
65-
# Note: the out_q is typed as taking only TaskResult objects, not
66-
# Optional[TaskResult], despite that being the type it is used on the Stub
67-
# side. This is because the `None` value is only injected by the Stub itself.
6865

69-
# `threads` is defaulted to 1 in LocalWorkerPool's constructor.
7066
# A setting of 1 does not inhibit the while loop below from running since
7167
# that runs on the main thread of the process. Urgent tasks will still
7268
# process near-immediately. `threads` only controls how many threads are
@@ -75,27 +71,34 @@ def _run_impl(in_q: 'queue.Queue[Task]', out_q: 'queue.Queue[TaskResult]',
7571
pool = concurrent.futures.ThreadPoolExecutor(max_workers=1)
7672
obj = worker_class(*args, **kwargs)
7773

74+
# Pipes are not thread safe
75+
pipe_lock = threading.Lock()
76+
77+
def send(task_result: TaskResult):
78+
with pipe_lock:
79+
pipe.send(task_result)
80+
7881
def make_ondone(msgid):
7982

8083
def on_done(f: concurrent.futures.Future):
8184
if f.exception():
82-
out_q.put(TaskResult(msgid=msgid, success=False, value=f.exception()))
85+
send(TaskResult(msgid=msgid, success=False, value=f.exception()))
8386
else:
84-
out_q.put(TaskResult(msgid=msgid, success=True, value=f.result()))
87+
send(TaskResult(msgid=msgid, success=True, value=f.result()))
8588

8689
return on_done
8790

8891
# Run forever. The stub will just kill the runner when done.
8992
while True:
90-
task = in_q.get()
93+
task: Task = pipe.recv()
9194
the_func = getattr(obj, task.func_name)
9295
application = functools.partial(the_func, *task.args, **task.kwargs)
9396
if task.is_urgent:
9497
try:
9598
res = application()
96-
out_q.put(TaskResult(msgid=task.msgid, success=True, value=res))
99+
send(TaskResult(msgid=task.msgid, success=True, value=res))
97100
except BaseException as e: # pylint: disable=broad-except
98-
out_q.put(TaskResult(msgid=task.msgid, success=False, value=e))
101+
send(TaskResult(msgid=task.msgid, success=False, value=e))
99102
else:
100103
pool.submit(application).add_done_callback(make_ondone(task.msgid))
101104

@@ -114,33 +117,30 @@ class _Stub():
114117
"""Client stub to a worker hosted by a process."""
115118

116119
def __init__(self):
117-
self._send: 'queue.Queue[Task]' = multiprocessing.get_context().Queue()
118-
self._receive: 'queue.Queue[Optional[TaskResult]]' = \
119-
multiprocessing.get_context().Queue()
120+
parent_pipe, child_pipe = multiprocessing.get_context().Pipe()
121+
self._pipe = parent_pipe
122+
self._pipe_lock = threading.Lock()
120123

121124
# this is the process hosting one worker instance.
122125
# we set aside 1 thread to coordinate running jobs, and the main thread
123126
# to handle high priority requests. The expectation is that the user
124127
# achieves concurrency through multiprocessing, not multithreading.
125128
self._process = multiprocessing.Process(
126129
target=functools.partial(
127-
_run,
128-
worker_class=cls,
129-
in_q=self._send,
130-
out_q=self._receive,
131-
*args,
132-
**kwargs))
130+
_run, worker_class=cls, pipe=child_pipe, *args, **kwargs))
133131
# lock for the msgid -> reply future map. The map will be set to None
134132
# when we stop.
135133
self._lock = threading.Lock()
136134
self._map: Dict[int, concurrent.futures.Future] = {}
137135

138-
# thread drainig the receive queue
136+
# thread draining the pipe
139137
self._pump = threading.Thread(target=self._msg_pump)
140138

139+
# Set the state of this worker to "dead" if the process dies naturally.
141140
def observer():
142141
self._process.join()
143-
self._receive.put(None)
142+
# Feed the parent pipe a poison pill, this kills msg_pump
143+
child_pipe.send(None)
144144

145145
self._observer = threading.Thread(target=observer)
146146

@@ -156,8 +156,8 @@ def observer():
156156

157157
def _msg_pump(self):
158158
while True:
159-
task_result = self._receive.get()
160-
if task_result is None:
159+
task_result: Optional[TaskResult] = self._pipe.recv()
160+
if task_result is None: # Poison pill fed by observer
161161
break
162162
with self._lock:
163163
future = self._map[task_result.msgid]
@@ -189,20 +189,23 @@ def remote_call(*args, **kwargs):
189189
if self._is_stopped():
190190
result_future.set_exception(concurrent.futures.CancelledError())
191191
else:
192-
self._send.put(
193-
Task(
194-
msgid=msgid,
195-
func_name=name,
196-
args=args,
197-
kwargs=kwargs,
198-
is_urgent=cls.is_priority_method(name)))
192+
with self._pipe_lock:
193+
self._pipe.send(
194+
Task(
195+
msgid=msgid,
196+
func_name=name,
197+
args=args,
198+
kwargs=kwargs,
199+
is_urgent=cls.is_priority_method(name)))
199200
self._map[msgid] = result_future
200201
return result_future
201202

202203
return remote_call
203204

204205
def shutdown(self):
205206
try:
207+
# Killing the process triggers observer exit, which triggers msg_pump
208+
# exit
206209
self._process.kill()
207210
except: # pylint: disable=bare-except
208211
pass

compiler_opt/rl/local_data_collector_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -159,7 +159,7 @@ def _test_iterator_fn(data_list):
159159
**expected_monitor_dict_subset
160160
})
161161

162-
collector.close_pool()
162+
collector.close_pool()
163163

164164
def test_local_data_collector_task_management(self):
165165

0 commit comments

Comments
 (0)