Skip to content

Commit bd8f87a

Browse files
committed
Introduce Validation Data Collector
- Allows concurrent evaluation of models on a separate dataset during training, with --validation_data_path - This is done with minimal impact on training time by only utilizing the CPU for the validation dataset when it is mostly idle doing tf.train(), and pinning processes to specific CPUs - The amount of impact can be adjusted via a gin.config on cpu_affinity.py - CPU affinities are only optimized for internal AMD-Zen based systems at the moment, but can be extended in the future.
1 parent 390faa2 commit bd8f87a

File tree

9 files changed

+395
-7
lines changed

9 files changed

+395
-7
lines changed
Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
# coding=utf-8
2+
# Copyright 2020 Google LLC
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
"""Utility functions to set cpu affinities when operating main and subprocesses
16+
simultaneously."""
17+
import gin
18+
import psutil
19+
import itertools
20+
21+
_NR_CPUS = psutil.cpu_count()
22+
23+
_CPU_CONFIG = { # List of CPU numbers in cache-sharing order.
24+
# 'google-epyc' assumes logical core 0 and N/2 are the same physical core.
25+
# Also, L3 cache is assumed to be shared between consecutive core numbers.
26+
'google-epyc':
27+
list(
28+
itertools.chain(
29+
*zip(range(_NR_CPUS // 2), range(_NR_CPUS // 2, _NR_CPUS))))
30+
}
31+
32+
33+
@gin.configurable
34+
def set_and_get(is_main_process: bool,
35+
max_cpus=_NR_CPUS,
36+
min_main_cpu: int = 32,
37+
arch: str = 'google-epyc'):
38+
"""
39+
Sets the cpu affinity of the current process to appropriate values, and
40+
returns the list of cpus the process is set to use.
41+
Args:
42+
is_main_process: whether the caller is the main process.
43+
max_cpus: maximal number of cpus to use
44+
min_main_cpu: number of cpus to assign to the main process.
45+
arch: the system type, used to infer the cpu cache architecture.
46+
"""
47+
config = _CPU_CONFIG[arch][:max_cpus]
48+
if is_main_process:
49+
cpus = config[:min_main_cpu]
50+
else:
51+
cpus = config[min_main_cpu:]
52+
if len(cpus) == 0:
53+
raise ValueError('Attempting to set cpu affinity of process to nothing.')
54+
psutil.Process().cpu_affinity(cpus)
55+
return list(cpus)
Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
# coding=utf-8
2+
# Copyright 2020 Google LLC
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
"""Test for cpu_affinity."""
16+
17+
from absl.testing import absltest
18+
from compiler_opt.distributed.local import cpu_affinity
19+
# pylint: disable=protected-access
20+
21+
22+
class CpuAffinityTest(absltest.TestCase):
23+
24+
def test_tally(self):
25+
for v in cpu_affinity._CPU_CONFIG:
26+
self.assertLen(set(v), cpu_affinity._NR_CPUS)
27+
28+
29+
if __name__ == '__main__':
30+
absltest.main()

compiler_opt/distributed/worker.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
# limitations under the License.
1515
"""Common abstraction for a worker contract."""
1616

17-
from typing import Iterable, Optional, Protocol, TypeVar
17+
from typing import Iterable, Optional, TypeVar, Protocol, runtime_checkable
1818

1919

2020
class Worker(Protocol):
@@ -25,6 +25,13 @@ def is_priority_method(cls, method_name: str) -> bool:
2525
return False
2626

2727

28+
@runtime_checkable
29+
class ContextAwareWorker(Worker, Protocol):
30+
31+
def set_context(self, local: bool) -> None:
32+
return
33+
34+
2835
T = TypeVar('T')
2936

3037

compiler_opt/rl/compilation_runner.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -283,7 +283,8 @@ def is_priority_method(cls, method_name: str) -> bool:
283283
def __init__(self,
284284
clang_path: Optional[str] = None,
285285
launcher_path: Optional[str] = None,
286-
moving_average_decay_rate: float = 1):
286+
moving_average_decay_rate: float = 1,
287+
compilation_timeout=_COMPILATION_TIMEOUT.value):
287288
"""Initialization of CompilationRunner class.
288289
289290
Args:
@@ -294,7 +295,7 @@ def __init__(self,
294295
self._clang_path = clang_path
295296
self._launcher_path = launcher_path
296297
self._moving_average_decay_rate = moving_average_decay_rate
297-
self._compilation_timeout = _COMPILATION_TIMEOUT.value
298+
self._compilation_timeout = compilation_timeout
298299
self._cancellation_manager = WorkerCancellationManager()
299300

300301
# re-allow the cancellation manager accept work.

compiler_opt/rl/corpus.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,10 @@ def filter(self, p: re.Pattern):
122122
"""Filters module specs, keeping those which match the provided pattern."""
123123
self._module_specs = [ms for ms in self._module_specs if p.match(ms.name)]
124124

125+
@property
126+
def modules(self):
127+
return list(self._module_specs)
128+
125129
def __len__(self):
126130
return len(self._module_specs)
127131

compiler_opt/rl/inlining/inlining_runner.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,8 @@ def compile_fn(
7171
cancelled work.
7272
RuntimeError: if llvm-size produces unexpected output.
7373
"""
74+
if cancellation_manager is None:
75+
cancellation_manager = self._cancellation_manager
7476
working_dir = tempfile.mkdtemp()
7577

7678
log_path = os.path.join(working_dir, 'log')
Lines changed: 243 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,243 @@
1+
# coding=utf-8
2+
# Copyright 2020 Google LLC
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
"""Validation data collection module."""
16+
import concurrent.futures
17+
import threading
18+
import time
19+
from typing import Dict, Optional, List, Tuple
20+
21+
from absl import logging
22+
23+
from compiler_opt.distributed import worker
24+
from compiler_opt.distributed.local import buffered_scheduler
25+
from compiler_opt.distributed.local import cpu_affinity
26+
from compiler_opt.distributed.local.local_worker_manager import LocalWorkerPool
27+
28+
from compiler_opt.rl import corpus
29+
30+
31+
class LocalValidationDataCollector(worker.ContextAwareWorker):
32+
"""Local implementation of a validation data collector
33+
Args:
34+
module_specs: List of module specs to use
35+
worker_pool_args: Pool of workers to use
36+
"""
37+
38+
def __init__(self, cps: corpus.Corpus, worker_pool_args, reward_stat_map,
39+
max_cpus):
40+
self._num_modules = len(cps) if cps is not None else 0
41+
self._corpus: corpus.Corpus = cps
42+
self._default_rewards = {}
43+
44+
self._running_policy = None
45+
self._default_futures: List[worker.WorkerFuture] = []
46+
self._current_work: List[Tuple[corpus.ModuleSpec, worker.WorkerFuture]] = []
47+
self._last_time = None
48+
self._elapsed_time = 0
49+
50+
self._context_local = True
51+
52+
# Check a bit later so some expected vars have been set first.
53+
if not cps:
54+
return
55+
56+
affinities = cpu_affinity.set_and_get(
57+
is_main_process=False, max_cpus=max_cpus)
58+
59+
# Add some runner specific flags.
60+
logging.info('Validation data collector using %d workers.', len(affinities))
61+
worker_pool_args['count'] = len(affinities)
62+
worker_pool_args['moving_average_decay_rate'] = 1
63+
worker_pool_args['compilation_timeout'] = 1200
64+
65+
# Borrow from the external reward_stat_map in case it got loaded from disk
66+
# and already has some values. On a fresh run this will be recalculated
67+
# from scratch in the main data collector and here. It would be ideal if
68+
# both shared the same dict, but that would be too complex to implement.
69+
for name, data in reward_stat_map.items():
70+
if name not in self._default_rewards:
71+
self._default_rewards[name] = {}
72+
for identifier, reward_stat in data.items():
73+
self._default_rewards[name][identifier] = reward_stat.default_reward
74+
75+
self._pool = LocalWorkerPool(**worker_pool_args)
76+
self._worker_pool = self._pool.stubs
77+
78+
for i, p in zip(affinities, self._worker_pool):
79+
p.set_nice(19)
80+
p.set_affinity([i])
81+
82+
# BEGIN: ContextAwareWorker methods
83+
@classmethod
84+
def is_priority_method(cls, _: str) -> bool:
85+
# Everything is a priority: this is essentially a synchronous RPC endpoint.
86+
return True
87+
88+
def set_context(self, local: bool):
89+
self._context_local = local
90+
91+
# END: ContextAwareWorker methods
92+
93+
def _schedule_jobs(self, policy_path, module_specs):
94+
default_jobs = []
95+
for module_spec in module_specs:
96+
if module_spec.name not in self._default_rewards:
97+
# The bool is reward_only, None is cancellation_manager
98+
default_jobs.append((module_spec, '', True, None))
99+
100+
default_rewards_lock = threading.Lock()
101+
102+
def create_update_rewards(spec_name):
103+
104+
def updater(f: concurrent.futures.Future):
105+
if f.exception() is not None:
106+
reward_stat = f.result()
107+
for identifier, (_, default_reward) in reward_stat:
108+
with default_rewards_lock:
109+
self._default_rewards[spec_name][identifier] = default_reward
110+
111+
return updater
112+
113+
# The bool is reward_only, None is cancellation_manager
114+
policy_jobs = [
115+
(module_spec, policy_path, True, None) for module_spec in module_specs
116+
]
117+
118+
def work_factory(job):
119+
120+
def work(w):
121+
return w.compile_fn(*job)
122+
123+
return work
124+
125+
work = [work_factory(job) for job in default_jobs]
126+
work += [work_factory(job) for job in policy_jobs]
127+
128+
futures = buffered_scheduler.schedule(work, self._worker_pool, buffer=10)
129+
130+
self._default_futures = futures[:len(default_jobs)]
131+
policy_futures = futures[len(default_jobs):]
132+
133+
for job, future in zip(default_jobs, self._default_futures):
134+
future.add_done_callback(create_update_rewards(job[0]))
135+
136+
return policy_futures
137+
138+
def collect_data_async(
139+
self,
140+
policy_path: str,
141+
step: int = 0) -> Optional[Dict[tuple, Dict[str, float]]]:
142+
"""Collect data for a given policy.
143+
144+
Args:
145+
policy_path: the path to the policy directory to collect data with.
146+
step: the step number associated with the policy_path
147+
148+
Returns:
149+
Either returns data in the form of a dictionary, or returns None if the
150+
data is not ready yet.
151+
"""
152+
if self._num_modules == 0:
153+
return None
154+
155+
# Resume immediately, so that if new jobs are scheduled,
156+
# they run while processing last batch's results
157+
self.resume_children()
158+
finished_work = [
159+
(spec, res) for spec, res in self._current_work if res.done()
160+
]
161+
162+
# Check if there are default rewards being collected.
163+
if len(self._default_futures) > 0:
164+
finished_default_work = sum(res.done() for res in self._default_futures)
165+
if finished_default_work != len(self._default_futures):
166+
logging.info('%d out of %d default-rewards modules are finished.',
167+
finished_default_work, len(self._default_futures))
168+
return None
169+
170+
if len(finished_work) != len(self._current_work): # on 1st iter both are 0
171+
logging.info('%d out of %d modules are finished.', len(finished_work),
172+
len(self._current_work))
173+
return None
174+
module_specs = self._corpus.modules
175+
results = self._schedule_jobs(policy_path, module_specs)
176+
self._current_work = list(zip(module_specs, results))
177+
prev_policy = self._running_policy
178+
self._running_policy = step
179+
180+
if len(finished_work) == 0: # 1st iteration this is 0
181+
return None
182+
183+
# Since all work is done: reset clock. Essential if processes never paused.
184+
if self._last_time is not None:
185+
cur_time = time.time()
186+
self._elapsed_time += cur_time - self._last_time
187+
self._last_time = cur_time
188+
189+
successful_work = [(spec, res.result())
190+
for spec, res in finished_work
191+
if not worker.get_exception(res)]
192+
failures = len(finished_work) - len(successful_work)
193+
194+
logging.info('%d of %d modules finished in %d seconds (%d failures).',
195+
len(finished_work), self._num_modules, self._elapsed_time,
196+
failures)
197+
198+
sum_policy = 0
199+
sum_default = 0
200+
for spec, res in successful_work:
201+
# res format: {_DEFAULT_IDENTIFIER: (None, native_size)}
202+
for identifier, (_, policy_reward) in res:
203+
sum_policy += policy_reward
204+
sum_default += self._default_rewards[spec.name][identifier]
205+
206+
if sum_default <= 0:
207+
raise ValueError('Sum of default rewards is 0.')
208+
reward = 1 - sum_policy / sum_default
209+
210+
monitor_dict = {
211+
prev_policy: {
212+
'success_modules': len(successful_work),
213+
'compile_wall_time': self._elapsed_time,
214+
'sum_reward': reward
215+
}
216+
}
217+
self._elapsed_time = 0 # Only on completion this is reset
218+
return monitor_dict
219+
220+
def pause_children(self):
221+
if not self._context_local or self._running_policy is None:
222+
return
223+
224+
for p in self._worker_pool:
225+
p.pause_all_work()
226+
227+
if self._last_time is not None:
228+
self._elapsed_time += time.time() - self._last_time
229+
self._last_time = None
230+
231+
def resume_children(self):
232+
last_time_was_none = False
233+
if self._last_time is None:
234+
last_time_was_none = True
235+
self._last_time = time.time()
236+
237+
if not self._context_local or self._running_policy is None:
238+
return
239+
240+
# Only pause changes last_time to None.
241+
if last_time_was_none:
242+
for p in self._worker_pool:
243+
p.resume_all_work()

0 commit comments

Comments
 (0)