diff --git a/compiler_opt/distributed/local/local_worker_manager.py b/compiler_opt/distributed/local/local_worker_manager.py index eb5b9f46..425d83b3 100644 --- a/compiler_opt/distributed/local/local_worker_manager.py +++ b/compiler_opt/distributed/local/local_worker_manager.py @@ -31,6 +31,7 @@ import dataclasses import functools import multiprocessing +import psutil import threading from absl import logging @@ -39,7 +40,7 @@ from contextlib import AbstractContextManager from multiprocessing import connection -from typing import Any, Callable, Dict, Optional +from typing import Any, Callable, Dict, List, Optional @dataclasses.dataclass(frozen=True) @@ -214,6 +215,18 @@ def shutdown(self): except: # pylint: disable=bare-except pass + def set_nice(self, val: int): + """Sets the nice-ness of the process, this modifies how the OS + schedules it. Only works on Unix, since val is presumed to be an int. + """ + psutil.Process(self._process.pid).nice(val) + + def set_affinity(self, val: List[int]): + """Sets the CPU affinity of the process, this modifies which cores the OS + schedules it on. + """ + psutil.Process(self._process.pid).cpu_affinity(val) + def join(self): self._observer.join() self._pump.join() @@ -247,3 +260,11 @@ def __exit__(self, *args): # now wait for the message pumps to indicate they exit. for s in self._stubs: s.join() + + def __del__(self): + self.__exit__() + + @property + def stubs(self): + # Return a shallow copy, to avoid something messing the internal list up + return list(self._stubs) diff --git a/compiler_opt/distributed/local/local_worker_manager_test.py b/compiler_opt/distributed/local/local_worker_manager_test.py index 21a8c0ab..75aa9075 100644 --- a/compiler_opt/distributed/local/local_worker_manager_test.py +++ b/compiler_opt/distributed/local/local_worker_manager_test.py @@ -15,7 +15,6 @@ """Test for local worker manager.""" import concurrent.futures -import multiprocessing import time from absl.testing import absltest diff --git a/compiler_opt/distributed/worker.py b/compiler_opt/distributed/worker.py index 46b32df7..ff040d8e 100644 --- a/compiler_opt/distributed/worker.py +++ b/compiler_opt/distributed/worker.py @@ -14,10 +14,10 @@ # limitations under the License. """Common abstraction for a worker contract.""" -from typing import Iterable, Optional, TypeVar, Protocol +from typing import Iterable, Optional, Protocol, TypeVar -class Worker: +class Worker(Protocol): @classmethod def is_priority_method(cls, method_name: str) -> bool: diff --git a/compiler_opt/rl/compilation_runner.py b/compiler_opt/rl/compilation_runner.py index aa835de1..91060f80 100644 --- a/compiler_opt/rl/compilation_runner.py +++ b/compiler_opt/rl/compilation_runner.py @@ -18,6 +18,7 @@ import dataclasses import json import os +import signal import subprocess import threading from typing import Dict, List, Optional, Tuple @@ -108,6 +109,7 @@ def __init__(self): # empty() is accurate and get() never blocks. self._processes = set() self._done = False + self._paused = False self._lock = threading.Lock() def enable(self): @@ -129,10 +131,33 @@ def kill_all_processes(self): for p in self._processes: kill_process_ignore_exceptions(p) + def pause_all_processes(self): + with self._lock: + if self._paused: + return + self._paused = True + + for p in self._processes: + # used to send the STOP signal; does not actually kill the process + os.kill(p.pid, signal.SIGSTOP) + + def resume_all_processes(self): + with self._lock: + if not self._paused: + return + self._paused = False + + for p in self._processes: + # used to send the CONTINUE signal; does not actually kill the process + os.kill(p.pid, signal.SIGCONT) + def unregister_process(self, p: 'subprocess.Popen[bytes]'): with self._lock: - if not self._done: - self._processes.remove(p) + self._processes.remove(p) + + def __del__(self): + if len(self._processes) > 0: + raise RuntimeError('Cancellation manager deleted while containing items.') def start_cancellable_process( @@ -174,6 +199,7 @@ def start_cancellable_process( finally: if cancellation_manager: cancellation_manager.unregister_process(p) + if retcode != 0: raise ProcessKilledError( ) if retcode == -9 else subprocess.CalledProcessError(retcode, cmdline) @@ -249,7 +275,9 @@ class CompilationRunner(Worker): @classmethod def is_priority_method(cls, method_name: str) -> bool: - return method_name == 'cancel_all_work' + return method_name in { + 'cancel_all_work', 'pause_all_work', 'resume_all_work' + } def __init__(self, clang_path: Optional[str] = None, @@ -275,6 +303,12 @@ def enable(self): def cancel_all_work(self): self._cancellation_manager.kill_all_processes() + def pause_all_work(self): + self._cancellation_manager.pause_all_processes() + + def resume_all_work(self): + self._cancellation_manager.resume_all_processes() + def collect_data( self, module_spec: corpus.ModuleSpec, tf_policy_path: str, reward_stat: Optional[Dict[str, RewardStat]]) -> CompilationResult: diff --git a/compiler_opt/rl/compilation_runner_test.py b/compiler_opt/rl/compilation_runner_test.py index 28b7178f..4390d51d 100644 --- a/compiler_opt/rl/compilation_runner_test.py +++ b/compiler_opt/rl/compilation_runner_test.py @@ -17,6 +17,7 @@ import os import string import subprocess +import threading import time from unittest import mock @@ -213,9 +214,9 @@ def test_exception_handling(self, mock_compile_fn): self.assertEqual(1, mock_compile_fn.call_count) def test_start_subprocess_output(self): - ct = compilation_runner.WorkerCancellationManager() + cm = compilation_runner.WorkerCancellationManager() output = compilation_runner.start_cancellable_process( - ['ls', '-l'], timeout=100, cancellation_manager=ct, want_output=True) + ['ls', '-l'], timeout=100, cancellation_manager=cm, want_output=True) if output: output_str = output.decode('utf-8') else: @@ -235,6 +236,23 @@ def test_timeout_kills_process(self): time.sleep(2) self.assertFalse(os.path.exists(sentinel_file)) + def test_pause_resume(self): + cm = compilation_runner.WorkerCancellationManager() + start_time = time.time() + + def stop_and_start(): + time.sleep(0.25) + cm.pause_all_processes() + time.sleep(1) + cm.resume_all_processes() + + threading.Thread(target=stop_and_start).start() + compilation_runner.start_cancellable_process(['sleep', '0.5'], + 30, + cancellation_manager=cm) + # should be at least 1 second due to the pause. + self.assertGreater(time.time() - start_time, 1) + if __name__ == '__main__': tf.test.main() diff --git a/requirements.txt b/requirements.txt index 98456846..aa710053 100644 --- a/requirements.txt +++ b/requirements.txt @@ -27,6 +27,7 @@ oauthlib==3.1.1 opt-einsum==3.3.0 pillow==8.3.1 protobuf==3.17.3 +psutil==5.9.0 pyasn1==0.4.8 pyasn1_modules==0.2.8 pyglet==1.5.0