Skip to content

Commit a3dbc19

Browse files
authored
Have the worker pool context produce a pool object (#154)
The motivation is: - allowing worker managers specify the level of concurrency for a worker. This is set to '10' right now by the data collector, but it's really specific to the worker set up. For the local case, where we max out the hardware threads with workers, 10 is overly-generous. For a distributed case, a number approaching the hardware threads would be more appropriate - allowing distributed worker managers update the set of available workers over time. Workers could get preempted, or new ones become available. Having a way to periodically check and update what's available - albeit there are never guarantees a worker, once discovered, stays alive - helps avoid artificial starvation. The rest is renames that fall out of this refactoring.
1 parent 9668e36 commit a3dbc19

File tree

6 files changed

+59
-22
lines changed

6 files changed

+59
-22
lines changed

compiler_opt/distributed/local/local_worker_manager.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@
3636

3737
from absl import logging
3838
# pylint: disable=unused-import
39-
from compiler_opt.distributed.worker import Worker
39+
from compiler_opt.distributed.worker import Worker, FixedWorkerPool
4040

4141
from contextlib import AbstractContextManager
4242
from multiprocessing import connection
@@ -238,7 +238,7 @@ def __dir__(self):
238238
return _Stub()
239239

240240

241-
class LocalWorkerPool(AbstractContextManager):
241+
class LocalWorkerPoolManager(AbstractContextManager):
242242
"""A pool of workers hosted on the local machines, each in its own process."""
243243

244244
def __init__(self, worker_class: 'type[Worker]', count: Optional[int], *args,
@@ -249,8 +249,8 @@ def __init__(self, worker_class: 'type[Worker]', count: Optional[int], *args,
249249
_make_stub(worker_class, *args, **kwargs) for _ in range(count)
250250
]
251251

252-
def __enter__(self):
253-
return self._stubs
252+
def __enter__(self) -> FixedWorkerPool:
253+
return FixedWorkerPool(workers=self._stubs, worker_concurrency=10)
254254

255255
def __exit__(self, *args):
256256
# first, trigger killing the worker process and exiting of the msg pump,

compiler_opt/distributed/local/local_worker_manager_test.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -62,9 +62,9 @@ class LocalWorkerManagerTest(absltest.TestCase):
6262

6363
def test_pool(self):
6464

65-
with local_worker_manager.LocalWorkerPool(JobNormal, 2) as pool:
66-
p1 = pool[0]
67-
p2 = pool[1]
65+
with local_worker_manager.LocalWorkerPoolManager(JobNormal, 2) as pool:
66+
p1 = pool.get_currently_active()[0]
67+
p2 = pool.get_currently_active()[1]
6868
set_futures = [p1.set_token(1), p2.set_token(2)]
6969
done, not_done = concurrent.futures.wait(set_futures)
7070
self.assertLen(done, 2)
@@ -81,16 +81,16 @@ def test_pool(self):
8181

8282
def test_failure(self):
8383

84-
with local_worker_manager.LocalWorkerPool(JobFail, 2) as pool:
84+
with local_worker_manager.LocalWorkerPoolManager(JobFail, 2) as pool:
8585
with self.assertRaises(concurrent.futures.CancelledError):
8686
# this will fail because we didn't pass the arg to the ctor, so the
8787
# worker hosting process will crash.
88-
pool[0].method().result()
88+
pool.get_currently_active()[0].method().result()
8989

9090
def test_worker_crash_while_waiting(self):
9191

92-
with local_worker_manager.LocalWorkerPool(JobSlow, 2) as pool:
93-
p = pool[0]
92+
with local_worker_manager.LocalWorkerPoolManager(JobSlow, 2) as pool:
93+
p = pool.get_currently_active()[0]
9494
f = p.method()
9595
self.assertFalse(f.done())
9696
try:

compiler_opt/distributed/worker.py

Lines changed: 30 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,8 @@
1414
# limitations under the License.
1515
"""Common abstraction for a worker contract."""
1616

17-
from typing import Iterable, Optional, Protocol, TypeVar
17+
import abc
18+
from typing import Any, List, Iterable, Optional, Protocol, TypeVar
1819

1920

2021
class Worker(Protocol):
@@ -28,6 +29,34 @@ def is_priority_method(cls, method_name: str) -> bool:
2829
T = TypeVar('T')
2930

3031

32+
class WorkerPool(metaclass=abc.ABCMeta):
33+
"""Abstraction of a pool of workers that may be refreshed."""
34+
35+
# Issue #155 would strongly-type the return type.
36+
@abc.abstractmethod
37+
def get_currently_active(self) -> List[Any]:
38+
raise NotImplementedError()
39+
40+
@abc.abstractmethod
41+
def get_worker_concurrency(self) -> int:
42+
raise NotImplementedError()
43+
44+
45+
class FixedWorkerPool(WorkerPool):
46+
"""A WorkerPool built from a fixed list of workers."""
47+
48+
# Issue #155 would strongly-type `workers`
49+
def __init__(self, workers: List[Any], worker_concurrency: int = 2):
50+
self._workers = workers
51+
self._worker_concurrency = worker_concurrency
52+
53+
def get_currently_active(self):
54+
return self._workers
55+
56+
def get_worker_concurrency(self):
57+
return self._worker_concurrency
58+
59+
3160
# Dask's Futures are limited. This captures that.
3261
class WorkerFuture(Protocol[T]):
3362

compiler_opt/rl/local_data_collector.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ def __init__(
3737
self,
3838
cps: corpus.Corpus,
3939
num_modules: int,
40-
worker_pool: List[compilation_runner.CompilationRunnerStub],
40+
worker_pool: worker.WorkerPool,
4141
parser: Callable[[List[str]], Iterator[trajectory.Trajectory]],
4242
reward_stat_map: Dict[str, Optional[Dict[str,
4343
compilation_runner.RewardStat]]],
@@ -49,6 +49,9 @@ def __init__(
4949
self._num_modules = num_modules
5050
self._parser = parser
5151
self._worker_pool = worker_pool
52+
self._workers: List[
53+
compilation_runner
54+
.CompilationRunnerStub] = self._worker_pool.get_currently_active()
5255
self._reward_stat_map = reward_stat_map
5356
self._exit_checker_ctor = exit_checker_ctor
5457
# _reset_workers is a future that resolves when post-data collection cleanup
@@ -75,8 +78,11 @@ def _prefetch_next_sample(self):
7578

7679
def close_pool(self):
7780
self._join_pending_jobs()
78-
for p in self._worker_pool:
81+
# if the pool lost some workers, that's fine - we don't need to tell them
82+
# anything anymore. To the new ones, the call is redudant (fine).
83+
for p in self._workers:
7984
p.cancel_all_work()
85+
self._workers = None
8086
self._worker_pool = None
8187

8288
def _join_pending_jobs(self):
@@ -110,7 +116,9 @@ def work(w: compilation_runner.CompilationRunnerStub):
110116
return work
111117

112118
work = [work_factory(job) for job in jobs]
113-
return buffered_scheduler.schedule(work, self._worker_pool, buffer=10)
119+
self._workers = self._worker_pool.get_currently_active()
120+
return buffered_scheduler.schedule(
121+
work, self._workers, self._worker_pool.get_worker_concurrency())
114122

115123
def collect_data(
116124
self, policy: policy_saver.Policy
@@ -158,13 +166,13 @@ def get_num_finished_work():
158166

159167
# signal whatever work is left to finish, and re-enable workers.
160168
def wrapup():
161-
cancel_futures = [wkr.cancel_all_work() for wkr in self._worker_pool]
169+
cancel_futures = [wkr.cancel_all_work() for wkr in self._workers]
162170
worker.wait_for(cancel_futures)
163171
# now that the workers killed pending compilations, make sure the workers
164172
# drained their working queues first - they should all complete quickly
165173
# since the cancellation manager is killing immediately any process starts
166174
worker.wait_for(self._current_futures)
167-
worker.wait_for([wkr.enable() for wkr in self._worker_pool])
175+
worker.wait_for([wkr.enable() for wkr in self._workers])
168176

169177
self._reset_workers = self._pool.submit(wrapup)
170178

compiler_opt/rl/local_data_collector_test.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
import tensorflow as tf
2424
from tf_agents.system import system_multiprocessing as multiprocessing
2525

26-
from compiler_opt.distributed.local.local_worker_manager import LocalWorkerPool
26+
from compiler_opt.distributed.local.local_worker_manager import LocalWorkerPoolManager
2727
from compiler_opt.rl import compilation_runner
2828
from compiler_opt.rl import corpus
2929
from compiler_opt.rl import data_collector
@@ -142,7 +142,7 @@ def _test_iterator_fn(data_list):
142142
return _test_iterator_fn
143143

144144
sampler = DeterministicSampler()
145-
with LocalWorkerPool(worker_class=MyRunner, count=4) as lwp:
145+
with LocalWorkerPoolManager(worker_class=MyRunner, count=4) as lwp:
146146
collector = local_data_collector.LocalDataCollector(
147147
cps=corpus.create_corpus_for_testing(
148148
location=self.create_tempdir(),
@@ -214,7 +214,7 @@ def __init__(self, num_modules):
214214
def wait(self, _):
215215
return False
216216

217-
with LocalWorkerPool(worker_class=Sleeper, count=4) as lwp:
217+
with LocalWorkerPoolManager(worker_class=Sleeper, count=4) as lwp:
218218
collector = local_data_collector.LocalDataCollector(
219219
cps=corpus.create_corpus_for_testing(
220220
location=self.create_tempdir(),

compiler_opt/rl/train_locally.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929
from tf_agents.system import system_multiprocessing as multiprocessing
3030
from typing import List
3131

32-
from compiler_opt.distributed.local.local_worker_manager import LocalWorkerPool
32+
from compiler_opt.distributed.local.local_worker_manager import LocalWorkerPoolManager
3333
from compiler_opt.rl import agent_creators
3434
from compiler_opt.rl import compilation_runner
3535
from compiler_opt.rl import constant
@@ -59,7 +59,7 @@
5959

6060

6161
@gin.configurable
62-
def train_eval(worker_manager_class=LocalWorkerPool,
62+
def train_eval(worker_manager_class=LocalWorkerPoolManager,
6363
agent_name=constant.AgentName.PPO,
6464
warmstart_policy_dir=None,
6565
num_policy_iterations=0,

0 commit comments

Comments
 (0)