Skip to content

Commit 16e9a86

Browse files
committed
Introduce Validation Data Collector
- Allows concurrent evaluation of models on a separate dataset during training, with --validation_data_path - This is done with minimal impact on training time by only utilizing the CPU for the validation dataset when it is mostly idle doing tf.train(), and pinning processes to specific CPUs - The amount of impact can be adjusted via a gin.config on cpu_affinity.py - CPU affinities are only optimized for internal AMD-Zen based systems at the moment, but can be extended in the future.
1 parent 99671e4 commit 16e9a86

11 files changed

+380
-28
lines changed

compiler_opt/distributed/local/local_worker_manager.py

Lines changed: 40 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,14 +32,17 @@
3232
import functools
3333
import multiprocessing
3434
import threading
35+
import os
36+
import psutil
37+
import signal
3538

3639
from absl import logging
3740
# pylint: disable=unused-import
3841
from compiler_opt.distributed.worker import Worker
3942

4043
from contextlib import AbstractContextManager
4144
from multiprocessing import connection
42-
from typing import Any, Callable, Dict, Optional
45+
from typing import Any, Callable, Dict, Optional, List
4346

4447

4548
@dataclasses.dataclass(frozen=True)
@@ -131,6 +134,7 @@ def __init__(self):
131134
# when we stop.
132135
self._lock = threading.Lock()
133136
self._map: Dict[int, concurrent.futures.Future] = {}
137+
self.is_paused = False
134138

135139
# thread draining the pipe
136140
self._pump = threading.Thread(target=self._msg_pump)
@@ -205,10 +209,37 @@ def shutdown(self):
205209
try:
206210
# Killing the process triggers observer exit, which triggers msg_pump
207211
# exit
212+
self.resume()
208213
self._process.kill()
209214
except: # pylint: disable=bare-except
210215
pass
211216

217+
def pause(self):
218+
if self.is_paused:
219+
return
220+
self.is_paused = True
221+
# used to send the STOP signal; does not actually kill the process
222+
os.kill(self._process.pid, signal.SIGSTOP)
223+
224+
def resume(self):
225+
if not self.is_paused:
226+
return
227+
self.is_paused = False
228+
# used to send the CONTINUE signal; does not actually kill the process
229+
os.kill(self._process.pid, signal.SIGCONT)
230+
231+
def set_nice(self, val: int):
232+
"""Sets the nice-ness of the process, this modifies how the OS
233+
schedules it. Only works on Unix, since val is presumed to be an int.
234+
"""
235+
psutil.Process(self._process.pid).nice(val)
236+
237+
def set_affinity(self, val: List[int]):
238+
"""Sets the CPU affinity of the process, this modifies which cores the OS
239+
schedules it on.
240+
"""
241+
psutil.Process(self._process.pid).cpu_affinity(val)
242+
212243
def join(self):
213244
self._observer.join()
214245
self._pump.join()
@@ -242,3 +273,11 @@ def __exit__(self, *args):
242273
# now wait for the message pumps to indicate they exit.
243274
for s in self._stubs:
244275
s.join()
276+
277+
def __del__(self):
278+
self.__exit__()
279+
280+
@property
281+
def stubs(self):
282+
# Return a shallow copy, to avoid something messing the internal list up
283+
return list(self._stubs)

compiler_opt/distributed/worker.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,17 +15,24 @@
1515
"""Common abstraction for a worker contract."""
1616

1717
import abc
18-
from typing import Generic, Iterable, Optional, TypeVar
18+
from typing import Generic, Iterable, Optional, TypeVar, Protocol, runtime_checkable
1919

2020

21-
class Worker:
21+
class Worker(Protocol):
2222

2323
@classmethod
2424
def is_priority_method(cls, method_name: str) -> bool:
2525
_ = method_name
2626
return False
2727

2828

29+
@runtime_checkable
30+
class ContextAwareWorker(Worker, Protocol):
31+
32+
def set_context(self, local: bool) -> None:
33+
return
34+
35+
2936
T = TypeVar('T')
3037

3138

compiler_opt/rl/compilation_runner.py

