|
| 1 | +# coding=utf-8 |
| 2 | +# Copyright 2020 Google LLC |
| 3 | +# |
| 4 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 5 | +# you may not use this file except in compliance with the License. |
| 6 | +# You may obtain a copy of the License at |
| 7 | +# |
| 8 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 9 | +# |
| 10 | +# Unless required by applicable law or agreed to in writing, software |
| 11 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 12 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 13 | +# See the License for the specific language governing permissions and |
| 14 | +# limitations under the License. |
| 15 | +"""Validation data collection module.""" |
| 16 | +import concurrent.futures |
| 17 | +import threading |
| 18 | +import time |
| 19 | +from typing import Dict, Optional, List, Tuple |
| 20 | + |
| 21 | +from absl import logging |
| 22 | + |
| 23 | +from compiler_opt.distributed import worker |
| 24 | +from compiler_opt.distributed.local import buffered_scheduler |
| 25 | +from compiler_opt.distributed.local import cpu_affinity |
| 26 | +from compiler_opt.distributed.local.local_worker_manager import LocalWorkerPool |
| 27 | + |
| 28 | +from compiler_opt.rl import corpus |
| 29 | + |
| 30 | + |
| 31 | +class LocalValidationDataCollector(worker.ContextAwareWorker): |
| 32 | + """Local implementation of a validation data collector |
| 33 | + Args: |
| 34 | + module_specs: List of module specs to use |
| 35 | + worker_pool_args: Pool of workers to use |
| 36 | + """ |
| 37 | + |
| 38 | + def __init__(self, cps: corpus.Corpus, worker_pool_args, reward_stat_map, |
| 39 | + max_cpus): |
| 40 | + self._num_modules = len(cps) if cps is not None else 0 |
| 41 | + self._corpus: corpus.Corpus = cps |
| 42 | + self._default_rewards = {} |
| 43 | + |
| 44 | + self._running_policy = None |
| 45 | + self._default_futures: List[worker.WorkerFuture] = [] |
| 46 | + self._current_work: List[Tuple[corpus.ModuleSpec, worker.WorkerFuture]] = [] |
| 47 | + self._last_time = None |
| 48 | + self._elapsed_time = 0 |
| 49 | + |
| 50 | + self._context_local = True |
| 51 | + |
| 52 | + # Check a bit later so some expected vars have been set first. |
| 53 | + if not cps: |
| 54 | + return |
| 55 | + |
| 56 | + affinities = cpu_affinity.set_and_get( |
| 57 | + is_main_process=False, max_cpus=max_cpus) |
| 58 | + |
| 59 | + # Add some runner specific flags. |
| 60 | + logging.info('Validation data collector using %d workers.', len(affinities)) |
| 61 | + worker_pool_args['count'] = len(affinities) |
| 62 | + worker_pool_args['moving_average_decay_rate'] = 1 |
| 63 | + worker_pool_args['compilation_timeout'] = 1200 |
| 64 | + |
| 65 | + # Borrow from the external reward_stat_map in case it got loaded from disk |
| 66 | + # and already has some values. On a fresh run this will be recalculated |
| 67 | + # from scratch in the main data collector and here. It would be ideal if |
| 68 | + # both shared the same dict, but that would be too complex to implement. |
| 69 | + for name, data in reward_stat_map.items(): |
| 70 | + if name not in self._default_rewards: |
| 71 | + self._default_rewards[name] = {} |
| 72 | + for identifier, reward_stat in data.items(): |
| 73 | + self._default_rewards[name][identifier] = reward_stat.default_reward |
| 74 | + |
| 75 | + self._pool = LocalWorkerPool(**worker_pool_args) |
| 76 | + self._worker_pool = self._pool.stubs |
| 77 | + |
| 78 | + for i, p in zip(affinities, self._worker_pool): |
| 79 | + p.set_nice(19) |
| 80 | + p.set_affinity([i]) |
| 81 | + |
| 82 | + # BEGIN: ContextAwareWorker methods |
| 83 | + @classmethod |
| 84 | + def is_priority_method(cls, _: str) -> bool: |
| 85 | + # Everything is a priority: this is essentially a synchronous RPC endpoint. |
| 86 | + return True |
| 87 | + |
| 88 | + def set_context(self, local: bool): |
| 89 | + self._context_local = local |
| 90 | + |
| 91 | + # END: ContextAwareWorker methods |
| 92 | + |
| 93 | + def _schedule_jobs(self, policy_path, module_specs): |
| 94 | + default_jobs = [] |
| 95 | + for module_spec in module_specs: |
| 96 | + if module_spec.name not in self._default_rewards: |
| 97 | + # The bool is reward_only, None is cancellation_manager |
| 98 | + default_jobs.append((module_spec, '', True, None)) |
| 99 | + |
| 100 | + default_rewards_lock = threading.Lock() |
| 101 | + |
| 102 | + def create_update_rewards(spec_name): |
| 103 | + |
| 104 | + def updater(f: concurrent.futures.Future): |
| 105 | + if f.exception() is not None: |
| 106 | + reward_stat = f.result() |
| 107 | + for identifier, (_, default_reward) in reward_stat: |
| 108 | + with default_rewards_lock: |
| 109 | + self._default_rewards[spec_name][identifier] = default_reward |
| 110 | + |
| 111 | + return updater |
| 112 | + |
| 113 | + # The bool is reward_only, None is cancellation_manager |
| 114 | + policy_jobs = [ |
| 115 | + (module_spec, policy_path, True, None) for module_spec in module_specs |
| 116 | + ] |
| 117 | + |
| 118 | + def work_factory(job): |
| 119 | + |
| 120 | + def work(w): |
| 121 | + return w.compile_fn(*job) |
| 122 | + |
| 123 | + return work |
| 124 | + |
| 125 | + work = [work_factory(job) for job in default_jobs] |
| 126 | + work += [work_factory(job) for job in policy_jobs] |
| 127 | + |
| 128 | + futures = buffered_scheduler.schedule(work, self._worker_pool, buffer=10) |
| 129 | + |
| 130 | + self._default_futures = futures[:len(default_jobs)] |
| 131 | + policy_futures = futures[len(default_jobs):] |
| 132 | + |
| 133 | + for job, future in zip(default_jobs, self._default_futures): |
| 134 | + future.add_done_callback(create_update_rewards(job[0])) |
| 135 | + |
| 136 | + return policy_futures |
| 137 | + |
| 138 | + def collect_data_async( |
| 139 | + self, |
| 140 | + policy_path: str, |
| 141 | + step: int = 0) -> Optional[Dict[tuple, Dict[str, float]]]: |
| 142 | + """Collect data for a given policy. |
| 143 | +
|
| 144 | + Args: |
| 145 | + policy_path: the path to the policy directory to collect data with. |
| 146 | + step: the step number associated with the policy_path |
| 147 | +
|
| 148 | + Returns: |
| 149 | + Either returns data in the form of a dictionary, or returns None if the |
| 150 | + data is not ready yet. |
| 151 | + """ |
| 152 | + if self._num_modules == 0: |
| 153 | + return None |
| 154 | + |
| 155 | + # Resume immediately, so that if new jobs are scheduled, |
| 156 | + # they run while processing last batch's results |
| 157 | + self.resume_children() |
| 158 | + finished_work = [ |
| 159 | + (spec, res) for spec, res in self._current_work if res.done() |
| 160 | + ] |
| 161 | + |
| 162 | + # Check if there are default rewards being collected. |
| 163 | + if len(self._default_futures) > 0: |
| 164 | + finished_default_work = sum(res.done() for res in self._default_futures) |
| 165 | + if finished_default_work != len(self._default_futures): |
| 166 | + logging.info('%d out of %d default-rewards modules are finished.', |
| 167 | + finished_default_work, len(self._default_futures)) |
| 168 | + return None |
| 169 | + |
| 170 | + if len(finished_work) != len(self._current_work): # on 1st iter both are 0 |
| 171 | + logging.info('%d out of %d modules are finished.', len(finished_work), |
| 172 | + len(self._current_work)) |
| 173 | + return None |
| 174 | + module_specs = self._corpus.modules |
| 175 | + results = self._schedule_jobs(policy_path, module_specs) |
| 176 | + self._current_work = list(zip(module_specs, results)) |
| 177 | + prev_policy = self._running_policy |
| 178 | + self._running_policy = step |
| 179 | + |
| 180 | + if len(finished_work) == 0: # 1st iteration this is 0 |
| 181 | + return None |
| 182 | + |
| 183 | + # Since all work is done: reset clock. Essential if processes never paused. |
| 184 | + if self._last_time is not None: |
| 185 | + cur_time = time.time() |
| 186 | + self._elapsed_time += cur_time - self._last_time |
| 187 | + self._last_time = cur_time |
| 188 | + |
| 189 | + successful_work = [(spec, res.result()) |
| 190 | + for spec, res in finished_work |
| 191 | + if not worker.get_exception(res)] |
| 192 | + failures = len(finished_work) - len(successful_work) |
| 193 | + |
| 194 | + logging.info('%d of %d modules finished in %d seconds (%d failures).', |
| 195 | + len(finished_work), self._num_modules, self._elapsed_time, |
| 196 | + failures) |
| 197 | + |
| 198 | + sum_policy = 0 |
| 199 | + sum_default = 0 |
| 200 | + for spec, res in successful_work: |
| 201 | + # res format: {_DEFAULT_IDENTIFIER: (None, native_size)} |
| 202 | + for identifier, (_, policy_reward) in res: |
| 203 | + sum_policy += policy_reward |
| 204 | + sum_default += self._default_rewards[spec.name][identifier] |
| 205 | + |
| 206 | + if sum_default <= 0: |
| 207 | + raise ValueError('Sum of default rewards is 0.') |
| 208 | + reward = 1 - sum_policy / sum_default |
| 209 | + |
| 210 | + monitor_dict = { |
| 211 | + prev_policy: { |
| 212 | + 'success_modules': len(successful_work), |
| 213 | + 'compile_wall_time': self._elapsed_time, |
| 214 | + 'sum_reward': reward |
| 215 | + } |
| 216 | + } |
| 217 | + self._elapsed_time = 0 # Only on completion this is reset |
| 218 | + return monitor_dict |
| 219 | + |
| 220 | + def pause_children(self): |
| 221 | + if not self._context_local or self._running_policy is None: |
| 222 | + return |
| 223 | + |
| 224 | + for p in self._worker_pool: |
| 225 | + p.pause_all_work() |
| 226 | + |
| 227 | + if self._last_time is not None: |
| 228 | + self._elapsed_time += time.time() - self._last_time |
| 229 | + self._last_time = None |
| 230 | + |
| 231 | + def resume_children(self): |
| 232 | + last_time_was_none = False |
| 233 | + if self._last_time is None: |
| 234 | + last_time_was_none = True |
| 235 | + self._last_time = time.time() |
| 236 | + |
| 237 | + if not self._context_local or self._running_policy is None: |
| 238 | + return |
| 239 | + |
| 240 | + # Only pause changes last_time to None. |
| 241 | + if last_time_was_none: |
| 242 | + for p in self._worker_pool: |
| 243 | + p.resume_all_work() |
0 commit comments