Skip to content

Commit 7630dea

Browse files
authored
[infra] Introduce a simple local-process worker pool manager (#33)
[infra] Introduce a simple local-process worker pool manager This is meant as an in-place replacement of the current functionality, but using a stateful worker object abstraction analogous to what Dask supports. For local workloads, this implementation is faster than Dask's LocalCluster (which may be due to misconfiguration). Because it's very simple and introduces no new dependencies, even if LocalCluster performance could be improved, the implementation would still be useful for debugging. The worker abstraction allows implementers specify a list of methods that should be executed promptly on the server side, analogous to Dask's `separate_thread=False` concept. For our purposes, we can use it to implement full cancelation of work on the server side. This patch introduces the worker manager, a subsequent patch will enable its use in the rest of the codebase.
1 parent 5f6b5f8 commit 7630dea

File tree

5 files changed

+384
-0
lines changed

5 files changed

+384
-0
lines changed

compiler_opt/distributed/__init__.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
# coding=utf-8
2+
# Copyright 2020 Google LLC
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
# coding=utf-8
2+
# Copyright 2020 Google LLC
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
Lines changed: 230 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,230 @@
1+
# coding=utf-8
2+
# Copyright 2020 Google LLC
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
"""Local Process Pool - based middleware implementation.
16+
17+
This is a simple implementation of a worker pool, running on the local machine.
18+
Each worker object is hosted by a separate process. Each worker object may
19+
handle a number of concurrent requests. The client is given a stub object that
20+
exposes the same methods as the worker, just that they return Futures.
21+
22+
There is a pair of queues between a stub and its corresponding process/worker.
23+
One queue is used to place tasks (method calls), the other to receive results.
24+
Tasks and results are correlated by a monotonically incrementing counter
25+
maintained by the stub.
26+
27+
The worker process dequeues tasks promptly and either re-enqueues them to a
28+
local thread pool, or, if the task is 'urgent', it executes it promptly.
29+
"""
30+
import concurrent.futures
31+
import dataclasses
32+
import functools
33+
import multiprocessing
34+
import multiprocessing.connection
35+
import queue # pylint: disable=unused-import
36+
import threading
37+
38+
from absl import logging
39+
# pylint: disable=unused-import
40+
from compiler_opt.distributed.worker import Worker
41+
42+
from contextlib import AbstractContextManager
43+
from typing import Any, Callable, Dict, Optional
44+
45+
46+
@dataclasses.dataclass(frozen=True)
47+
class Task:
48+
msgid: int
49+
func_name: str
50+
args: tuple
51+
kwargs: dict
52+
is_urgent: bool
53+
54+
55+
@dataclasses.dataclass(frozen=True)
56+
class TaskResult:
57+
msgid: int
58+
success: bool
59+
value: Any
60+
61+
62+
def _run_impl(in_q: 'queue.Queue[Task]', out_q: 'queue.Queue[TaskResult]',
63+
worker_class: 'type[Worker]', *args, **kwargs):
64+
"""Worker process entrypoint."""
65+
# Note: the out_q is typed as taking only TaskResult objects, not
66+
# Optional[TaskResult], despite that being the type it is used on the Stub
67+
# side. This is because the `None` value is only injected by the Stub itself.
68+
pool = concurrent.futures.ThreadPoolExecutor()
69+
obj = worker_class(*args, **kwargs)
70+
71+
def make_ondone(msgid):
72+
73+
def on_done(f: concurrent.futures.Future):
74+
if f.exception():
75+
out_q.put(TaskResult(msgid=msgid, success=False, value=f.exception()))
76+
else:
77+
out_q.put(TaskResult(msgid=msgid, success=True, value=f.result()))
78+
79+
return on_done
80+
81+
# Run forever. The stub will just kill the runner when done.
82+
while True:
83+
task = in_q.get()
84+
the_func = getattr(obj, task.func_name)
85+
application = functools.partial(the_func, *task.args, **task.kwargs)
86+
if task.is_urgent:
87+
try:
88+
res = application()
89+
out_q.put(TaskResult(msgid=task.msgid, success=True, value=res))
90+
except BaseException as e: # pylint: disable=broad-except
91+
out_q.put(TaskResult(msgid=task.msgid, success=False, value=e))
92+
else:
93+
pool.submit(application).add_done_callback(make_ondone(task.msgid))
94+
95+
96+
def _run(*args, **kwargs):
97+
try:
98+
_run_impl(*args, **kwargs)
99+
except BaseException as e:
100+
logging.error(e)
101+
raise e
102+
103+
104+
def _make_stub(cls: 'type[Worker]', *args, **kwargs):
105+
106+
class _Stub():
107+
"""Client stub to a worker hosted by a process."""
108+
109+
def __init__(self):
110+
self._send: 'queue.Queue[Task]' = multiprocessing.get_context().Queue()
111+
self._receive: 'queue.Queue[Optional[TaskResult]]' = \
112+
multiprocessing.get_context().Queue()
113+
114+
# this is the process hosting one worker instance.
115+
self._process = multiprocessing.Process(
116+
target=functools.partial(
117+
_run,
118+
worker_class=cls,
119+
in_q=self._send,
120+
out_q=self._receive,
121+
*args,
122+
**kwargs))
123+
# lock for the msgid -> reply future map. The map will be set to None
124+
# when we stop.
125+
self._lock = threading.Lock()
126+
self._map: Dict[int, concurrent.futures.Future] = {}
127+
128+
# thread drainig the receive queue
129+
self._pump = threading.Thread(target=self._msg_pump)
130+
def observer():
131+
self._process.join()
132+
self._receive.put(None)
133+
self._observer = threading.Thread(target=observer)
134+
135+
# atomic control to _msgid
136+
self._msgidlock = threading.Lock()
137+
self._msgid = 0
138+
139+
# start the worker and the message pump
140+
self._process.start()
141+
# the observer must follow the process start, otherwise join() raises.
142+
self._observer.start()
143+
self._pump.start()
144+
145+
def _msg_pump(self):
146+
while True:
147+
task_result = self._receive.get()
148+
if task_result is None:
149+
break
150+
with self._lock:
151+
future = self._map[task_result.msgid]
152+
del self._map[task_result.msgid]
153+
if task_result.success:
154+
future.set_result(task_result.value)
155+
else:
156+
future.set_exception(task_result.value)
157+
158+
# clear out pending futures and mark ourselves as "stopped" by null-ing
159+
# the map
160+
with self._lock:
161+
for _, v in self._map.items():
162+
v.set_exception(concurrent.futures.CancelledError())
163+
self._map = None
164+
165+
def _is_stopped(self):
166+
return self._map is None
167+
168+
def __getattr__(self, name) -> Callable[[Any], concurrent.futures.Future]:
169+
result_future = concurrent.futures.Future()
170+
171+
with self._msgidlock:
172+
msgid = self._msgid
173+
self._msgid += 1
174+
175+
def remote_call(*args, **kwargs):
176+
with self._lock:
177+
if self._is_stopped():
178+
result_future.set_exception(concurrent.futures.CancelledError())
179+
else:
180+
self._send.put(
181+
Task(
182+
msgid=msgid,
183+
func_name=name,
184+
args=args,
185+
kwargs=kwargs,
186+
is_urgent=cls.is_priority_method(name)))
187+
self._map[msgid] = result_future
188+
return result_future
189+
190+
return remote_call
191+
192+
def shutdown(self):
193+
try:
194+
self._process.kill()
195+
except: # pylint: disable=bare-except
196+
pass
197+
198+
def join(self):
199+
self._observer.join()
200+
self._pump.join()
201+
self._process.join()
202+
203+
def __dir__(self):
204+
return [n for n in dir(cls) if not n.startswith('_')]
205+
206+
return _Stub()
207+
208+
209+
class LocalWorkerPool(AbstractContextManager):
210+
"""A pool of workers hosted on the local machines, each in its own process."""
211+
212+
def __init__(self, worker_class: 'type[Worker]', count: Optional[int], *args,
213+
**kwargs):
214+
if not count:
215+
count = multiprocessing.cpu_count()
216+
self._stubs = [
217+
_make_stub(worker_class, *args, **kwargs) for _ in range(count)
218+
]
219+
220+
def __enter__(self):
221+
return self._stubs
222+
223+
def __exit__(self, *args, **kwargs):
224+
# first, trigger killing the worker process and exiting of the msg pump,
225+
# which will also clear out any pending futures.
226+
for s in self._stubs:
227+
s.shutdown()
228+
# now wait for the message pumps to indicate they exit.
229+
for s in self._stubs:
230+
s.join()
Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,103 @@
1+
# coding=utf-8
2+
# Copyright 2020 Google LLC
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
"""Test for local worker manager."""
16+
17+
import concurrent.futures
18+
import multiprocessing
19+
import time
20+
21+
from absl.testing import absltest
22+
from compiler_opt.distributed.worker import Worker
23+
from compiler_opt.distributed.local import local_worker_manager
24+
from tf_agents.system import system_multiprocessing as multiprocessing
25+
26+
27+
class LocalWorkerManagerTest(absltest.TestCase):
28+
29+
def test_pool(self):
30+
31+
class Job(Worker):
32+
"""Test worker."""
33+
34+
def __init__(self):
35+
self._token = 0
36+
37+
@classmethod
38+
def is_priority_method(cls, method_name: str) -> bool:
39+
return method_name == 'priority_method'
40+
41+
def priority_method(self):
42+
return f'priority {self._token}'
43+
44+
def get_token(self):
45+
return self._token
46+
47+
def set_token(self, value):
48+
self._token = value
49+
50+
with local_worker_manager.LocalWorkerPool(Job, 2) as pool:
51+
p1 = pool[0]
52+
p2 = pool[1]
53+
set_futures = [p1.set_token(1), p2.set_token(2)]
54+
done, not_done = concurrent.futures.wait(set_futures)
55+
self.assertLen(done, 2)
56+
self.assertEmpty(not_done)
57+
self.assertLen([f for f in done if not f.exception()], 2)
58+
self.assertEqual(p1.get_token().result(), 1)
59+
self.assertEqual(p2.get_token().result(), 2)
60+
self.assertEqual(p1.priority_method().result(), 'priority 1')
61+
self.assertEqual(p2.priority_method().result(), 'priority 2')
62+
# wait - to make sure the pump doesn't panic if there's no new messages
63+
time.sleep(3)
64+
# everything still works
65+
self.assertEqual(p2.get_token().result(), 2)
66+
67+
def test_failure(self):
68+
69+
class Job(Worker):
70+
71+
def __init__(self, wont_be_passed):
72+
self._arg = wont_be_passed
73+
74+
def method(self):
75+
return self._arg
76+
77+
with local_worker_manager.LocalWorkerPool(Job, 2) as pool:
78+
with self.assertRaises(concurrent.futures.CancelledError):
79+
# this will fail because we didn't pass the arg to the ctor, so the
80+
# worker hosting process will crash.
81+
pool[0].method().result()
82+
83+
84+
def test_worker_crash_while_waiting(self):
85+
86+
class Job(Worker):
87+
88+
def method(self):
89+
time.sleep(3600)
90+
91+
with local_worker_manager.LocalWorkerPool(Job, 2) as pool:
92+
p = pool[0]
93+
f = p.method()
94+
self.assertFalse(f.done())
95+
try:
96+
p._process.kill() # pylint: disable=protected-access
97+
finally:
98+
with self.assertRaises(concurrent.futures.CancelledError):
99+
_ = f.result()
100+
101+
102+
if __name__ == '__main__':
103+
multiprocessing.handle_test_main(absltest.main)

compiler_opt/distributed/worker.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
# coding=utf-8
2+
# Copyright 2020 Google LLC
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
"""Common abstraction for a worker contract."""
16+
17+
18+
class Worker:
19+
20+
@classmethod
21+
def is_priority_method(cls, method_name: str) -> bool:
22+
_ = method_name
23+
return False

0 commit comments

Comments
 (0)