Skip to content

Commit 2f63d6b

Browse files
authored
Change the return type of buffered_scheduler APIs (#221)
They return canonical futures (i.e. `concurrent.futures.Future`). This is great because we can then use any `Future` APIs, e.g. `concurrent.futures.wait`.
1 parent 387ed6f commit 2f63d6b

File tree

3 files changed

+22
-3
lines changed

3 files changed

+22
-3
lines changed

compiler_opt/distributed/buffered_scheduler.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929

3030
def schedule(work: List[Callable[[T], worker.WorkerFuture]],
3131
workers: List[T],
32-
buffer=2) -> List[worker.WorkerFuture]:
32+
buffer=2) -> List[concurrent.futures.Future]:
3333
"""
3434
Assigns work to workers once previous work of the worker are
3535
completed.
@@ -88,7 +88,7 @@ def schedule_on_worker_pool(
8888
jobs: Iterable[T],
8989
worker_pool: worker.WorkerPool,
9090
buffer_size: Optional[int] = None
91-
) -> Tuple[List[W], List[worker.WorkerFuture]]:
91+
) -> Tuple[List[W], List[concurrent.futures.Future]]:
9292
"""
9393
Schedule the given action on workers from the given worker pool.
9494
Args:

compiler_opt/distributed/buffered_scheduler_test.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,25 @@
2626

2727
class BufferedSchedulerTest(absltest.TestCase):
2828

29+
def test_canonical_futures(self):
30+
31+
class WaitingWorker(worker.Worker):
32+
33+
def wait_seconds(self, n: int):
34+
time.sleep(n)
35+
return n + 1
36+
37+
with local_worker_manager.LocalWorkerPoolManager(WaitingWorker, 2) as pool:
38+
_, futures = buffered_scheduler.schedule_on_worker_pool(
39+
lambda w, v: w.wait_seconds(v), range(4), pool)
40+
not_done = futures
41+
entered_count = 0
42+
while not_done:
43+
_, not_done = concurrent.futures.wait(not_done, timeout=0.5)
44+
entered_count += 1
45+
self.assertGreater(entered_count, 1)
46+
self.assertListEqual([r.result() for r in futures], list(range(1, 5)))
47+
2948
def test_simple_scheduling(self):
3049

3150
class TheWorker(worker.Worker):

compiler_opt/rl/local_data_collector.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -170,7 +170,7 @@ def wrapup():
170170
# now that the workers killed pending compilations, make sure the workers
171171
# drained their working queues first - they should all complete quickly
172172
# since the cancellation manager is killing immediately any process starts
173-
worker.wait_for(self._current_futures)
173+
concurrent.futures.wait(self._current_futures)
174174
worker.wait_for([wkr.enable() for wkr in self._workers])
175175

176176
self._reset_workers = self._pool.submit(wrapup)

0 commit comments

Comments
 (0)