Lines changed: 18 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -212,10 +212,12 @@ class CompilationResult:
212212

213213
def __post_init__(self, sequence_examples: List[tf.train.SequenceExample]):
214214
object.__setattr__(self, 'serialized_sequence_examples',
215-
[x.SerializeToString() for x in sequence_examples])
215+
[(x.SerializeToString() if x is not None else None)
216+
for x in sequence_examples])
216217
lengths = [
217218
len(next(iter(x.feature_lists.feature_list.values())).feature)
218219
for x in sequence_examples
220+
if x is not None
219221
]
220222
object.__setattr__(self, 'length', sum(lengths))
221223

@@ -229,10 +231,9 @@ class CompilationRunnerStub(metaclass=abc.ABCMeta):
229231
"""The interface of a stub to CompilationRunner, for type checkers."""
230232

231233
@abc.abstractmethod
232-
def collect_data(
233-
self, module_spec: corpus.ModuleSpec, tf_policy_path: str,
234-
reward_stat: Optional[Dict[str, RewardStat]]
235-
) -> WorkerFuture[CompilationResult]:
234+
def collect_data(self, module_spec: corpus.ModuleSpec, tf_policy_path: str,
235+
reward_stat: Optional[Dict[str, RewardStat]],
236+
raw_reward_only: bool) -> WorkerFuture[CompilationResult]:
236237
raise NotImplementedError()
237238

238239
@abc.abstractmethod
@@ -275,17 +276,18 @@ def enable(self):
275276
def cancel_all_work(self):
276277
self._cancellation_manager.kill_all_processes()
277278

