diff --git a/compiler_opt/distributed/local/local_worker_manager.py b/compiler_opt/distributed/local/local_worker_manager.py index 50f3321a..246cb389 100644 --- a/compiler_opt/distributed/local/local_worker_manager.py +++ b/compiler_opt/distributed/local/local_worker_manager.py @@ -37,6 +37,7 @@ from absl import logging # pylint: disable=unused-import from compiler_opt.distributed import worker +from compiler_opt.distributed import buffered_scheduler from contextlib import AbstractContextManager from multiprocessing import connection @@ -238,6 +239,18 @@ def __dir__(self): return _Stub() +class LocalWorkerPool(worker.FixedWorkerPool): + + def __init__(self, workers: List[Any], worker_concurrency: int): + super().__init__(workers=workers, worker_concurrency=worker_concurrency) + + def schedule(self, work: List[Any]) -> List[worker.WorkerFuture]: + return buffered_scheduler.schedule( + work, + workers=self.get_currently_active(), + buffer=self.get_worker_concurrency()) + + class LocalWorkerPoolManager(AbstractContextManager): """A pool of workers hosted on the local machines, each in its own process.""" @@ -251,7 +264,7 @@ def __init__(self, worker_class: 'type[worker.Worker]', count: Optional[int], ] def __enter__(self) -> worker.FixedWorkerPool: - return worker.FixedWorkerPool(workers=self._stubs, worker_concurrency=10) + return LocalWorkerPool(workers=self._stubs, worker_concurrency=10) def __exit__(self, *args): # first, trigger killing the worker process and exiting of the msg pump, diff --git a/compiler_opt/distributed/worker.py b/compiler_opt/distributed/worker.py index dea01267..5e28aeac 100644 --- a/compiler_opt/distributed/worker.py +++ b/compiler_opt/distributed/worker.py @@ -15,6 +15,7 @@ """Common abstraction for a worker contract.""" import abc +import concurrent.futures import sys from typing import Any, List, Iterable, Optional, Protocol, TypeVar @@ -32,34 +33,6 @@ def is_priority_method(cls, method_name: str) -> bool: T = TypeVar('T') -class WorkerPool(metaclass=abc.ABCMeta): - """Abstraction of a pool of workers that may be refreshed.""" - - # Issue #155 would strongly-type the return type. - @abc.abstractmethod - def get_currently_active(self) -> List[Any]: - raise NotImplementedError() - - @abc.abstractmethod - def get_worker_concurrency(self) -> int: - raise NotImplementedError() - - -class FixedWorkerPool(WorkerPool): - """A WorkerPool built from a fixed list of workers.""" - - # Issue #155 would strongly-type `workers` - def __init__(self, workers: List[Any], worker_concurrency: int = 2): - self._workers = workers - self._worker_concurrency = worker_concurrency - - def get_currently_active(self): - return self._workers - - def get_worker_concurrency(self): - return self._worker_concurrency - - # Dask's Futures are limited. This captures that. class WorkerFuture(Protocol[T]): @@ -91,6 +64,63 @@ def get_exception(worker_future: WorkerFuture) -> Optional[Exception]: return e +def lift_futures_through_list(future_list: WorkerFuture, + expected_size: int) -> List[WorkerFuture]: + """Convert Future[List] to List[Future].""" + flattened = [concurrent.futures.Future() for _ in range(expected_size)] + + def _handler(fut): + if e := get_exception(fut): + for f in flattened: + f.set_exception(e) + return + + for i, res in enumerate(fut.result()): + assert i < expected_size + if isinstance(res, Exception): + flattened[i].set_exception(res) + else: + flattened[i].set_result(res) + for j in range(i + 1, expected_size): + flattened[j].set_exception( + ValueError(f'No value returned for index {j} in future_list')) + + future_list.add_done_callback(_handler) + return flattened + + +class WorkerPool(metaclass=abc.ABCMeta): + """Abstraction of a pool of workers that may be refreshed.""" + + # Issue #155 would strongly-type the return type. + @abc.abstractmethod + def get_currently_active(self) -> List[Any]: + raise NotImplementedError() + + @abc.abstractmethod + def get_worker_concurrency(self) -> int: + raise NotImplementedError() + + @abc.abstractmethod + def schedule(self, work: List[Any]) -> List[WorkerFuture]: + raise NotImplementedError() + + +class FixedWorkerPool(WorkerPool): + """A WorkerPool built from a fixed list of workers.""" + + # Issue #155 would strongly-type `workers` + def __init__(self, workers: List[Any], worker_concurrency: int = 2): + self._workers = workers + self._worker_concurrency = worker_concurrency + + def get_currently_active(self): + return self._workers + + def get_worker_concurrency(self): + return self._worker_concurrency + + def get_full_worker_args(worker_class: 'type[Worker]', current_kwargs): """Get the union of given kwargs and gin config. diff --git a/compiler_opt/distributed/worker_test.py b/compiler_opt/distributed/worker_test.py new file mode 100644 index 00000000..216c1611 --- /dev/null +++ b/compiler_opt/distributed/worker_test.py @@ -0,0 +1,60 @@ +# coding=utf-8 +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Test for worker.""" + +from absl.testing import absltest +import concurrent.futures + +from compiler_opt.distributed import worker + + +class LiftFuturesThroughListTest(absltest.TestCase): + + def test_normal_path(self): + expected_list = [1, True, [2.0, False]] + future_list = concurrent.futures.Future() + list_future = worker.lift_futures_through_list(future_list, + len(expected_list)) + future_list.set_result(expected_list) + worker.wait_for(list_future) + + self.assertEqual([f.result() for f in list_future], expected_list) + + def test_with_exceptions_in_list(self): + expected_list = [1, ValueError('error')] + future_list = concurrent.futures.Future() + list_future = worker.lift_futures_through_list(future_list, + len(expected_list)) + future_list.set_result(expected_list) + worker.wait_for(list_future) + + self.assertEqual(list_future[0].result(), expected_list[0]) + self.assertTrue( + isinstance(worker.get_exception(list_future[1]), ValueError)) + + def test_list_is_exception(self): + expected_size = 42 + future_list = concurrent.futures.Future() + list_future = worker.lift_futures_through_list(future_list, expected_size) + future_list.set_exception(ValueError('error')) + + worker.wait_for(list_future) + self.assertEqual(len(list_future), expected_size) + for f in list_future: + self.assertTrue(isinstance(worker.get_exception(f), ValueError)) + + +if __name__ == '__main__': + absltest.main() diff --git a/compiler_opt/rl/data_collector.py b/compiler_opt/rl/data_collector.py index 81586ae5..ad9b859d 100644 --- a/compiler_opt/rl/data_collector.py +++ b/compiler_opt/rl/data_collector.py @@ -15,13 +15,17 @@ """Data collection module.""" import abc +import concurrent.futures import time -from typing import Dict, Iterator, Tuple, Sequence +from typing import Any, Dict, Iterator, List, Tuple, Sequence +from absl import logging import numpy as np from compiler_opt.rl import policy_saver from tf_agents.trajectories import trajectory +from compiler_opt.distributed import worker + # Deadline for data collection. DEADLINE_IN_SECONDS = 30 @@ -138,3 +142,85 @@ def wait(self, get_num_finished_work): def waited_time(self): return self._waited_time + + +class CancelledForEarlyExitException(Exception): + ... + + +def _create_cancelled_future(): + f = concurrent.futures.Future() + f.set_exception(CancelledForEarlyExitException()) + return f + + +class EarlyExitWorkerPool(worker.WorkerPool): + """Worker pool wrapper which performs early-exit checking. + + Note that this worker pool wraps another worker pool, and this wrapper only + manages cancelling work from the underlying pool. Also, due to the nature of + "early exit," the futures that this pool's schedule() method returns are all + already .done(). + """ + + def __init__(self, + worker_pool: worker.WorkerPool, + exit_checker_ctor=EarlyExitChecker): + """ + Args: + worker_pool: the underlying worker pool to schedule work on. + exit_checker_ctor: the exit checker constructor to use. + """ + self._worker_pool = worker_pool + self._reset_workers_pool = concurrent.futures.ThreadPoolExecutor() + self._reset_workers_future: Optional[concurrent.futures.Future] = None + self._exit_checker_ctor = exit_checker_ctor + + def get_currently_active(self) -> List[Any]: + return self._worker_pool.get_currently_active() + + def get_worker_concurrency(self) -> int: + return self._worker_pool.get_worker_concurrency() + + def schedule(self, work: List[Any]) -> List[worker.WorkerFuture]: + """Schedule the provided work on the underlying worker pool. + + After the work is scheduled, this method blocks until the early exit + checker deems it ok to exit early. Work that was cancelled will have a + future with a CancelledForEarlyExitException error. + + Args: + work: the work to schedule. + + Returns: + a list of futures which all are already .done(). + """ + + t1 = time.time() + if self._reset_workers_future: + concurrent.futures.wait([self._reset_workers_future]) + self._reset_workers_future = None + logging.info('Waiting for pending work took %f', time.time() - t1) + + result_futures = self._worker_pool.schedule(work) + early_exit = self._exit_checker_ctor(num_modules=len(work)) + early_exit.wait(lambda: sum(res.done() for res in result_futures)) + + def _wrapup(): + workers = self._worker_pool.get_currently_active() + cancel_futures = [wkr.cancel_all_work() for wkr in workers] + worker.wait_for(cancel_futures) + # now that the workers killed pending compilations, make sure the workers + # drained their working queues first - they should all complete quickly + # since the cancellation manager is killing immediately any process starts + worker.wait_for(result_futures) + worker.wait_for([wkr.enable() for wkr in workers]) + + def _process_future(f): + if f.done(): + return f + return _create_cancelled_future() + + results = [_process_future(f) for f in result_futures] + self._reset_future = self._reset_workers_pool.submit(_wrapup) + return results diff --git a/compiler_opt/rl/local_data_collector.py b/compiler_opt/rl/local_data_collector.py index 9a18dce5..dbbca504 100644 --- a/compiler_opt/rl/local_data_collector.py +++ b/compiler_opt/rl/local_data_collector.py @@ -43,28 +43,16 @@ def __init__( reward_stat_map: Dict[str, Optional[Dict[str, compilation_runner.RewardStat]]], best_trajectory_repo: Optional[best_trajectory.BestTrajectoryRepo], - exit_checker_ctor=data_collector.EarlyExitChecker): - # TODO(mtrofin): type exit_checker_ctor when we get typing.Protocol support + ): super().__init__() self._corpus = cps self._num_modules = num_modules self._parser = parser self._worker_pool = worker_pool - self._workers: List[ - compilation_runner - .CompilationRunnerStub] = self._worker_pool.get_currently_active() self._reward_stat_map = reward_stat_map self._best_trajectory_repo = best_trajectory_repo - self._exit_checker_ctor = exit_checker_ctor - # _reset_workers is a future that resolves when post-data collection cleanup - # work completes, i.e. cancelling all work and re-enabling the workers. - # We remove this activity from the critical path by running it concurrently - # with the training phase - i.e. whatever happens between successive data - # collection calls. Subsequent runs will wait for these to finish. - self._reset_workers: Optional[concurrent.futures.Future] = None self._current_futures: List[worker.WorkerFuture] = [] - self._pool = concurrent.futures.ThreadPoolExecutor() self._prefetch_pool = concurrent.futures.ThreadPoolExecutor() self._next_sample: List[ concurrent.futures.Future] = self._prefetch_next_sample() @@ -80,33 +68,18 @@ def _prefetch_next_sample(self): return ret def close_pool(self): - self._join_pending_jobs() # if the pool lost some workers, that's fine - we don't need to tell them # anything anymore. To the new ones, the call is redundant (fine). - for p in self._workers: + for p in self._worker_pool.get_currently_active(): p.cancel_all_work() - self._workers = None self._worker_pool = None - def _join_pending_jobs(self): - t1 = time.time() - if self._reset_workers: - concurrent.futures.wait([self._reset_workers]) - - self._reset_workers = None - # this should have taken negligible time, normally, since all the work - # has been cancelled and the workers had time to process the cancellation - # while training was unfolding. - logging.info('Waiting for pending work from last iteration took %f', - time.time() - t1) - def _schedule_jobs( self, policy: policy_saver.Policy, model_id: int, sampled_modules: List[corpus.LoadedModuleSpec] ) -> List[worker.WorkerFuture[compilation_runner.CompilationResult]]: # by now, all the pending work, which was signaled to cancel, must've # finished - self._join_pending_jobs() jobs = [(loaded_module_spec, policy, self._reward_stat_map[loaded_module_spec.name]) for loaded_module_spec in sampled_modules] @@ -119,9 +92,7 @@ def work(w: compilation_runner.CompilationRunnerStub): return work work = [work_factory(job) for job in jobs] - self._workers = self._worker_pool.get_currently_active() - return buffered_scheduler.schedule( - work, self._workers, self._worker_pool.get_worker_concurrency()) + return self._worker_pool.schedule(work) def collect_data( self, policy: policy_saver.Policy, model_id: int @@ -145,40 +116,35 @@ def collect_data( logging.info('resolving prefetched sample took: %d seconds', time.time() - time1) self._next_sample = self._prefetch_next_sample() + + time_before_schedule = time.time() self._current_futures = self._schedule_jobs(policy, model_id, sampled_modules) - def wait_for_termination(): - early_exit = self._exit_checker_ctor(num_modules=self._num_modules) + # Wait for all futures to complete. We don't do any early-exit checking as + # that functionality has been moved to the + # data_collector.EarlyExitWorkerPool abstraction. + worker.wait_for(self._current_futures) - def get_num_finished_work(): - finished_work = sum(res.done() for res in self._current_futures) - return finished_work + current_work = list(zip(sampled_modules, self._current_futures)) - return early_exit.wait(get_num_finished_work) + def is_cancelled(fut): + if not fut.done(): + return False + if e := worker.get_exception(fut): + return isinstance(e, data_collector.CancelledForEarlyExitException) + return False - wait_seconds = wait_for_termination() - current_work = list(zip(sampled_modules, self._current_futures)) finished_work = [(spec, res) for spec, res in current_work if res.done()] successful_work = [(spec, res.result()) for spec, res in finished_work if not worker.get_exception(res)] - failures = len(finished_work) - len(successful_work) + cancelled_work = [res for res in self._current_futures if is_cancelled(res)] + failures = len(finished_work) - len(successful_work) - len(cancelled_work) logging.info(('%d of %d modules finished in %d seconds (%d failures).'), - len(finished_work), self._num_modules, wait_seconds, failures) - - # signal whatever work is left to finish, and re-enable workers. - def wrapup(): - cancel_futures = [wkr.cancel_all_work() for wkr in self._workers] - worker.wait_for(cancel_futures) - # now that the workers killed pending compilations, make sure the workers - # drained their working queues first - they should all complete quickly - # since the cancellation manager is killing immediately any process starts - worker.wait_for(self._current_futures) - worker.wait_for([wkr.enable() for wkr in self._workers]) - - self._reset_workers = self._pool.submit(wrapup) + len(finished_work) - len(cancelled_work), self._num_modules, + time.time() - time_before_schedule, failures) sequence_examples = list( itertools.chain.from_iterable( diff --git a/compiler_opt/rl/local_data_collector_test.py b/compiler_opt/rl/local_data_collector_test.py index f361b16a..f0ed51a9 100644 --- a/compiler_opt/rl/local_data_collector_test.py +++ b/compiler_opt/rl/local_data_collector_test.py @@ -233,16 +233,15 @@ def wait(self, _): for i in range(200) ]), num_modules=4, - worker_pool=lwp, + worker_pool=data_collector.EarlyExitWorkerPool(lwp, QuickExiter), parser=parser, reward_stat_map=collections.defaultdict(lambda: None), - best_trajectory_repo=None, - exit_checker_ctor=QuickExiter) + best_trajectory_repo=None) collector.collect_data(policy=_mock_policy, model_id=0) - collector._join_pending_jobs() killed = 0 for w in collector._current_futures: - self.assertRaises(compilation_runner.ProcessKilledError, w.result) + self.assertRaises(data_collector.CancelledForEarlyExitException, + w.result) killed += 1 self.assertEqual(killed, 4) collector.close_pool() diff --git a/compiler_opt/rl/train_locally.py b/compiler_opt/rl/train_locally.py index 137031fa..72de7c04 100644 --- a/compiler_opt/rl/train_locally.py +++ b/compiler_opt/rl/train_locally.py @@ -37,6 +37,7 @@ from compiler_opt.rl import corpus from compiler_opt.rl import data_reader from compiler_opt.rl import gin_external_configurables # pylint: disable=unused-import +from compiler_opt.rl import data_collector from compiler_opt.rl import local_data_collector from compiler_opt.rl import policy_saver from compiler_opt.rl import random_net_distillation @@ -61,6 +62,7 @@ @gin.configurable def train_eval(worker_manager_class=LocalWorkerPoolManager, + use_early_exit_worker_pool=True, agent_name=constant.AgentName.PPO, warmstart_policy_dir=None, num_policy_iterations=0, @@ -149,7 +151,10 @@ def sequence_example_iterator_fn(seq_ex: List[str]): worker_class=problem_config.get_runner_type(), count=FLAGS.num_workers, moving_average_decay_rate=moving_average_decay_rate) as worker_pool: - data_collector = local_data_collector.LocalDataCollector( + if use_early_exit_worker_pool: + logging.info('Constructing early exit worker pool wrapper') + worker_pool = data_collector.EarlyExitWorkerPool(worker_pool) + train_data_collector = local_data_collector.LocalDataCollector( cps=cps, num_modules=num_modules, worker_pool=worker_pool, @@ -174,13 +179,13 @@ def sequence_example_iterator_fn(seq_ex: List[str]): str(llvm_trainer.global_step_numpy())) saver.save(policy_path) - dataset_iter, monitor_dict = data_collector.collect_data( + dataset_iter, monitor_dict = train_data_collector.collect_data( policy=policy_saver.Policy.from_filesystem( os.path.join(policy_path, deploy_policy_name)), model_id=llvm_trainer.global_step_numpy()) llvm_trainer.train(dataset_iter, monitor_dict, num_iterations) - data_collector.on_dataset_consumed(dataset_iter) + train_data_collector.on_dataset_consumed(dataset_iter) # Save final policy. saver.save(root_dir)