Skip to content

Commit 02a0276

Browse files
authored
Simplify scheduling work on a worker pool (#219)
The user can just specify how to call a worker with some work, instead of having to go through the trouble of creating a work factory and so on.
1 parent 8b04c35 commit 02a0276

File tree

3 files changed

+84
-21
lines changed

3 files changed

+84
-21
lines changed

compiler_opt/distributed/buffered_scheduler.py

Lines changed: 37 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,11 +19,12 @@
1919
import concurrent.futures
2020
import threading
2121

22-
from typing import List, Callable, TypeVar
22+
from typing import Any, Callable, Iterable, List, Optional, Tuple, TypeVar
2323

2424
from compiler_opt.distributed import worker
2525

2626
T = TypeVar('T')
27+
W = TypeVar('W')
2728

2829

2930
def schedule(work: List[Callable[[T], worker.WorkerFuture]],
@@ -80,3 +81,38 @@ def chain_work(wkr: T):
8081
chain_work(w)
8182

8283
return results
84+
85+
86+
def schedule_on_worker_pool(
87+
action: Callable[[W, T], Any],
88+
jobs: Iterable[T],
89+
worker_pool: worker.WorkerPool,
90+
buffer_size: Optional[int] = None
91+
) -> Tuple[List[W], List[worker.WorkerFuture]]:
92+
"""
93+
Schedule the given action on workers from the given worker pool.
94+
Args:
95+
action: a function that, given a worker and some args, calls that worker
96+
with those args.
97+
jobs: a list of arguments, each element constituting a unit of work.
98+
worker_pool: the worker pool on which to schedule the work.
99+
buffer_size: if provided, buffer these many work items, instead of the
100+
worker manager's default.
101+
102+
Returns:
103+
a tuple. The first value is the workers that are used to perform the work.
104+
The second is a list of futures, one for each work item.
105+
"""
106+
107+
def work_factory(args):
108+
109+
def work(w: worker.Worker):
110+
return action(w, args)
111+
112+
return work
113+
114+
work = [work_factory(job) for job in jobs]
115+
workers: List[W] = worker_pool.get_currently_active()
116+
return workers, schedule(work, workers,
117+
(worker_pool.get_worker_concurrency()
118+
if buffer_size is None else buffer_size))

compiler_opt/distributed/buffered_scheduler_test.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,10 +21,42 @@
2121
from absl.testing import absltest
2222
from compiler_opt.distributed import worker
2323
from compiler_opt.distributed import buffered_scheduler
24+
from compiler_opt.distributed.local import local_worker_manager
2425

2526

2627
class BufferedSchedulerTest(absltest.TestCase):
2728

29+
def test_simple_scheduling(self):
30+
31+
class TheWorker(worker.Worker):
32+
33+
def square(self, the_value, extra_factor=1):
34+
return the_value * the_value * extra_factor
35+
36+
with local_worker_manager.LocalWorkerPoolManager(TheWorker, 2) as pool:
37+
workers, futures = buffered_scheduler.schedule_on_worker_pool(
38+
lambda w, v: w.square(v), range(10), pool)
39+
self.assertLen(workers, 2)
40+
self.assertLen(futures, 10)
41+
worker.wait_for(futures)
42+
self.assertListEqual([f.result() for f in futures],
43+
[x * x for x in range(10)])
44+
45+
_, futures = buffered_scheduler.schedule_on_worker_pool(
46+
lambda w, v: w.square(**v), [dict(the_value=v) for v in range(10)],
47+
pool)
48+
worker.wait_for(futures)
49+
self.assertListEqual([f.result() for f in futures],
50+
[x * x for x in range(10)])
51+
52+
# same idea, but mix some kwargs
53+
_, futures = buffered_scheduler.schedule_on_worker_pool(
54+
lambda w, v: w.square(v[0], **v[1]),
55+
[(v, dict(extra_factor=10)) for v in range(10)], pool)
56+
worker.wait_for(futures)
57+
self.assertListEqual([f.result() for f in futures],
58+
[x * x * 10 for x in range(10)])
59+
2860
def test_schedules(self):
2961
call_count = [0] * 4
3062
locks = [threading.Lock() for _ in range(4)]

compiler_opt/rl/local_data_collector.py

Lines changed: 15 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -100,28 +100,24 @@ def _join_pending_jobs(self):
100100
logging.info('Waiting for pending work from last iteration took %f',
101101
time.time() - t1)
102102

103-
def _schedule_jobs(
104-
self, policy: policy_saver.Policy, model_id: int,
105-
sampled_modules: List[corpus.LoadedModuleSpec]
106-
) -> List[worker.WorkerFuture[compilation_runner.CompilationResult]]:
103+
def _schedule_jobs(self, policy: policy_saver.Policy, model_id: int,
104+
sampled_modules: List[corpus.LoadedModuleSpec]) -> None:
107105
# by now, all the pending work, which was signaled to cancel, must've
108106
# finished
109107
self._join_pending_jobs()
110-
jobs = [(loaded_module_spec, policy,
111-
self._reward_stat_map[loaded_module_spec.name])
112-
for loaded_module_spec in sampled_modules]
113-
114-
def work_factory(job):
115-
116-
def work(w: compilation_runner.CompilationRunnerStub):
117-
return w.collect_data(*job, model_id=model_id)
118-
119-
return work
108+
jobs = [
109+
dict(
110+
loaded_module_spec=loaded_module_spec,
111+
policy=policy,
112+
reward_stat=self._reward_stat_map[loaded_module_spec.name],
113+
model_id=model_id) for loaded_module_spec in sampled_modules
114+
]
120115

121-
work = [work_factory(job) for job in jobs]
122-
self._workers = self._worker_pool.get_currently_active()
123-
return buffered_scheduler.schedule(
124-
work, self._workers, self._worker_pool.get_worker_concurrency())
116+
(self._workers,
117+
self._current_futures) = buffered_scheduler.schedule_on_worker_pool(
118+
action=lambda w, kwargs: w.collect_data(**kwargs),
119+
jobs=jobs,
120+
worker_pool=self._worker_pool)
125121

126122
def collect_data(
127123
self, policy: policy_saver.Policy, model_id: int
@@ -145,8 +141,7 @@ def collect_data(
145141
logging.info('resolving prefetched sample took: %d seconds',
146142
time.time() - time1)
147143
self._next_sample = self._prefetch_next_sample()
148-
self._current_futures = self._schedule_jobs(policy, model_id,
149-
sampled_modules)
144+
self._schedule_jobs(policy, model_id, sampled_modules)
150145

151146
def wait_for_termination():
152147
early_exit = self._exit_checker_ctor(num_modules=self._num_modules)

0 commit comments

Comments
 (0)