From eb5685d89444037ad04a6bf0ad1252b3121a201d Mon Sep 17 00:00:00 2001 From: Jacob Hegna Date: Tue, 31 Jan 2023 21:03:03 +0000 Subject: [PATCH 1/6] Refactor early-exit in data_collector into the WorkerPool. Additionally, make scheduling work a method on the work pool. The local worker pool will continue to use the buffered scheduler, but alternative (distributed) worker pools will use alternative strategies. --- .../distributed/local/local_worker_manager.py | 17 +++- compiler_opt/distributed/worker.py | 85 +++++++++++++------ compiler_opt/rl/corpus.py | 11 ++- compiler_opt/rl/data_collector.py | 63 +++++++++++++- compiler_opt/rl/local_data_collector.py | 74 ++++------------ compiler_opt/rl/train_locally.py | 13 ++- 6 files changed, 168 insertions(+), 95 deletions(-) diff --git a/compiler_opt/distributed/local/local_worker_manager.py b/compiler_opt/distributed/local/local_worker_manager.py index 50f3321a..d6d35d34 100644 --- a/compiler_opt/distributed/local/local_worker_manager.py +++ b/compiler_opt/distributed/local/local_worker_manager.py @@ -37,6 +37,7 @@ from absl import logging # pylint: disable=unused-import from compiler_opt.distributed import worker +from compiler_opt.distributed import buffered_scheduler from contextlib import AbstractContextManager from multiprocessing import connection @@ -217,12 +218,14 @@ def shutdown(self): def set_nice(self, val: int): """Sets the nice-ness of the process, this modifies how the OS + schedules it. Only works on Unix, since val is presumed to be an int. """ psutil.Process(self._process.pid).nice(val) def set_affinity(self, val: List[int]): """Sets the CPU affinity of the process, this modifies which cores the OS + schedules it on. """ psutil.Process(self._process.pid).cpu_affinity(val) @@ -238,6 +241,18 @@ def __dir__(self): return _Stub() +class LocalWorkerPool(worker.FixedWorkerPool): + + def __init__(self, workers: List[Any], worker_concurrency: int): + super().__init__(workers=workers, worker_concurrency=worker_concurrency) + + def schedule(self, work: List[Any]) -> List[worker.WorkerFuture]: + return buffered_scheduler.schedule( + work, + workers=self.get_currently_active(), + buffer=self.get_worker_concurrency()) + + class LocalWorkerPoolManager(AbstractContextManager): """A pool of workers hosted on the local machines, each in its own process.""" @@ -251,7 +266,7 @@ def __init__(self, worker_class: 'type[worker.Worker]', count: Optional[int], ] def __enter__(self) -> worker.FixedWorkerPool: - return worker.FixedWorkerPool(workers=self._stubs, worker_concurrency=10) + return LocalWorkerPool(workers=self._stubs, worker_concurrency=10) def __exit__(self, *args): # first, trigger killing the worker process and exiting of the msg pump, diff --git a/compiler_opt/distributed/worker.py b/compiler_opt/distributed/worker.py index dea01267..8aa7b90b 100644 --- a/compiler_opt/distributed/worker.py +++ b/compiler_opt/distributed/worker.py @@ -32,34 +32,6 @@ def is_priority_method(cls, method_name: str) -> bool: T = TypeVar('T') -class WorkerPool(metaclass=abc.ABCMeta): - """Abstraction of a pool of workers that may be refreshed.""" - - # Issue #155 would strongly-type the return type. - @abc.abstractmethod - def get_currently_active(self) -> List[Any]: - raise NotImplementedError() - - @abc.abstractmethod - def get_worker_concurrency(self) -> int: - raise NotImplementedError() - - -class FixedWorkerPool(WorkerPool): - """A WorkerPool built from a fixed list of workers.""" - - # Issue #155 would strongly-type `workers` - def __init__(self, workers: List[Any], worker_concurrency: int = 2): - self._workers = workers - self._worker_concurrency = worker_concurrency - - def get_currently_active(self): - return self._workers - - def get_worker_concurrency(self): - return self._worker_concurrency - - # Dask's Futures are limited. This captures that. class WorkerFuture(Protocol[T]): @@ -91,6 +63,63 @@ def get_exception(worker_future: WorkerFuture) -> Optional[Exception]: return e +def lift_futures_through_list(future_list: WorkerFuture[List], + expected_size: int) -> List[WorkerFuture]: + """Convert Future[List] to List[Future].""" + flattened = [concurrent.futures.Future() for _ in range(size)] + + def _handler(fut): + if e := get_exception(future_list): + for f in flattened: + f.set_exception(e) + return + + for i, res in enumerate(fut.result()): + assert i < size + if isinstance(res, Exception): + flattened[i].set_exception(res) + else: + flattened[i].set_result(res) + for j in range(i + 1, size): + flattened[j].set_exception( + ValueError(f'No value returned for index {j} in future_list')) + + fut.add_done_callback(_handler) + return flattened + + +class WorkerPool(metaclass=abc.ABCMeta): + """Abstraction of a pool of workers that may be refreshed.""" + + # Issue #155 would strongly-type the return type. + @abc.abstractmethod + def get_currently_active(self) -> List[Any]: + raise NotImplementedError() + + @abc.abstractmethod + def get_worker_concurrency(self) -> int: + raise NotImplementedError() + + @abc.abstractmethod + def schedule(self, work: List[Any]) -> List[WorkerFuture]: + raise NotImplementedError() + + +class FixedWorkerPool(WorkerPool): + """A WorkerPool built from a fixed list of workers.""" + + # Issue #155 would strongly-type `workers` + def __init__(self, workers: List[Any], worker_concurrency: int = 2): + self._workers = workers + self._worker_concurrency = worker_concurrency + + def get_currently_active(self): + return self._workers + + def get_worker_concurrency(self): + return self._worker_concurrency + + def get_full_worker_args(worker_class: 'type[Worker]', current_kwargs): """Get the union of given kwargs and gin config. diff --git a/compiler_opt/rl/corpus.py b/compiler_opt/rl/corpus.py index 37b7f0da..d9f7a184 100644 --- a/compiler_opt/rl/corpus.py +++ b/compiler_opt/rl/corpus.py @@ -112,6 +112,7 @@ def build_command_line(self, local_dir: str) -> FullyQualifiedCmdLine: @dataclass(frozen=True) class ModuleSpec: """Metadata of a compilation unit. + This contains the necessary information to enable corpus operations like sampling or filtering, as well as to enable the corpus create a LoadedModuleSpec from a CorpusElement. @@ -132,6 +133,7 @@ def __call__(self, n: int = 20) -> List[ModuleSpec]: """ Args: + module_specs: list of module_specs to sample from k: number of modules to sample n: number of buckets to use @@ -141,6 +143,7 @@ def __call__(self, class SamplerBucketRoundRobin(Sampler): """Calls return a list of module_specs sampled randomly from n buckets, in + round-robin order. The buckets are sequential sections of module_specs of roughly equal lengths.""" @@ -153,6 +156,7 @@ def __call__(self, n: int = 20) -> List[ModuleSpec]: """ Args: + module_specs: list of module_specs to sample from k: number of modules to sample n: number of buckets to use @@ -233,6 +237,7 @@ def __init__(self, sampler: Sampler = SamplerBucketRoundRobin()): """ Prepares the corpus by pre-loading all the CorpusElements and preparing for + sampling. Command line origin (.cmd file or override) is decided, and final command line transformation rules are set (i.e. thinlto flags handled, also output) and validated. @@ -240,10 +245,10 @@ def __init__(self, Args: data_path: corpus directory. additional_flags: list of flags to append to the command line - delete_flags: list of flags to remove (both `-flag=` are supported). - replace_flags: list of flags to be replaced. The key in the dictionary - is the flag. The value is a string that will be `format`-ed with a + replace_flags: list of flags to be replaced. The key in the dictionary is + the flag. The value is a string that will be `format`-ed with a `context` object - see `ReplaceContext`. We verify that flags in replace_flags are present, and do not appear in the additional_flags nor delete_flags. diff --git a/compiler_opt/rl/data_collector.py b/compiler_opt/rl/data_collector.py index 81586ae5..9f3a7af1 100644 --- a/compiler_opt/rl/data_collector.py +++ b/compiler_opt/rl/data_collector.py @@ -15,13 +15,17 @@ """Data collection module.""" import abc +import concurrent.futures import time -from typing import Dict, Iterator, Tuple, Sequence +from typing import Any, Dict, Iterator, List, Tuple, Sequence +from absl import logging import numpy as np from compiler_opt.rl import policy_saver from tf_agents.trajectories import trajectory +from compiler_opt.distributed import worker + # Deadline for data collection. DEADLINE_IN_SECONDS = 30 @@ -138,3 +142,60 @@ def wait(self, get_num_finished_work): def waited_time(self): return self._waited_time + + +class CancelledForEarlyExitException(Exception): + ... + + +def _create_cancelled_future(): + f = concurrent.futures.Future() + f.set_exception(CancelledForEarlyExitException()) + return f + + +class EarlyExitWorkerPool(worker.WorkerPool): + + def __init__(self, + worker_pool: worker.WorkerPool, + exit_checker_ctor=EarlyExitChecker): + self._worker_pool = worker_pool + self._reset_workers_pool = concurrent.futures.ThreadPoolExecutor() + self._reset_workers_future: Optional[concurrent.futures.Future] = None + self._exit_checker_ctor = exit_checker_ctor + + def get_currently_active(self) -> List[Any]: + return self._worker_pool.get_currently_active() + + def get_worker_concurrency(self) -> int: + return self._worker_pool.get_worker_concurrency() + + def schedule(self, work: List[Any]) -> List[worker.WorkerFuture]: + t1 = time.time() + if self._reset_workers_future: + concurrent.futures.wait([self._reset_workers_future]) + self._reset_workers_future = None + logging.info('Waiting for pending work took %f', time.time() - t1) + + result_futures = self._worker_pool.schedule(work) + early_exit = self._exit_checker_ctor(num_modules=len(work)) + early_exit.wait(lambda: sum(res.done() for res in result_futures)) + + def _wrapup(): + workers = self._worker_pool.get_currently_active() + cancel_futures = [wkr.cancel_all_work() for wkr in workers] + worker.wait_for(cancel_futures) + # now that the workers killed pending compilations, make sure the workers + # drained their working queues first - they should all complete quickly + # since the cancellation manager is killing immediately any process starts + worker.wait_for(result_futures) + worker.wait_for([wkr.enable() for wkr in workers]) + + def _process_future(f): + if f.done(): + return f + return _create_cancelled_future() + + results = [_process_future(f) for f in result_futures] + self._reset_future = self._reset_workers_pool.submit(_wrapup) + return results diff --git a/compiler_opt/rl/local_data_collector.py b/compiler_opt/rl/local_data_collector.py index 9a18dce5..4a4944a4 100644 --- a/compiler_opt/rl/local_data_collector.py +++ b/compiler_opt/rl/local_data_collector.py @@ -42,29 +42,15 @@ def __init__( parser: Callable[[List[str]], Iterator[trajectory.Trajectory]], reward_stat_map: Dict[str, Optional[Dict[str, compilation_runner.RewardStat]]], - best_trajectory_repo: Optional[best_trajectory.BestTrajectoryRepo], - exit_checker_ctor=data_collector.EarlyExitChecker): - # TODO(mtrofin): type exit_checker_ctor when we get typing.Protocol support + best_trajectory_repo: Optional[best_trajectory.BestTrajectoryRepo]): super().__init__() self._corpus = cps self._num_modules = num_modules self._parser = parser self._worker_pool = worker_pool - self._workers: List[ - compilation_runner - .CompilationRunnerStub] = self._worker_pool.get_currently_active() self._reward_stat_map = reward_stat_map self._best_trajectory_repo = best_trajectory_repo - self._exit_checker_ctor = exit_checker_ctor - # _reset_workers is a future that resolves when post-data collection cleanup - # work completes, i.e. cancelling all work and re-enabling the workers. - # We remove this activity from the critical path by running it concurrently - # with the training phase - i.e. whatever happens between successive data - # collection calls. Subsequent runs will wait for these to finish. - self._reset_workers: Optional[concurrent.futures.Future] = None - self._current_futures: List[worker.WorkerFuture] = [] - self._pool = concurrent.futures.ThreadPoolExecutor() self._prefetch_pool = concurrent.futures.ThreadPoolExecutor() self._next_sample: List[ concurrent.futures.Future] = self._prefetch_next_sample() @@ -80,33 +66,18 @@ def _prefetch_next_sample(self): return ret def close_pool(self): - self._join_pending_jobs() # if the pool lost some workers, that's fine - we don't need to tell them # anything anymore. To the new ones, the call is redundant (fine). - for p in self._workers: + for p in self._worker_pool.get_currently_active(): p.cancel_all_work() - self._workers = None self._worker_pool = None - def _join_pending_jobs(self): - t1 = time.time() - if self._reset_workers: - concurrent.futures.wait([self._reset_workers]) - - self._reset_workers = None - # this should have taken negligible time, normally, since all the work - # has been cancelled and the workers had time to process the cancellation - # while training was unfolding. - logging.info('Waiting for pending work from last iteration took %f', - time.time() - t1) - def _schedule_jobs( self, policy: policy_saver.Policy, model_id: int, sampled_modules: List[corpus.LoadedModuleSpec] ) -> List[worker.WorkerFuture[compilation_runner.CompilationResult]]: # by now, all the pending work, which was signaled to cancel, must've # finished - self._join_pending_jobs() jobs = [(loaded_module_spec, policy, self._reward_stat_map[loaded_module_spec.name]) for loaded_module_spec in sampled_modules] @@ -119,9 +90,7 @@ def work(w: compilation_runner.CompilationRunnerStub): return work work = [work_factory(job) for job in jobs] - self._workers = self._worker_pool.get_currently_active() - return buffered_scheduler.schedule( - work, self._workers, self._worker_pool.get_worker_concurrency()) + return self._worker_pool.schedule(work) def collect_data( self, policy: policy_saver.Policy, model_id: int @@ -145,40 +114,29 @@ def collect_data( logging.info('resolving prefetched sample took: %d seconds', time.time() - time1) self._next_sample = self._prefetch_next_sample() - self._current_futures = self._schedule_jobs(policy, model_id, - sampled_modules) - def wait_for_termination(): - early_exit = self._exit_checker_ctor(num_modules=self._num_modules) + time_before_schedule = time.time() + current_futures = self._schedule_jobs(policy, model_id, sampled_modules) - def get_num_finished_work(): - finished_work = sum(res.done() for res in self._current_futures) - return finished_work + current_work = list(zip(sampled_modules, current_futures)) - return early_exit.wait(get_num_finished_work) + def is_cancelled(fut): + if not fut.done(): + return False + if e := worker.get_exception(fut): + return isinstance(e, data_collector.CancelledForEarlyExitException) + return False - wait_seconds = wait_for_termination() - current_work = list(zip(sampled_modules, self._current_futures)) finished_work = [(spec, res) for spec, res in current_work if res.done()] successful_work = [(spec, res.result()) for spec, res in finished_work if not worker.get_exception(res)] - failures = len(finished_work) - len(successful_work) + cancelled_work = [res for res in current_futures if is_cancelled(res)] + failures = len(finished_work) - len(successful_work) - len(cancelled_work) logging.info(('%d of %d modules finished in %d seconds (%d failures).'), - len(finished_work), self._num_modules, wait_seconds, failures) - - # signal whatever work is left to finish, and re-enable workers. - def wrapup(): - cancel_futures = [wkr.cancel_all_work() for wkr in self._workers] - worker.wait_for(cancel_futures) - # now that the workers killed pending compilations, make sure the workers - # drained their working queues first - they should all complete quickly - # since the cancellation manager is killing immediately any process starts - worker.wait_for(self._current_futures) - worker.wait_for([wkr.enable() for wkr in self._workers]) - - self._reset_workers = self._pool.submit(wrapup) + len(finished_work) - len(cancelled_work), self._num_modules, + time.time() - time_before_schedule, failures) sequence_examples = list( itertools.chain.from_iterable( diff --git a/compiler_opt/rl/train_locally.py b/compiler_opt/rl/train_locally.py index 137031fa..0fda57ad 100644 --- a/compiler_opt/rl/train_locally.py +++ b/compiler_opt/rl/train_locally.py @@ -37,6 +37,7 @@ from compiler_opt.rl import corpus from compiler_opt.rl import data_reader from compiler_opt.rl import gin_external_configurables # pylint: disable=unused-import +from compiler_opt.rl import data_collector from compiler_opt.rl import local_data_collector from compiler_opt.rl import policy_saver from compiler_opt.rl import random_net_distillation @@ -61,6 +62,7 @@ @gin.configurable def train_eval(worker_manager_class=LocalWorkerPoolManager, + use_early_exit_worker_pool=True, agent_name=constant.AgentName.PPO, warmstart_policy_dir=None, num_policy_iterations=0, @@ -149,7 +151,10 @@ def sequence_example_iterator_fn(seq_ex: List[str]): worker_class=problem_config.get_runner_type(), count=FLAGS.num_workers, moving_average_decay_rate=moving_average_decay_rate) as worker_pool: - data_collector = local_data_collector.LocalDataCollector( + if use_early_exit_worker_pool: + logging.info('Constructing early exit worker pool wrapper') + worker_pool = data_collector.EarlyExitWorkerPool(worker_pool) + train_data_collector = local_data_collector.LocalDataCollector( cps=cps, num_modules=num_modules, worker_pool=worker_pool, @@ -174,17 +179,17 @@ def sequence_example_iterator_fn(seq_ex: List[str]): str(llvm_trainer.global_step_numpy())) saver.save(policy_path) - dataset_iter, monitor_dict = data_collector.collect_data( + dataset_iter, monitor_dict = train_data_collector.collect_data( policy=policy_saver.Policy.from_filesystem( os.path.join(policy_path, deploy_policy_name)), model_id=llvm_trainer.global_step_numpy()) llvm_trainer.train(dataset_iter, monitor_dict, num_iterations) - data_collector.on_dataset_consumed(dataset_iter) + train_data_collector.on_dataset_consumed(dataset_iter) # Save final policy. saver.save(root_dir) - # Wait for all the workers to finish. + # Wait for all the workers to finish data_collector.close_pool() From 8dfa3eb052e57cf49270fa2d5fddd577fa3aa61f Mon Sep 17 00:00:00 2001 From: Jacob Hegna Date: Tue, 31 Jan 2023 21:08:49 +0000 Subject: [PATCH 2/6] Fix unintended format changes. --- compiler_opt/distributed/local/local_worker_manager.py | 2 -- compiler_opt/rl/corpus.py | 10 +++------- compiler_opt/rl/train_locally.py | 2 +- 3 files changed, 4 insertions(+), 10 deletions(-) diff --git a/compiler_opt/distributed/local/local_worker_manager.py b/compiler_opt/distributed/local/local_worker_manager.py index d6d35d34..246cb389 100644 --- a/compiler_opt/distributed/local/local_worker_manager.py +++ b/compiler_opt/distributed/local/local_worker_manager.py @@ -218,14 +218,12 @@ def shutdown(self): def set_nice(self, val: int): """Sets the nice-ness of the process, this modifies how the OS - schedules it. Only works on Unix, since val is presumed to be an int. """ psutil.Process(self._process.pid).nice(val) def set_affinity(self, val: List[int]): """Sets the CPU affinity of the process, this modifies which cores the OS - schedules it on. """ psutil.Process(self._process.pid).cpu_affinity(val) diff --git a/compiler_opt/rl/corpus.py b/compiler_opt/rl/corpus.py index d9f7a184..8ea6f038 100644 --- a/compiler_opt/rl/corpus.py +++ b/compiler_opt/rl/corpus.py @@ -112,7 +112,6 @@ def build_command_line(self, local_dir: str) -> FullyQualifiedCmdLine: @dataclass(frozen=True) class ModuleSpec: """Metadata of a compilation unit. - This contains the necessary information to enable corpus operations like sampling or filtering, as well as to enable the corpus create a LoadedModuleSpec from a CorpusElement. @@ -133,7 +132,6 @@ def __call__(self, n: int = 20) -> List[ModuleSpec]: """ Args: - module_specs: list of module_specs to sample from k: number of modules to sample n: number of buckets to use @@ -143,7 +141,6 @@ def __call__(self, class SamplerBucketRoundRobin(Sampler): """Calls return a list of module_specs sampled randomly from n buckets, in - round-robin order. The buckets are sequential sections of module_specs of roughly equal lengths.""" @@ -156,7 +153,6 @@ def __call__(self, n: int = 20) -> List[ModuleSpec]: """ Args: - module_specs: list of module_specs to sample from k: number of modules to sample n: number of buckets to use @@ -245,10 +241,10 @@ def __init__(self, Args: data_path: corpus directory. additional_flags: list of flags to append to the command line - delete_flags: list of flags to remove (both `-flag=` are supported). - replace_flags: list of flags to be replaced. The key in the dictionary is - the flag. The value is a string that will be `format`-ed with a + replace_flags: list of flags to be replaced. The key in the dictionary + is the flag. The value is a string that will be `format`-ed with a `context` object - see `ReplaceContext`. We verify that flags in replace_flags are present, and do not appear in the additional_flags nor delete_flags. diff --git a/compiler_opt/rl/train_locally.py b/compiler_opt/rl/train_locally.py index 0fda57ad..72de7c04 100644 --- a/compiler_opt/rl/train_locally.py +++ b/compiler_opt/rl/train_locally.py @@ -189,7 +189,7 @@ def sequence_example_iterator_fn(seq_ex: List[str]): # Save final policy. saver.save(root_dir) - # Wait for all the workers to finish + # Wait for all the workers to finish. data_collector.close_pool() From 10b7288bb0ddce9e69bba03d93e5aac3e3256cbd Mon Sep 17 00:00:00 2001 From: Jacob Hegna Date: Tue, 31 Jan 2023 21:23:42 +0000 Subject: [PATCH 3/6] Fix tests and add documentation for EarlyExitWorkerPool. --- compiler_opt/rl/data_collector.py | 25 ++++++++++++++++++++ compiler_opt/rl/local_data_collector.py | 12 +++++++--- compiler_opt/rl/local_data_collector_test.py | 9 ++++--- 3 files changed, 38 insertions(+), 8 deletions(-) diff --git a/compiler_opt/rl/data_collector.py b/compiler_opt/rl/data_collector.py index 9f3a7af1..ad9b859d 100644 --- a/compiler_opt/rl/data_collector.py +++ b/compiler_opt/rl/data_collector.py @@ -155,10 +155,22 @@ def _create_cancelled_future(): class EarlyExitWorkerPool(worker.WorkerPool): + """Worker pool wrapper which performs early-exit checking. + + Note that this worker pool wraps another worker pool, and this wrapper only + manages cancelling work from the underlying pool. Also, due to the nature of + "early exit," the futures that this pool's schedule() method returns are all + already .done(). + """ def __init__(self, worker_pool: worker.WorkerPool, exit_checker_ctor=EarlyExitChecker): + """ + Args: + worker_pool: the underlying worker pool to schedule work on. + exit_checker_ctor: the exit checker constructor to use. + """ self._worker_pool = worker_pool self._reset_workers_pool = concurrent.futures.ThreadPoolExecutor() self._reset_workers_future: Optional[concurrent.futures.Future] = None @@ -171,6 +183,19 @@ def get_worker_concurrency(self) -> int: return self._worker_pool.get_worker_concurrency() def schedule(self, work: List[Any]) -> List[worker.WorkerFuture]: + """Schedule the provided work on the underlying worker pool. + + After the work is scheduled, this method blocks until the early exit + checker deems it ok to exit early. Work that was cancelled will have a + future with a CancelledForEarlyExitException error. + + Args: + work: the work to schedule. + + Returns: + a list of futures which all are already .done(). + """ + t1 = time.time() if self._reset_workers_future: concurrent.futures.wait([self._reset_workers_future]) diff --git a/compiler_opt/rl/local_data_collector.py b/compiler_opt/rl/local_data_collector.py index 4a4944a4..8deb1588 100644 --- a/compiler_opt/rl/local_data_collector.py +++ b/compiler_opt/rl/local_data_collector.py @@ -51,6 +51,7 @@ def __init__( self._worker_pool = worker_pool self._reward_stat_map = reward_stat_map self._best_trajectory_repo = best_trajectory_repo + self._current_futures: List[worker.WorkerFuture] = [] self._prefetch_pool = concurrent.futures.ThreadPoolExecutor() self._next_sample: List[ concurrent.futures.Future] = self._prefetch_next_sample() @@ -116,9 +117,14 @@ def collect_data( self._next_sample = self._prefetch_next_sample() time_before_schedule = time.time() - current_futures = self._schedule_jobs(policy, model_id, sampled_modules) + self._current_futures = self._schedule_jobs(policy, model_id, sampled_modules) - current_work = list(zip(sampled_modules, current_futures)) + # Wait for all futures to complete. We don't do any early-exit checking as + # that functionality has been moved to the + # data_collector.EarlyExitWorkerPool abstraction. + worker.wait_for(self._current_futures) + + current_work = list(zip(sampled_modules, self._current_futures)) def is_cancelled(fut): if not fut.done(): @@ -131,7 +137,7 @@ def is_cancelled(fut): successful_work = [(spec, res.result()) for spec, res in finished_work if not worker.get_exception(res)] - cancelled_work = [res for res in current_futures if is_cancelled(res)] + cancelled_work = [res for res in self._current_futures if is_cancelled(res)] failures = len(finished_work) - len(successful_work) - len(cancelled_work) logging.info(('%d of %d modules finished in %d seconds (%d failures).'), diff --git a/compiler_opt/rl/local_data_collector_test.py b/compiler_opt/rl/local_data_collector_test.py index f361b16a..f0ed51a9 100644 --- a/compiler_opt/rl/local_data_collector_test.py +++ b/compiler_opt/rl/local_data_collector_test.py @@ -233,16 +233,15 @@ def wait(self, _): for i in range(200) ]), num_modules=4, - worker_pool=lwp, + worker_pool=data_collector.EarlyExitWorkerPool(lwp, QuickExiter), parser=parser, reward_stat_map=collections.defaultdict(lambda: None), - best_trajectory_repo=None, - exit_checker_ctor=QuickExiter) + best_trajectory_repo=None) collector.collect_data(policy=_mock_policy, model_id=0) - collector._join_pending_jobs() killed = 0 for w in collector._current_futures: - self.assertRaises(compilation_runner.ProcessKilledError, w.result) + self.assertRaises(data_collector.CancelledForEarlyExitException, + w.result) killed += 1 self.assertEqual(killed, 4) collector.close_pool() From fa2e5ec9a2cb71ca724e61ea409e07dff28c8a93 Mon Sep 17 00:00:00 2001 From: Jacob Hegna Date: Tue, 31 Jan 2023 21:27:54 +0000 Subject: [PATCH 4/6] Fix unintended formatting change. --- compiler_opt/rl/corpus.py | 1 - 1 file changed, 1 deletion(-) diff --git a/compiler_opt/rl/corpus.py b/compiler_opt/rl/corpus.py index 8ea6f038..37b7f0da 100644 --- a/compiler_opt/rl/corpus.py +++ b/compiler_opt/rl/corpus.py @@ -233,7 +233,6 @@ def __init__(self, sampler: Sampler = SamplerBucketRoundRobin()): """ Prepares the corpus by pre-loading all the CorpusElements and preparing for - sampling. Command line origin (.cmd file or override) is decided, and final command line transformation rules are set (i.e. thinlto flags handled, also output) and validated. From 0c7229e26990d1f1456d2f9042e38be783572d19 Mon Sep 17 00:00:00 2001 From: Jacob Hegna Date: Tue, 31 Jan 2023 21:46:37 +0000 Subject: [PATCH 5/6] Add worker_test.py. --- compiler_opt/distributed/worker.py | 11 ++--- compiler_opt/distributed/worker_test.py | 60 +++++++++++++++++++++++++ compiler_opt/rl/local_data_collector.py | 6 ++- 3 files changed, 70 insertions(+), 7 deletions(-) create mode 100644 compiler_opt/distributed/worker_test.py diff --git a/compiler_opt/distributed/worker.py b/compiler_opt/distributed/worker.py index 8aa7b90b..8fee04d4 100644 --- a/compiler_opt/distributed/worker.py +++ b/compiler_opt/distributed/worker.py @@ -15,6 +15,7 @@ """Common abstraction for a worker contract.""" import abc +import concurrent.futures import sys from typing import Any, List, Iterable, Optional, Protocol, TypeVar @@ -66,25 +67,25 @@ def get_exception(worker_future: WorkerFuture) -> Optional[Exception]: def lift_futures_through_list(future_list: WorkerFuture[List], expected_size: int) -> List[WorkerFuture]: """Convert Future[List] to List[Future].""" - flattened = [concurrent.futures.Future() for _ in range(size)] + flattened = [concurrent.futures.Future() for _ in range(expected_size)] def _handler(fut): - if e := get_exception(future_list): + if e := get_exception(fut): for f in flattened: f.set_exception(e) return for i, res in enumerate(fut.result()): - assert i < size + assert i < expected_size if isinstance(res, Exception): flattened[i].set_exception(res) else: flattened[i].set_result(res) - for j in range(i + 1, size): + for j in range(i + 1, expected_size): flattened[j].set_exception( ValueError(f'No value returned for index {j} in future_list')) - fut.add_done_callback(_handler) + future_list.add_done_callback(_handler) return flattened diff --git a/compiler_opt/distributed/worker_test.py b/compiler_opt/distributed/worker_test.py new file mode 100644 index 00000000..216c1611 --- /dev/null +++ b/compiler_opt/distributed/worker_test.py @@ -0,0 +1,60 @@ +# coding=utf-8 +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Test for worker.""" + +from absl.testing import absltest +import concurrent.futures + +from compiler_opt.distributed import worker + + +class LiftFuturesThroughListTest(absltest.TestCase): + + def test_normal_path(self): + expected_list = [1, True, [2.0, False]] + future_list = concurrent.futures.Future() + list_future = worker.lift_futures_through_list(future_list, + len(expected_list)) + future_list.set_result(expected_list) + worker.wait_for(list_future) + + self.assertEqual([f.result() for f in list_future], expected_list) + + def test_with_exceptions_in_list(self): + expected_list = [1, ValueError('error')] + future_list = concurrent.futures.Future() + list_future = worker.lift_futures_through_list(future_list, + len(expected_list)) + future_list.set_result(expected_list) + worker.wait_for(list_future) + + self.assertEqual(list_future[0].result(), expected_list[0]) + self.assertTrue( + isinstance(worker.get_exception(list_future[1]), ValueError)) + + def test_list_is_exception(self): + expected_size = 42 + future_list = concurrent.futures.Future() + list_future = worker.lift_futures_through_list(future_list, expected_size) + future_list.set_exception(ValueError('error')) + + worker.wait_for(list_future) + self.assertEqual(len(list_future), expected_size) + for f in list_future: + self.assertTrue(isinstance(worker.get_exception(f), ValueError)) + + +if __name__ == '__main__': + absltest.main() diff --git a/compiler_opt/rl/local_data_collector.py b/compiler_opt/rl/local_data_collector.py index 8deb1588..dbbca504 100644 --- a/compiler_opt/rl/local_data_collector.py +++ b/compiler_opt/rl/local_data_collector.py @@ -42,7 +42,8 @@ def __init__( parser: Callable[[List[str]], Iterator[trajectory.Trajectory]], reward_stat_map: Dict[str, Optional[Dict[str, compilation_runner.RewardStat]]], - best_trajectory_repo: Optional[best_trajectory.BestTrajectoryRepo]): + best_trajectory_repo: Optional[best_trajectory.BestTrajectoryRepo], + ): super().__init__() self._corpus = cps @@ -117,7 +118,8 @@ def collect_data( self._next_sample = self._prefetch_next_sample() time_before_schedule = time.time() - self._current_futures = self._schedule_jobs(policy, model_id, sampled_modules) + self._current_futures = self._schedule_jobs(policy, model_id, + sampled_modules) # Wait for all futures to complete. We don't do any early-exit checking as # that functionality has been moved to the From 93f6756516e33615b5352f6bad1672fa483b4bce Mon Sep 17 00:00:00 2001 From: Jacob Hegna Date: Tue, 31 Jan 2023 21:52:14 +0000 Subject: [PATCH 6/6] Fixed pytype error. --- compiler_opt/distributed/worker.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/compiler_opt/distributed/worker.py b/compiler_opt/distributed/worker.py index 8fee04d4..5e28aeac 100644 --- a/compiler_opt/distributed/worker.py +++ b/compiler_opt/distributed/worker.py @@ -64,7 +64,7 @@ def get_exception(worker_future: WorkerFuture) -> Optional[Exception]: return e -def lift_futures_through_list(future_list: WorkerFuture[List], +def lift_futures_through_list(future_list: WorkerFuture, expected_size: int) -> List[WorkerFuture]: """Convert Future[List] to List[Future].""" flattened = [concurrent.futures.Future() for _ in range(expected_size)]