From d3bb4657d6d928b99efeae245e920e74a3731527 Mon Sep 17 00:00:00 2001 From: Jin Xin Ng Date: Thu, 11 Aug 2022 15:11:19 -0700 Subject: [PATCH 1/5] Add pause/resume/context to workers - Allows a user to start/stop processes at will, via OS signals SIGSTOP and SIGCONT. - Allows a user to bind processes to specific CPUs. - Allows local_worker_pool to be used outside of a context manager - Switch workers to be Protocol based, so Workers are effectively duck-typed (i.e. anything that has the required methods passes as a Worker) Part of #96 --- .../distributed/local/local_worker_manager.py | 41 ++++++++++++++++++- compiler_opt/distributed/worker.py | 15 ++++++- requirements.txt | 1 + 3 files changed, 54 insertions(+), 3 deletions(-) diff --git a/compiler_opt/distributed/local/local_worker_manager.py b/compiler_opt/distributed/local/local_worker_manager.py index eb5b9f46..9ab0a38b 100644 --- a/compiler_opt/distributed/local/local_worker_manager.py +++ b/compiler_opt/distributed/local/local_worker_manager.py @@ -32,6 +32,9 @@ import functools import multiprocessing import threading +import os +import psutil +import signal from absl import logging # pylint: disable=unused-import @@ -39,7 +42,7 @@ from contextlib import AbstractContextManager from multiprocessing import connection -from typing import Any, Callable, Dict, Optional +from typing import Any, Callable, Dict, Optional, List @dataclasses.dataclass(frozen=True) @@ -131,6 +134,7 @@ def __init__(self): # when we stop. self._lock = threading.Lock() self._map: Dict[int, concurrent.futures.Future] = {} + self.is_paused = False # thread draining the pipe self._pump = threading.Thread(target=self._msg_pump) @@ -210,10 +214,37 @@ def shutdown(self): try: # Killing the process triggers observer exit, which triggers msg_pump # exit + self.resume() self._process.kill() except: # pylint: disable=bare-except pass + def pause(self): + if self.is_paused: + return + self.is_paused = True + # used to send the STOP signal; does not actually kill the process + os.kill(self._process.pid, signal.SIGSTOP) + + def resume(self): + if not self.is_paused: + return + self.is_paused = False + # used to send the CONTINUE signal; does not actually kill the process + os.kill(self._process.pid, signal.SIGCONT) + + def set_nice(self, val: int): + """Sets the nice-ness of the process, this modifies how the OS + schedules it. Only works on Unix, since val is presumed to be an int. + """ + psutil.Process(self._process.pid).nice(val) + + def set_affinity(self, val: List[int]): + """Sets the CPU affinity of the process, this modifies which cores the OS + schedules it on. + """ + psutil.Process(self._process.pid).cpu_affinity(val) + def join(self): self._observer.join() self._pump.join() @@ -247,3 +278,11 @@ def __exit__(self, *args): # now wait for the message pumps to indicate they exit. for s in self._stubs: s.join() + + def __del__(self): + self.__exit__() + + @property + def stubs(self): + # Return a shallow copy, to avoid something messing the internal list up + return list(self._stubs) diff --git a/compiler_opt/distributed/worker.py b/compiler_opt/distributed/worker.py index 46b32df7..ddad65d9 100644 --- a/compiler_opt/distributed/worker.py +++ b/compiler_opt/distributed/worker.py @@ -14,10 +14,10 @@ # limitations under the License. """Common abstraction for a worker contract.""" -from typing import Iterable, Optional, TypeVar, Protocol +from typing import Iterable, Optional, TypeVar, Protocol, runtime_checkable -class Worker: +class Worker(Protocol): @classmethod def is_priority_method(cls, method_name: str) -> bool: @@ -25,6 +25,17 @@ def is_priority_method(cls, method_name: str) -> bool: return False +@runtime_checkable +class ContextAwareWorker(Worker, Protocol): + """ContextAwareWorkers use set_context to modify internal state, this allows + it to behave differently when run remotely vs locally. The user of a + ContextAwareWorker can check for this with isinstance(obj, ContextAwareWorker) + """ + + def set_context(self, local: bool) -> None: + return + + T = TypeVar('T') diff --git a/requirements.txt b/requirements.txt index 98456846..aa710053 100644 --- a/requirements.txt +++ b/requirements.txt @@ -27,6 +27,7 @@ oauthlib==3.1.1 opt-einsum==3.3.0 pillow==8.3.1 protobuf==3.17.3 +psutil==5.9.0 pyasn1==0.4.8 pyasn1_modules==0.2.8 pyglet==1.5.0 From 470d7e825657f3d48f36feaadd9df9c8b409af19 Mon Sep 17 00:00:00 2001 From: Jin Xin Ng Date: Thu, 11 Aug 2022 16:32:45 -0700 Subject: [PATCH 2/5] Add a test --- .../local/local_worker_manager_test.py | 48 ++++++++++++++++++- 1 file changed, 47 insertions(+), 1 deletion(-) diff --git a/compiler_opt/distributed/local/local_worker_manager_test.py b/compiler_opt/distributed/local/local_worker_manager_test.py index 21a8c0ab..a6527333 100644 --- a/compiler_opt/distributed/local/local_worker_manager_test.py +++ b/compiler_opt/distributed/local/local_worker_manager_test.py @@ -15,7 +15,6 @@ """Test for local worker manager.""" import concurrent.futures -import multiprocessing import time from absl.testing import absltest @@ -59,6 +58,25 @@ def method(self): time.sleep(3600) +class JobCounter(Worker): + """Test worker.""" + + def __init__(self): + self.times = [] + + @classmethod + def is_priority_method(cls, method_name: str) -> bool: + return method_name == 'get_times' + + def start(self): + while True: + self.times.append(time.time()) + time.sleep(0.05) + + def get_times(self): + return self.times + + class LocalWorkerManagerTest(absltest.TestCase): def test_pool(self): @@ -100,6 +118,34 @@ def test_worker_crash_while_waiting(self): with self.assertRaises(concurrent.futures.CancelledError): _ = f.result() + def test_pause_resume(self): + + with local_worker_manager.LocalWorkerPool(JobCounter, 1) as pool: + p = pool[0] + + # Fill the q for 1 second + p.start() + time.sleep(1) + + # Then pause the process for 1 second + p.pause() + time.sleep(1) + + # Then resume the process and wait 1 more second + p.resume() + time.sleep(1) + + times = p.get_times().result() + + # If pause/resume worked, there should be a gap of at least 0.5 seconds. + # Otherwise, this will throw an exception. + self.assertNotEmpty(times) + last_time = times[0] + for cur_time in times: + if cur_time - last_time > 0.5: + return + raise ValueError('Failed to find a 2 second gap in times.') + if __name__ == '__main__': multiprocessing.handle_test_main(absltest.main) From 1d537c4903f030563954e333f6da4974cdfcf33c Mon Sep 17 00:00:00 2001 From: Jin Xin Ng Date: Mon, 15 Aug 2022 09:56:55 -0700 Subject: [PATCH 3/5] Move logic to runners --- .../distributed/local/local_worker_manager.py | 18 --- .../local/local_worker_manager_test.py | 47 -------- compiler_opt/rl/compilation_runner.py | 111 ++++++++++++++---- compiler_opt/rl/compilation_runner_test.py | 28 ++++- compiler_opt/rl/inlining/inlining_runner.py | 2 - compiler_opt/rl/local_data_collector_test.py | 6 +- compiler_opt/rl/regalloc/regalloc_runner.py | 1 - 7 files changed, 114 insertions(+), 99 deletions(-) diff --git a/compiler_opt/distributed/local/local_worker_manager.py b/compiler_opt/distributed/local/local_worker_manager.py index 9ab0a38b..6431140f 100644 --- a/compiler_opt/distributed/local/local_worker_manager.py +++ b/compiler_opt/distributed/local/local_worker_manager.py @@ -32,9 +32,7 @@ import functools import multiprocessing import threading -import os import psutil -import signal from absl import logging # pylint: disable=unused-import @@ -134,7 +132,6 @@ def __init__(self): # when we stop. self._lock = threading.Lock() self._map: Dict[int, concurrent.futures.Future] = {} - self.is_paused = False # thread draining the pipe self._pump = threading.Thread(target=self._msg_pump) @@ -214,25 +211,10 @@ def shutdown(self): try: # Killing the process triggers observer exit, which triggers msg_pump # exit - self.resume() self._process.kill() except: # pylint: disable=bare-except pass - def pause(self): - if self.is_paused: - return - self.is_paused = True - # used to send the STOP signal; does not actually kill the process - os.kill(self._process.pid, signal.SIGSTOP) - - def resume(self): - if not self.is_paused: - return - self.is_paused = False - # used to send the CONTINUE signal; does not actually kill the process - os.kill(self._process.pid, signal.SIGCONT) - def set_nice(self, val: int): """Sets the nice-ness of the process, this modifies how the OS schedules it. Only works on Unix, since val is presumed to be an int. diff --git a/compiler_opt/distributed/local/local_worker_manager_test.py b/compiler_opt/distributed/local/local_worker_manager_test.py index a6527333..75aa9075 100644 --- a/compiler_opt/distributed/local/local_worker_manager_test.py +++ b/compiler_opt/distributed/local/local_worker_manager_test.py @@ -58,25 +58,6 @@ def method(self): time.sleep(3600) -class JobCounter(Worker): - """Test worker.""" - - def __init__(self): - self.times = [] - - @classmethod - def is_priority_method(cls, method_name: str) -> bool: - return method_name == 'get_times' - - def start(self): - while True: - self.times.append(time.time()) - time.sleep(0.05) - - def get_times(self): - return self.times - - class LocalWorkerManagerTest(absltest.TestCase): def test_pool(self): @@ -118,34 +99,6 @@ def test_worker_crash_while_waiting(self): with self.assertRaises(concurrent.futures.CancelledError): _ = f.result() - def test_pause_resume(self): - - with local_worker_manager.LocalWorkerPool(JobCounter, 1) as pool: - p = pool[0] - - # Fill the q for 1 second - p.start() - time.sleep(1) - - # Then pause the process for 1 second - p.pause() - time.sleep(1) - - # Then resume the process and wait 1 more second - p.resume() - time.sleep(1) - - times = p.get_times().result() - - # If pause/resume worked, there should be a gap of at least 0.5 seconds. - # Otherwise, this will throw an exception. - self.assertNotEmpty(times) - last_time = times[0] - for cur_time in times: - if cur_time - last_time > 0.5: - return - raise ValueError('Failed to find a 2 second gap in times.') - if __name__ == '__main__': multiprocessing.handle_test_main(absltest.main) diff --git a/compiler_opt/rl/compilation_runner.py b/compiler_opt/rl/compilation_runner.py index aa835de1..6374d374 100644 --- a/compiler_opt/rl/compilation_runner.py +++ b/compiler_opt/rl/compilation_runner.py @@ -18,9 +18,11 @@ import dataclasses import json import os +import signal import subprocess import threading -from typing import Dict, List, Optional, Tuple +import time +from typing import Dict, List, Optional, Tuple, Union from absl import flags from compiler_opt.distributed.worker import Worker, WorkerFuture @@ -102,13 +104,22 @@ class WorkerCancellationManager: managing resources. """ - def __init__(self): + @dataclasses.dataclass + class ProcData: + process: 'subprocess.Popen[bytes]' + timeout: threading.Timer + time_left: float + start_time: float + + def __init__(self, timeout: float = _COMPILATION_TIMEOUT.value): # the queue is filled only by workers, and drained only by the single # consumer. we use _done to manage access to the queue. We can then assume # empty() is accurate and get() never blocks. - self._processes = set() + self._processes: Dict[int, WorkerCancellationManager.ProcData] = {} self._done = False + self._paused = False self._lock = threading.Lock() + self._timeout = timeout def enable(self): with self._lock: @@ -118,7 +129,13 @@ def register_process(self, p: 'subprocess.Popen[bytes]'): """Register a process for potential cancellation.""" with self._lock: if not self._done: - self._processes.add(p) + self._processes[p.pid] = self.ProcData( + process=p, + timeout=threading.Timer(self._timeout, + kill_process_ignore_exceptions, (p,)), + time_left=self._timeout, + start_time=time.time()) + self._processes[p.pid].timeout.start() return kill_process_ignore_exceptions(p) @@ -126,18 +143,56 @@ def kill_all_processes(self): """Cancel any pending work.""" with self._lock: self._done = True - for p in self._processes: - kill_process_ignore_exceptions(p) + for pdata in self._processes.values(): + kill_process_ignore_exceptions(pdata.process) + + def pause_all_processes(self): + with self._lock: + if self._paused: + return + self._paused = True + + cur_time = time.time() + for pid, pdata in self._processes.items(): + pdata.timeout.cancel() + pdata.time_left -= cur_time - pdata.start_time + if pdata.time_left > 0: + # used to send the STOP signal; does not actually kill the process + os.kill(pid, signal.SIGSTOP) + else: + # In case we cancelled right after the timeout expired, + # but before actually killing the process. + kill_process_ignore_exceptions(pdata.process) + + def resume_all_processes(self): + with self._lock: + if not self._paused: + return + self._paused = False + + cur_time = time.time() + for pid, pdata in self._processes.items(): + pdata.timeout = threading.Timer(pdata.time_left, + kill_process_ignore_exceptions, + (pdata.process,)) + pdata.timeout.start() + pdata.start_time = cur_time + # used to send the CONTINUE signal; does not actually kill the process + os.kill(pid, signal.SIGCONT) def unregister_process(self, p: 'subprocess.Popen[bytes]'): with self._lock: - if not self._done: - self._processes.remove(p) + if p.pid in self._processes: + self._processes[p.pid].timeout.cancel() + del self._processes[p.pid] + + def __del__(self): + if len(self._processes) > 0: + raise RuntimeError('Cancellation manager deleted while containing items.') def start_cancellable_process( cmdline: List[str], - timeout: float, cancellation_manager: Optional[WorkerCancellationManager], want_output: bool = False) -> Optional[bytes]: """Start a cancellable process. @@ -166,14 +221,10 @@ def start_cancellable_process( if cancellation_manager: cancellation_manager.register_process(p) - try: - retcode = p.wait(timeout=timeout) - except subprocess.TimeoutExpired as e: - kill_process_ignore_exceptions(p) - raise e - finally: - if cancellation_manager: - cancellation_manager.unregister_process(p) + retcode = p.wait() + + if cancellation_manager: + cancellation_manager.unregister_process(p) if retcode != 0: raise ProcessKilledError( ) if retcode == -9 else subprocess.CalledProcessError(retcode, cmdline) @@ -249,12 +300,16 @@ class CompilationRunner(Worker): @classmethod def is_priority_method(cls, method_name: str) -> bool: - return method_name == 'cancel_all_work' - - def __init__(self, - clang_path: Optional[str] = None, - launcher_path: Optional[str] = None, - moving_average_decay_rate: float = 1): + return method_name in { + 'cancel_all_work', 'pause_all_work', 'resume_all_work' + } + + def __init__( + self, + clang_path: Optional[str] = None, + launcher_path: Optional[str] = None, + moving_average_decay_rate: float = 1, + cancellation_manager: Optional[WorkerCancellationManager] = None): """Initialization of CompilationRunner class. Args: @@ -265,8 +320,8 @@ def __init__(self, self._clang_path = clang_path self._launcher_path = launcher_path self._moving_average_decay_rate = moving_average_decay_rate - self._compilation_timeout = _COMPILATION_TIMEOUT.value - self._cancellation_manager = WorkerCancellationManager() + self._cancellation_manager = ( + cancellation_manager or WorkerCancellationManager()) # re-allow the cancellation manager accept work. def enable(self): @@ -275,6 +330,12 @@ def enable(self): def cancel_all_work(self): self._cancellation_manager.kill_all_processes() + def pause_all_work(self): + self._cancellation_manager.pause_all_processes() + + def resume_all_work(self): + self._cancellation_manager.resume_all_processes() + def collect_data( self, module_spec: corpus.ModuleSpec, tf_policy_path: str, reward_stat: Optional[Dict[str, RewardStat]]) -> CompilationResult: diff --git a/compiler_opt/rl/compilation_runner_test.py b/compiler_opt/rl/compilation_runner_test.py index 28b7178f..b029817a 100644 --- a/compiler_opt/rl/compilation_runner_test.py +++ b/compiler_opt/rl/compilation_runner_test.py @@ -17,6 +17,7 @@ import os import string import subprocess +import threading import time from unittest import mock @@ -213,9 +214,9 @@ def test_exception_handling(self, mock_compile_fn): self.assertEqual(1, mock_compile_fn.call_count) def test_start_subprocess_output(self): - ct = compilation_runner.WorkerCancellationManager() + cm = compilation_runner.WorkerCancellationManager(100) output = compilation_runner.start_cancellable_process( - ['ls', '-l'], timeout=100, cancellation_manager=ct, want_output=True) + ['ls', '-l'], cancellation_manager=cm, want_output=True) if output: output_str = output.decode('utf-8') else: @@ -227,14 +228,31 @@ def test_timeout_kills_process(self): 'test_timeout_kills_test_file') if os.path.exists(sentinel_file): os.remove(sentinel_file) - with self.assertRaises(subprocess.TimeoutExpired): + with self.assertRaises(compilation_runner.ProcessKilledError): + cm = compilation_runner.WorkerCancellationManager(0.5) compilation_runner.start_cancellable_process( ['bash', '-c', 'sleep 1s ; touch ' + sentinel_file], - timeout=0.5, - cancellation_manager=None) + cancellation_manager=cm) time.sleep(2) self.assertFalse(os.path.exists(sentinel_file)) + def test_pause_resume(self): + # This also makes sure timeouts are restored properly. + cm = compilation_runner.WorkerCancellationManager(1) + start_time = time.time() + + def stop_and_start(): + time.sleep(0.25) + cm.pause_all_processes() + time.sleep(2) + cm.resume_all_processes() + + threading.Thread(target=stop_and_start).start() + compilation_runner.start_cancellable_process(['sleep', '0.5'], + cancellation_manager=cm) + # should be at least 2 seconds due to the pause. + self.assertGreater(time.time() - start_time, 2) + if __name__ == '__main__': tf.test.main() diff --git a/compiler_opt/rl/inlining/inlining_runner.py b/compiler_opt/rl/inlining/inlining_runner.py index 69730f0a..7c0ca181 100644 --- a/compiler_opt/rl/inlining/inlining_runner.py +++ b/compiler_opt/rl/inlining/inlining_runner.py @@ -90,12 +90,10 @@ def compile_fn( command_line.extend( ['-mllvm', '-ml-inliner-model-under-training=' + tf_policy_path]) compilation_runner.start_cancellable_process(command_line, - self._compilation_timeout, cancellation_manager) command_line = [self._llvm_size_path, output_native_path] output_bytes = compilation_runner.start_cancellable_process( command_line, - timeout=self._compilation_timeout, cancellation_manager=cancellation_manager, want_output=True) if not output_bytes: diff --git a/compiler_opt/rl/local_data_collector_test.py b/compiler_opt/rl/local_data_collector_test.py index 35491fae..88b9a050 100644 --- a/compiler_opt/rl/local_data_collector_test.py +++ b/compiler_opt/rl/local_data_collector_test.py @@ -80,9 +80,13 @@ def mock_collect_data(module_spec, tf_policy_dir, reward_stat): class Sleeper(compilation_runner.CompilationRunner): """Test CompilationRunner that just sleeps.""" + def __init__(self): + super().__init__( + cancellation_manager=compilation_runner.WorkerCancellationManager(3600)) + def collect_data(self, module_spec, tf_policy_path, reward_stat): _ = module_spec, tf_policy_path, reward_stat - compilation_runner.start_cancellable_process(['sleep', '3600s'], 3600, + compilation_runner.start_cancellable_process(['sleep', '3600s'], self._cancellation_manager) return compilation_runner.CompilationResult( diff --git a/compiler_opt/rl/regalloc/regalloc_runner.py b/compiler_opt/rl/regalloc/regalloc_runner.py index ef6cf0ce..637a40b2 100644 --- a/compiler_opt/rl/regalloc/regalloc_runner.py +++ b/compiler_opt/rl/regalloc/regalloc_runner.py @@ -87,7 +87,6 @@ def compile_fn( if tf_policy_path: command_line.extend(['-mllvm', '-regalloc-model=' + tf_policy_path]) compilation_runner.start_cancellable_process(command_line, - self._compilation_timeout, cancellation_manager) sequence_example = struct_pb2.Struct() From 3c71bf36f228e34465be5b4f20da05f89ba5cbd4 Mon Sep 17 00:00:00 2001 From: Jin Xin Ng Date: Mon, 15 Aug 2022 10:33:15 -0700 Subject: [PATCH 4/5] Immediately pause on register --- compiler_opt/rl/compilation_runner.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/compiler_opt/rl/compilation_runner.py b/compiler_opt/rl/compilation_runner.py index 6374d374..4b8f1a40 100644 --- a/compiler_opt/rl/compilation_runner.py +++ b/compiler_opt/rl/compilation_runner.py @@ -22,7 +22,7 @@ import subprocess import threading import time -from typing import Dict, List, Optional, Tuple, Union +from typing import Dict, List, Optional, Tuple from absl import flags from compiler_opt.distributed.worker import Worker, WorkerFuture @@ -135,7 +135,10 @@ def register_process(self, p: 'subprocess.Popen[bytes]'): kill_process_ignore_exceptions, (p,)), time_left=self._timeout, start_time=time.time()) - self._processes[p.pid].timeout.start() + if self._paused: + os.kill(p.pid, signal.SIGSTOP) + else: + self._processes[p.pid].timeout.start() return kill_process_ignore_exceptions(p) From d3ee08e155d56b80fd82b988248f0fba81af6e69 Mon Sep 17 00:00:00 2001 From: Jin Xin Ng Date: Thu, 18 Aug 2022 12:47:52 -0700 Subject: [PATCH 5/5] Remove custom timeout --- .../distributed/local/local_worker_manager.py | 4 +- compiler_opt/distributed/worker.py | 13 +-- compiler_opt/rl/compilation_runner.py | 82 ++++++------------- compiler_opt/rl/compilation_runner_test.py | 20 ++--- compiler_opt/rl/inlining/inlining_runner.py | 2 + compiler_opt/rl/local_data_collector_test.py | 6 +- compiler_opt/rl/regalloc/regalloc_runner.py | 1 + 7 files changed, 43 insertions(+), 85 deletions(-) diff --git a/compiler_opt/distributed/local/local_worker_manager.py b/compiler_opt/distributed/local/local_worker_manager.py index 6431140f..425d83b3 100644 --- a/compiler_opt/distributed/local/local_worker_manager.py +++ b/compiler_opt/distributed/local/local_worker_manager.py @@ -31,8 +31,8 @@ import dataclasses import functools import multiprocessing -import threading import psutil +import threading from absl import logging # pylint: disable=unused-import @@ -40,7 +40,7 @@ from contextlib import AbstractContextManager from multiprocessing import connection -from typing import Any, Callable, Dict, Optional, List +from typing import Any, Callable, Dict, List, Optional @dataclasses.dataclass(frozen=True) diff --git a/compiler_opt/distributed/worker.py b/compiler_opt/distributed/worker.py index ddad65d9..ff040d8e 100644 --- a/compiler_opt/distributed/worker.py +++ b/compiler_opt/distributed/worker.py @@ -14,7 +14,7 @@ # limitations under the License. """Common abstraction for a worker contract.""" -from typing import Iterable, Optional, TypeVar, Protocol, runtime_checkable +from typing import Iterable, Optional, Protocol, TypeVar class Worker(Protocol): @@ -25,17 +25,6 @@ def is_priority_method(cls, method_name: str) -> bool: return False -@runtime_checkable -class ContextAwareWorker(Worker, Protocol): - """ContextAwareWorkers use set_context to modify internal state, this allows - it to behave differently when run remotely vs locally. The user of a - ContextAwareWorker can check for this with isinstance(obj, ContextAwareWorker) - """ - - def set_context(self, local: bool) -> None: - return - - T = TypeVar('T') diff --git a/compiler_opt/rl/compilation_runner.py b/compiler_opt/rl/compilation_runner.py index 4b8f1a40..91060f80 100644 --- a/compiler_opt/rl/compilation_runner.py +++ b/compiler_opt/rl/compilation_runner.py @@ -21,7 +21,6 @@ import signal import subprocess import threading -import time from typing import Dict, List, Optional, Tuple from absl import flags @@ -104,22 +103,14 @@ class WorkerCancellationManager: managing resources. """ - @dataclasses.dataclass - class ProcData: - process: 'subprocess.Popen[bytes]' - timeout: threading.Timer - time_left: float - start_time: float - - def __init__(self, timeout: float = _COMPILATION_TIMEOUT.value): + def __init__(self): # the queue is filled only by workers, and drained only by the single # consumer. we use _done to manage access to the queue. We can then assume # empty() is accurate and get() never blocks. - self._processes: Dict[int, WorkerCancellationManager.ProcData] = {} + self._processes = set() self._done = False self._paused = False self._lock = threading.Lock() - self._timeout = timeout def enable(self): with self._lock: @@ -129,16 +120,7 @@ def register_process(self, p: 'subprocess.Popen[bytes]'): """Register a process for potential cancellation.""" with self._lock: if not self._done: - self._processes[p.pid] = self.ProcData( - process=p, - timeout=threading.Timer(self._timeout, - kill_process_ignore_exceptions, (p,)), - time_left=self._timeout, - start_time=time.time()) - if self._paused: - os.kill(p.pid, signal.SIGSTOP) - else: - self._processes[p.pid].timeout.start() + self._processes.add(p) return kill_process_ignore_exceptions(p) @@ -146,8 +128,8 @@ def kill_all_processes(self): """Cancel any pending work.""" with self._lock: self._done = True - for pdata in self._processes.values(): - kill_process_ignore_exceptions(pdata.process) + for p in self._processes: + kill_process_ignore_exceptions(p) def pause_all_processes(self): with self._lock: @@ -155,17 +137,9 @@ def pause_all_processes(self): return self._paused = True - cur_time = time.time() - for pid, pdata in self._processes.items(): - pdata.timeout.cancel() - pdata.time_left -= cur_time - pdata.start_time - if pdata.time_left > 0: - # used to send the STOP signal; does not actually kill the process - os.kill(pid, signal.SIGSTOP) - else: - # In case we cancelled right after the timeout expired, - # but before actually killing the process. - kill_process_ignore_exceptions(pdata.process) + for p in self._processes: + # used to send the STOP signal; does not actually kill the process + os.kill(p.pid, signal.SIGSTOP) def resume_all_processes(self): with self._lock: @@ -173,21 +147,13 @@ def resume_all_processes(self): return self._paused = False - cur_time = time.time() - for pid, pdata in self._processes.items(): - pdata.timeout = threading.Timer(pdata.time_left, - kill_process_ignore_exceptions, - (pdata.process,)) - pdata.timeout.start() - pdata.start_time = cur_time + for p in self._processes: # used to send the CONTINUE signal; does not actually kill the process - os.kill(pid, signal.SIGCONT) + os.kill(p.pid, signal.SIGCONT) def unregister_process(self, p: 'subprocess.Popen[bytes]'): with self._lock: - if p.pid in self._processes: - self._processes[p.pid].timeout.cancel() - del self._processes[p.pid] + self._processes.remove(p) def __del__(self): if len(self._processes) > 0: @@ -196,6 +162,7 @@ def __del__(self): def start_cancellable_process( cmdline: List[str], + timeout: float, cancellation_manager: Optional[WorkerCancellationManager], want_output: bool = False) -> Optional[bytes]: """Start a cancellable process. @@ -224,10 +191,15 @@ def start_cancellable_process( if cancellation_manager: cancellation_manager.register_process(p) - retcode = p.wait() + try: + retcode = p.wait(timeout=timeout) + except subprocess.TimeoutExpired as e: + kill_process_ignore_exceptions(p) + raise e + finally: + if cancellation_manager: + cancellation_manager.unregister_process(p) - if cancellation_manager: - cancellation_manager.unregister_process(p) if retcode != 0: raise ProcessKilledError( ) if retcode == -9 else subprocess.CalledProcessError(retcode, cmdline) @@ -307,12 +279,10 @@ def is_priority_method(cls, method_name: str) -> bool: 'cancel_all_work', 'pause_all_work', 'resume_all_work' } - def __init__( - self, - clang_path: Optional[str] = None, - launcher_path: Optional[str] = None, - moving_average_decay_rate: float = 1, - cancellation_manager: Optional[WorkerCancellationManager] = None): + def __init__(self, + clang_path: Optional[str] = None, + launcher_path: Optional[str] = None, + moving_average_decay_rate: float = 1): """Initialization of CompilationRunner class. Args: @@ -323,8 +293,8 @@ def __init__( self._clang_path = clang_path self._launcher_path = launcher_path self._moving_average_decay_rate = moving_average_decay_rate - self._cancellation_manager = ( - cancellation_manager or WorkerCancellationManager()) + self._compilation_timeout = _COMPILATION_TIMEOUT.value + self._cancellation_manager = WorkerCancellationManager() # re-allow the cancellation manager accept work. def enable(self): diff --git a/compiler_opt/rl/compilation_runner_test.py b/compiler_opt/rl/compilation_runner_test.py index b029817a..4390d51d 100644 --- a/compiler_opt/rl/compilation_runner_test.py +++ b/compiler_opt/rl/compilation_runner_test.py @@ -214,9 +214,9 @@ def test_exception_handling(self, mock_compile_fn): self.assertEqual(1, mock_compile_fn.call_count) def test_start_subprocess_output(self): - cm = compilation_runner.WorkerCancellationManager(100) + cm = compilation_runner.WorkerCancellationManager() output = compilation_runner.start_cancellable_process( - ['ls', '-l'], cancellation_manager=cm, want_output=True) + ['ls', '-l'], timeout=100, cancellation_manager=cm, want_output=True) if output: output_str = output.decode('utf-8') else: @@ -228,30 +228,30 @@ def test_timeout_kills_process(self): 'test_timeout_kills_test_file') if os.path.exists(sentinel_file): os.remove(sentinel_file) - with self.assertRaises(compilation_runner.ProcessKilledError): - cm = compilation_runner.WorkerCancellationManager(0.5) + with self.assertRaises(subprocess.TimeoutExpired): compilation_runner.start_cancellable_process( ['bash', '-c', 'sleep 1s ; touch ' + sentinel_file], - cancellation_manager=cm) + timeout=0.5, + cancellation_manager=None) time.sleep(2) self.assertFalse(os.path.exists(sentinel_file)) def test_pause_resume(self): - # This also makes sure timeouts are restored properly. - cm = compilation_runner.WorkerCancellationManager(1) + cm = compilation_runner.WorkerCancellationManager() start_time = time.time() def stop_and_start(): time.sleep(0.25) cm.pause_all_processes() - time.sleep(2) + time.sleep(1) cm.resume_all_processes() threading.Thread(target=stop_and_start).start() compilation_runner.start_cancellable_process(['sleep', '0.5'], + 30, cancellation_manager=cm) - # should be at least 2 seconds due to the pause. - self.assertGreater(time.time() - start_time, 2) + # should be at least 1 second due to the pause. + self.assertGreater(time.time() - start_time, 1) if __name__ == '__main__': diff --git a/compiler_opt/rl/inlining/inlining_runner.py b/compiler_opt/rl/inlining/inlining_runner.py index 7c0ca181..69730f0a 100644 --- a/compiler_opt/rl/inlining/inlining_runner.py +++ b/compiler_opt/rl/inlining/inlining_runner.py @@ -90,10 +90,12 @@ def compile_fn( command_line.extend( ['-mllvm', '-ml-inliner-model-under-training=' + tf_policy_path]) compilation_runner.start_cancellable_process(command_line, + self._compilation_timeout, cancellation_manager) command_line = [self._llvm_size_path, output_native_path] output_bytes = compilation_runner.start_cancellable_process( command_line, + timeout=self._compilation_timeout, cancellation_manager=cancellation_manager, want_output=True) if not output_bytes: diff --git a/compiler_opt/rl/local_data_collector_test.py b/compiler_opt/rl/local_data_collector_test.py index 88b9a050..35491fae 100644 --- a/compiler_opt/rl/local_data_collector_test.py +++ b/compiler_opt/rl/local_data_collector_test.py @@ -80,13 +80,9 @@ def mock_collect_data(module_spec, tf_policy_dir, reward_stat): class Sleeper(compilation_runner.CompilationRunner): """Test CompilationRunner that just sleeps.""" - def __init__(self): - super().__init__( - cancellation_manager=compilation_runner.WorkerCancellationManager(3600)) - def collect_data(self, module_spec, tf_policy_path, reward_stat): _ = module_spec, tf_policy_path, reward_stat - compilation_runner.start_cancellable_process(['sleep', '3600s'], + compilation_runner.start_cancellable_process(['sleep', '3600s'], 3600, self._cancellation_manager) return compilation_runner.CompilationResult( diff --git a/compiler_opt/rl/regalloc/regalloc_runner.py b/compiler_opt/rl/regalloc/regalloc_runner.py index 637a40b2..ef6cf0ce 100644 --- a/compiler_opt/rl/regalloc/regalloc_runner.py +++ b/compiler_opt/rl/regalloc/regalloc_runner.py @@ -87,6 +87,7 @@ def compile_fn( if tf_policy_path: command_line.extend(['-mllvm', '-regalloc-model=' + tf_policy_path]) compilation_runner.start_cancellable_process(command_line, + self._compilation_timeout, cancellation_manager) sequence_example = struct_pb2.Struct()