278-
def collect_data(
279-
self, module_spec: corpus.ModuleSpec, tf_policy_path: str,
280-
reward_stat: Optional[Dict[str, RewardStat]]) -> CompilationResult:
279+
def collect_data(self,
280+
module_spec: corpus.ModuleSpec,
281+
tf_policy_path: str,
282+
reward_stat: Optional[Dict[str, RewardStat]],
283+
raw_reward_only=False) -> CompilationResult:
281284
"""Collect data for the given IR file and policy.
282285
283286
Args:
284287
module_spec: a ModuleSpec.
285288
tf_policy_path: path to the tensorflow policy.
286289
reward_stat: reward stat of this module, None if unknown.
287-
cancellation_token: a CancellationToken through which workers may be
288-
signaled early termination
290+
raw_reward_only: whether to return the raw reward value without examples
289291
290292
Returns:
291293
A CompilationResult. In particular:
@@ -311,7 +313,7 @@ def collect_data(
311313
policy_result = self._compile_fn(
312314
module_spec,
313315
tf_policy_path,
314-
reward_only=False,
316+
reward_only=raw_reward_only,
315317
cancellation_manager=self._cancellation_manager)
316318
else:
317319
policy_result = default_result
@@ -327,6 +329,11 @@ def collect_data(
327329
(f'Example {k} does not exist under default policy for '
328330
f'module {module_spec.name}'))
329331
default_reward = reward_stat[k].default_reward
332+
if raw_reward_only:
333+
sequence_example_list.append(None)
334+
rewards.append(policy_reward)
335+
keys.append(k)
336+
continue
330337
moving_average_reward = reward_stat[k].moving_average_reward
331338
sequence_example = _overwrite_trajectory_reward(
332339
sequence_example=sequence_example,

compiler_opt/rl/cpu_affinity.py

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
# coding=utf-8
2+
# Copyright 2020 Google LLC
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
"""Utility functions to set cpu affinities when operating main and subprocesses
16+
simultaneously."""
17+
import gin
18+
import psutil
19+
import itertools
20+
21+
N = psutil.cpu_count()
22+
23+
CPU_CONFIG = { # List of CPU numbers in cache-sharing order.
24+
# 'google-epyc' assumes logical core 0 and N/2 are the same physical core.
25+
# Also, L3 cache is assumed to be shared between consecutive core numbers.
26+
'google-epyc': list(itertools.chain(*zip(range(N // 2), range(N // 2, N))))
27+
}
28+
29+
30+
@gin.configurable
31+
def set_and_get(is_main_process: bool,
32+
max_cpus=N,
33+
min_main_cpu: int = 32,
34+
arch: str = 'google-epyc'):
35+
"""
36+
Sets the cpu affinity of the current process to appropriate values, and
37+
returns the list of cpus the process is set to use.
38+
Args:
39+
is_main_process: whether the caller is the main process.
40+
max_cpus: maximal number of cpus to use
41+
min_main_cpu: number of cpus to assign to the main process.
42+
arch: the system type, used to infer the cpu cache architecture.
43+
"""
44+
config = CPU_CONFIG[arch][:max_cpus]
45+
if is_main_process:
46+
cpus = config[:min_main_cpu]
47+
else:
48+
cpus = config[min_main_cpu:]
49+
if len(cpus) == 0:
50+
raise ValueError('Attempting to set cpu affinity of process to nothing.')
51+
psutil.Process().cpu_affinity(cpus)
52+
return list(cpus)

compiler_opt/rl/local_data_collector.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
import itertools
1919
import random
2020
import time
21-
from typing import Callable, Dict, Iterator, List, Tuple, Optional
21+
from typing import Callable, Dict, Iterator, List, Tuple, Optional, Any
2222

2323
from absl import logging
2424
from tf_agents.trajectories import trajectory
@@ -55,7 +55,7 @@ def __init__(
5555
# We remove this activity from the critical path by running it concurrently
5656
# with the training phase - i.e. whatever happens between successive data
5757
# collection calls. Subsequent runs will wait for these to finish.
58-
self._reset_workers: concurrent.futures.Future = None
58+
self._reset_workers: Optional[concurrent.futures.Future] = None
5959
self._current_work: List[Tuple[corpus.ModuleSpec, worker.WorkerFuture]] = []
6060
self._pool = concurrent.futures.ThreadPoolExecutor()
6161

@@ -77,20 +77,25 @@ def _join_pending_jobs(self):
7777
logging.info('Waiting for pending work from last iteration took %f',
7878
time.time() - t1)
7979

80+
def _create_jobs(
81+
self, policy_path: str, sampled_modules: List[corpus.ModuleSpec]
82+
) -> Tuple[List[Tuple[Any, ...]], List[Optional[Dict]]]:
83+
return [(module_spec, policy_path, self._reward_stat_map[module_spec.name])
84+
for module_spec in sampled_modules], [{}] * len(sampled_modules)
85+
8086
def _schedule_jobs(
8187
self, policy_path: str, sampled_modules: List[corpus.ModuleSpec]
8288
) -> List[worker.WorkerFuture[compilation_runner.CompilationResult]]:
8389
# by now, all the pending work, which was signaled to cancel, must've
8490
# finished
8591
self._join_pending_jobs()
86-
jobs = [(module_spec, policy_path, self._reward_stat_map[module_spec.name])
87-
for module_spec in sampled_modules]
92+
args, kwargs = self._create_jobs(policy_path, sampled_modules)
8893

8994
# Naive load balancing.
9095
ret = []
91-
for i in range(len(jobs)):
96+
for i, (arg, kwarg) in enumerate(zip(args, kwargs)):
9297
ret.append(self._worker_pool[i % len(self._worker_pool)].collect_data(
93-
*(jobs[i])))
98+
*arg, **kwarg))
9499
return ret
95100

96101
def collect_data(

compiler_opt/rl/local_data_collector_test.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,11 @@ def mock_collect_data(module_spec, tf_policy_dir, reward_stat):
8080
class Sleeper(compilation_runner.CompilationRunner):
8181
"""Test CompilationRunner that just sleeps."""
8282

83-
def collect_data(self, module_spec, tf_policy_path, reward_stat):
83+
def collect_data(self,
84+
module_spec,
85+
tf_policy_path,
86+
reward_stat,
87+
raw_reward_only=False):
8488
_ = module_spec, tf_policy_path, reward_stat
8589
compilation_runner.start_cancellable_process(['sleep', '3600s'], 3600,
8690
self._cancellation_manager)

0 commit comments

Comments
 (0)