Skip to content

Add pause/resume/context to workers #101

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Aug 18, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 22 additions & 1 deletion compiler_opt/distributed/local/local_worker_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
import dataclasses
import functools
import multiprocessing
import psutil
import threading

from absl import logging
Expand All @@ -39,7 +40,7 @@

from contextlib import AbstractContextManager
from multiprocessing import connection
from typing import Any, Callable, Dict, Optional
from typing import Any, Callable, Dict, List, Optional


@dataclasses.dataclass(frozen=True)
Expand Down Expand Up @@ -214,6 +215,18 @@ def shutdown(self):
except: # pylint: disable=bare-except
pass

def set_nice(self, val: int):
"""Sets the nice-ness of the process, this modifies how the OS
schedules it. Only works on Unix, since val is presumed to be an int.
"""
psutil.Process(self._process.pid).nice(val)

def set_affinity(self, val: List[int]):
"""Sets the CPU affinity of the process, this modifies which cores the OS
schedules it on.
"""
psutil.Process(self._process.pid).cpu_affinity(val)

def join(self):
self._observer.join()
self._pump.join()
Expand Down Expand Up @@ -247,3 +260,11 @@ def __exit__(self, *args):
# now wait for the message pumps to indicate they exit.
for s in self._stubs:
s.join()

def __del__(self):
self.__exit__()

@property
def stubs(self):
# Return a shallow copy, to avoid something messing the internal list up
return list(self._stubs)
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
"""Test for local worker manager."""

import concurrent.futures
import multiprocessing
import time

from absl.testing import absltest
Expand Down
4 changes: 2 additions & 2 deletions compiler_opt/distributed/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,10 @@
# limitations under the License.
"""Common abstraction for a worker contract."""

from typing import Iterable, Optional, TypeVar, Protocol
from typing import Iterable, Optional, Protocol, TypeVar


class Worker:
class Worker(Protocol):

@classmethod
def is_priority_method(cls, method_name: str) -> bool:
Expand Down
40 changes: 37 additions & 3 deletions compiler_opt/rl/compilation_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import dataclasses
import json
import os
import signal
import subprocess
import threading
from typing import Dict, List, Optional, Tuple
Expand Down Expand Up @@ -108,6 +109,7 @@ def __init__(self):
# empty() is accurate and get() never blocks.
self._processes = set()
self._done = False
self._paused = False
self._lock = threading.Lock()

def enable(self):
Expand All @@ -129,10 +131,33 @@ def kill_all_processes(self):
for p in self._processes:
kill_process_ignore_exceptions(p)

def pause_all_processes(self):
with self._lock:
if self._paused:
return
self._paused = True

for p in self._processes:
# used to send the STOP signal; does not actually kill the process
os.kill(p.pid, signal.SIGSTOP)

def resume_all_processes(self):
with self._lock:
if not self._paused:
return
self._paused = False

for p in self._processes:
# used to send the CONTINUE signal; does not actually kill the process
os.kill(p.pid, signal.SIGCONT)

def unregister_process(self, p: 'subprocess.Popen[bytes]'):
with self._lock:
if not self._done:
self._processes.remove(p)
self._processes.remove(p)

def __del__(self):
if len(self._processes) > 0:
raise RuntimeError('Cancellation manager deleted while containing items.')


def start_cancellable_process(
Expand Down Expand Up @@ -174,6 +199,7 @@ def start_cancellable_process(
finally:
if cancellation_manager:
cancellation_manager.unregister_process(p)

if retcode != 0:
raise ProcessKilledError(
) if retcode == -9 else subprocess.CalledProcessError(retcode, cmdline)
Expand Down Expand Up @@ -249,7 +275,9 @@ class CompilationRunner(Worker):

@classmethod
def is_priority_method(cls, method_name: str) -> bool:
return method_name == 'cancel_all_work'
return method_name in {
'cancel_all_work', 'pause_all_work', 'resume_all_work'
}

def __init__(self,
clang_path: Optional[str] = None,
Expand All @@ -275,6 +303,12 @@ def enable(self):
def cancel_all_work(self):
self._cancellation_manager.kill_all_processes()

def pause_all_work(self):
self._cancellation_manager.pause_all_processes()

def resume_all_work(self):
self._cancellation_manager.resume_all_processes()

def collect_data(
self, module_spec: corpus.ModuleSpec, tf_policy_path: str,
reward_stat: Optional[Dict[str, RewardStat]]) -> CompilationResult:
Expand Down
22 changes: 20 additions & 2 deletions compiler_opt/rl/compilation_runner_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import os
import string
import subprocess
import threading
import time
from unittest import mock

Expand Down Expand Up @@ -213,9 +214,9 @@ def test_exception_handling(self, mock_compile_fn):
self.assertEqual(1, mock_compile_fn.call_count)

def test_start_subprocess_output(self):
ct = compilation_runner.WorkerCancellationManager()
cm = compilation_runner.WorkerCancellationManager()
output = compilation_runner.start_cancellable_process(
['ls', '-l'], timeout=100, cancellation_manager=ct, want_output=True)
['ls', '-l'], timeout=100, cancellation_manager=cm, want_output=True)
if output:
output_str = output.decode('utf-8')
else:
Expand All @@ -235,6 +236,23 @@ def test_timeout_kills_process(self):
time.sleep(2)
self.assertFalse(os.path.exists(sentinel_file))

def test_pause_resume(self):
cm = compilation_runner.WorkerCancellationManager()
start_time = time.time()

def stop_and_start():
time.sleep(0.25)
cm.pause_all_processes()
time.sleep(1)
cm.resume_all_processes()

threading.Thread(target=stop_and_start).start()
compilation_runner.start_cancellable_process(['sleep', '0.5'],
30,
cancellation_manager=cm)
# should be at least 1 second due to the pause.
self.assertGreater(time.time() - start_time, 1)


if __name__ == '__main__':
tf.test.main()
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ oauthlib==3.1.1
opt-einsum==3.3.0
pillow==8.3.1
protobuf==3.17.3
psutil==5.9.0
pyasn1==0.4.8
pyasn1_modules==0.2.8
pyglet==1.5.0
Expand Down