Skip to content

Commit 3a2c995

Browse files
authored
Add load balancer (google#106)
Adds a buffered load balancer, which by default maintains at least 2 tasks assigned to each worker. Closes google#91
1 parent a08ad91 commit 3a2c995

File tree

6 files changed

+205
-20
lines changed

6 files changed

+205
-20
lines changed
Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
# coding=utf-8
2+
# Copyright 2020 Google LLC
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
"""An optimal push-pull-based load balancer which attempts to maintain at least
16+
`buffer` tasks assigned to each worker.
17+
"""
18+
19+
import concurrent.futures
20+
import threading
21+
22+
from typing import List, Callable, TypeVar
23+
24+
from compiler_opt.distributed import worker
25+
26+
T = TypeVar('T')
27+
28+
29+
def schedule(work: List[Callable[[T], worker.WorkerFuture]],
30+
workers: List[T],
31+
buffer=2) -> List[worker.WorkerFuture]:
32+
"""
33+
Assigns work to workers once previous work of the worker are
34+
completed.
35+
Args:
36+
work: Function to call with a worker.
37+
workers: List of workers that are the singular argument to callable.
38+
buffer: Number of work to maintain on each worker.
39+
Returns:
40+
A list of Futures.
41+
"""
42+
# Create futures to be returned first, these futures aren't bound to
43+
# anything now, but they will be later.
44+
results = [concurrent.futures.Future() for _ in range(len(work))]
45+
idx = -1
46+
idx_lock = threading.Lock()
47+
48+
# Simple atomic increment and get.
49+
# Used to iterate over `work` like a thread-safe queue without making a copy.
50+
def fetch_idx():
51+
nonlocal idx
52+
with idx_lock:
53+
idx += 1
54+
return idx
55+
56+
def make_result_handler(wkr: T, result_future: concurrent.futures.Future):
57+
58+
def handler(worker_future: concurrent.futures.Future):
59+
if (e := worker_future.exception()) is not None:
60+
result_future.set_exception(e)
61+
else:
62+
result_future.set_result(worker_future.result())
63+
chain_work(wkr)
64+
65+
return handler
66+
67+
def chain_work(wkr: T):
68+
if (i := fetch_idx()) < len(work):
69+
# This potentially causes a deadlock if chain_work is called via a
70+
# future.set_result() context which holds a resource that is also required
71+
# to complete the call work[i](wkr) call below. For an example, see:
72+
# https://gist.github.com/Northbadge/a57f2d4e0a71e8f3934bdb47e59e343e
73+
# A fix/workaround would be using threading below, but that introduces
74+
# overhead of creating a new thread.
75+
work[i](wkr).add_done_callback(make_result_handler(wkr, results[i]))
76+
77+
# Use min() in case buffer is huge for some reason.
78+
for _ in range(min(buffer, (len(work) // len(workers)) + 1)):
79+
for w in workers:
80+
chain_work(w)
81+
82+
return results
Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,96 @@
1+
# coding=utf-8
2+
# Copyright 2020 Google LLC
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
"""Test for buffered_scheduler."""
16+
17+
import concurrent.futures
18+
import threading
19+
import time
20+
21+
from absl.testing import absltest
22+
from compiler_opt.distributed import worker
23+
from compiler_opt.distributed.local import buffered_scheduler
24+
25+
26+
class BufferedSchedulerTest(absltest.TestCase):
27+
28+
def test_schedules(self):
29+
call_count = [0] * 4
30+
locks = [threading.Lock() for _ in range(4)]
31+
32+
def wkr_factory(i):
33+
34+
def wkr():
35+
with locks[i]:
36+
call_count[i] += 1
37+
38+
return wkr
39+
40+
wkrs = [wkr_factory(i) for i in range(4)]
41+
42+
def job(wkr):
43+
future = concurrent.futures.Future()
44+
45+
def task():
46+
wkr()
47+
future.set_result(0)
48+
49+
threading.Timer(interval=0.10, function=task).start()
50+
return future
51+
52+
work = [job] * 20
53+
54+
worker.wait_for(buffered_scheduler.schedule(work, wkrs))
55+
self.assertEqual(sum(call_count), 20)
56+
57+
def test_balances(self):
58+
call_count = [0] * 4
59+
locks = [threading.Lock() for _ in range(4)]
60+
61+
def wkr_factory(i):
62+
63+
def wkr():
64+
with locks[i]:
65+
call_count[i] += 1
66+
67+
return wkr
68+
69+
def slow_wkr():
70+
with locks[0]:
71+
call_count[0] += 1
72+
time.sleep(1)
73+
74+
wkrs = [slow_wkr] + [wkr_factory(i) for i in range(1, 4)]
75+
76+
def job(wkr):
77+
future = concurrent.futures.Future()
78+
79+
def task():
80+
wkr()
81+
future.set_result(0)
82+
83+
threading.Timer(interval=0.10, function=task).start()
84+
return future
85+
86+
work = [job] * 20
87+
88+
worker.wait_for(buffered_scheduler.schedule(work, wkrs, buffer=2))
89+
self.assertEqual(sum(call_count), 20)
90+
# since buffer=2, 2 tasks get assigned to the slow wkr, the rest
91+
# should've been assigned elsewhere if load balancing works.
92+
self.assertEqual(call_count[0], 2)
93+
94+
95+
if __name__ == '__main__':
96+
absltest.main()

compiler_opt/distributed/local/local_worker_manager.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -161,6 +161,11 @@ def _msg_pump(self):
161161
with self._lock:
162162
future = self._map[task_result.msgid]
163163
del self._map[task_result.msgid]
164+
# The following will trigger any callbacks defined on the future, as a
165+
# direct function call. If those callbacks were set by the scheduler,
166+
# it's important that self._lock isn't being held when they are being
167+
# called, otherwise a deadlock could arise from __get_attr__ trying to
168+
# acquire the lock.
164169
if task_result.success:
165170
future.set_result(task_result.value)
166171
else:

compiler_opt/distributed/worker.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,7 @@
1414
# limitations under the License.
1515
"""Common abstraction for a worker contract."""
1616

17-
import abc
18-
from typing import Generic, Iterable, Optional, TypeVar
17+
from typing import Iterable, Optional, TypeVar, Protocol
1918

2019

2120
class Worker:
@@ -30,16 +29,17 @@ def is_priority_method(cls, method_name: str) -> bool:
3029

3130

3231
# Dask's Futures are limited. This captures that.
33-
class WorkerFuture(Generic[T], metaclass=abc.ABCMeta):
32+
class WorkerFuture(Protocol[T]):
3433

35-
@abc.abstractmethod
3634
def result(self) -> T:
3735
raise NotImplementedError()
3836

39-
@abc.abstractmethod
4037
def done(self) -> bool:
4138
raise NotImplementedError()
4239

40+
def add_done_callback(self, fn) -> None:
41+
raise NotImplementedError
42+
4343

4444
def wait_for(futures: Iterable[WorkerFuture]):
4545
"""Dask futures don't support more than result() and done()."""

compiler_opt/rl/local_data_collector.py

Lines changed: 16 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
from tf_agents.trajectories import trajectory
2424

2525
from compiler_opt.distributed import worker
26+
from compiler_opt.distributed.local import buffered_scheduler
2627
from compiler_opt.rl import compilation_runner
2728
from compiler_opt.rl import corpus
2829
from compiler_opt.rl import data_collector
@@ -55,7 +56,7 @@ def __init__(
5556
# with the training phase - i.e. whatever happens between successive data
5657
# collection calls. Subsequent runs will wait for these to finish.
5758
self._reset_workers: Optional[concurrent.futures.Future] = None
58-
self._current_work: List[Tuple[corpus.ModuleSpec, worker.WorkerFuture]] = []
59+
self._current_futures: List[worker.WorkerFuture] = []
5960
self._pool = concurrent.futures.ThreadPoolExecutor()
6061

6162
def close_pool(self):
@@ -85,12 +86,15 @@ def _schedule_jobs(
8586
jobs = [(module_spec, policy_path, self._reward_stat_map[module_spec.name])
8687
for module_spec in sampled_modules]
8788

88-
# TODO: Issue #91. Naive load balancing.
89-
ret = []
90-
for i in range(len(jobs)):
91-
ret.append(self._worker_pool[i % len(self._worker_pool)].collect_data(
92-
*(jobs[i])))
93-
return ret
89+
def work_factory(job):
90+
91+
def work(w):
92+
return w.collect_data(*job)
93+
94+
return work
95+
96+
work = [work_factory(job) for job in jobs]
97+
return buffered_scheduler.schedule(work, self._worker_pool, buffer=10)
9498

9599
def collect_data(
96100
self, policy_path: str
@@ -108,22 +112,20 @@ def collect_data(
108112
information is viewable in TensorBoard.
109113
"""
110114
sampled_modules = self._corpus.sample(k=self._num_modules, sort=False)
111-
results = self._schedule_jobs(policy_path, sampled_modules)
115+
self._current_futures = self._schedule_jobs(policy_path, sampled_modules)
112116

113117
def wait_for_termination():
114118
early_exit = self._exit_checker_ctor(num_modules=self._num_modules)
115119

116120
def get_num_finished_work():
117-
finished_work = sum(res.done() for res in results)
121+
finished_work = sum(res.done() for res in self._current_futures)
118122
return finished_work
119123

120124
return early_exit.wait(get_num_finished_work)
121125

122126
wait_seconds = wait_for_termination()
123-
self._current_work = list(zip(sampled_modules, results))
124-
finished_work = [
125-
(spec, res) for spec, res in self._current_work if res.done()
126-
]
127+
current_work = list(zip(sampled_modules, self._current_futures))
128+
finished_work = [(spec, res) for spec, res in current_work if res.done()]
127129
successful_work = [(spec, res.result())
128130
for spec, res in finished_work
129131
if not worker.get_exception(res)]
@@ -139,7 +141,7 @@ def wrapup():
139141
# now that the workers killed pending compilations, make sure the workers
140142
# drained their working queues first - they should all complete quickly
141143
# since the cancellation manager is killing immediately any process starts
142-
worker.wait_for(results)
144+
worker.wait_for(self._current_futures)
143145
worker.wait_for([wkr.enable() for wkr in self._worker_pool])
144146

145147
self._reset_workers = self._pool.submit(wrapup)

compiler_opt/rl/local_data_collector_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -188,7 +188,7 @@ def wait(self, _):
188188
collector.collect_data(policy_path='policy')
189189
collector._join_pending_jobs()
190190
killed = 0
191-
for _, w in collector._current_work:
191+
for w in collector._current_futures:
192192
self.assertRaises(compilation_runner.ProcessKilledError, w.result)
193193
killed += 1
194194
self.assertEqual(killed, 4)

0 commit comments

Comments
 (0)