Skip to content

Commit 16eb08b

Browse files
authored
Add pause/resume/context to workers (#101)
* Add pause/resume/context to workers - Allows a user to start/stop processes at will, via OS signals SIGSTOP and SIGCONT. - Allows a user to bind processes to specific CPUs. - Allows local_worker_pool to be used outside of a context manager - Switch workers to be Protocol based, so Workers are effectively duck-typed (i.e. anything that has the required methods passes as a Worker) Part of #96
1 parent 45d1e2d commit 16eb08b

File tree

6 files changed

+82
-9
lines changed

6 files changed

+82
-9
lines changed

compiler_opt/distributed/local/local_worker_manager.py

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
import dataclasses
3232
import functools
3333
import multiprocessing
34+
import psutil
3435
import threading
3536

3637
from absl import logging
@@ -39,7 +40,7 @@
3940

4041
from contextlib import AbstractContextManager
4142
from multiprocessing import connection
42-
from typing import Any, Callable, Dict, Optional
43+
from typing import Any, Callable, Dict, List, Optional
4344

4445

4546
@dataclasses.dataclass(frozen=True)
@@ -214,6 +215,18 @@ def shutdown(self):
214215
except: # pylint: disable=bare-except
215216
pass
216217

218+
def set_nice(self, val: int):
219+
"""Sets the nice-ness of the process, this modifies how the OS
220+
schedules it. Only works on Unix, since val is presumed to be an int.
221+
"""
222+
psutil.Process(self._process.pid).nice(val)
223+
224+
def set_affinity(self, val: List[int]):
225+
"""Sets the CPU affinity of the process, this modifies which cores the OS
226+
schedules it on.
227+
"""
228+
psutil.Process(self._process.pid).cpu_affinity(val)
229+
217230
def join(self):
218231
self._observer.join()
219232
self._pump.join()
@@ -247,3 +260,11 @@ def __exit__(self, *args):
247260
# now wait for the message pumps to indicate they exit.
248261
for s in self._stubs:
249262
s.join()
263+
264+
def __del__(self):
265+
self.__exit__()
266+
267+
@property
268+
def stubs(self):
269+
# Return a shallow copy, to avoid something messing the internal list up
270+
return list(self._stubs)

compiler_opt/distributed/local/local_worker_manager_test.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515
"""Test for local worker manager."""
1616

1717
import concurrent.futures
18-
import multiprocessing
1918
import time
2019

2120
from absl.testing import absltest

compiler_opt/distributed/worker.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,10 @@
1414
# limitations under the License.
1515
"""Common abstraction for a worker contract."""
1616

17-
from typing import Iterable, Optional, TypeVar, Protocol
17+
from typing import Iterable, Optional, Protocol, TypeVar
1818

1919

20-
class Worker:
20+
class Worker(Protocol):
2121

2222
@classmethod
2323
def is_priority_method(cls, method_name: str) -> bool:

compiler_opt/rl/compilation_runner.py

Lines changed: 37 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
import dataclasses
1919
import json
2020
import os
21+
import signal
2122
import subprocess
2223
import threading
2324
from typing import Dict, List, Optional, Tuple
@@ -108,6 +109,7 @@ def __init__(self):
108109
# empty() is accurate and get() never blocks.
109110
self._processes = set()
110111
self._done = False
112+
self._paused = False
111113
self._lock = threading.Lock()
112114

113115
def enable(self):
@@ -129,10 +131,33 @@ def kill_all_processes(self):
129131
for p in self._processes:
130132
kill_process_ignore_exceptions(p)
131133

134+
def pause_all_processes(self):
135+
with self._lock:
136+
if self._paused:
137+
return
138+
self._paused = True
139+
140+
for p in self._processes:
141+
# used to send the STOP signal; does not actually kill the process
142+
os.kill(p.pid, signal.SIGSTOP)
143+
144+
def resume_all_processes(self):
145+
with self._lock:
146+
if not self._paused:
147+
return
148+
self._paused = False
149+
150+
for p in self._processes:
151+
# used to send the CONTINUE signal; does not actually kill the process
152+
os.kill(p.pid, signal.SIGCONT)
153+
132154
def unregister_process(self, p: 'subprocess.Popen[bytes]'):
133155
with self._lock:
134-
if not self._done:
135-
self._processes.remove(p)
156+
self._processes.remove(p)
157+
158+
def __del__(self):
159+
if len(self._processes) > 0:
160+
raise RuntimeError('Cancellation manager deleted while containing items.')
136161

137162

138163
def start_cancellable_process(
@@ -174,6 +199,7 @@ def start_cancellable_process(
174199
finally:
175200
if cancellation_manager:
176201
cancellation_manager.unregister_process(p)
202+
177203
if retcode != 0:
178204
raise ProcessKilledError(
179205
) if retcode == -9 else subprocess.CalledProcessError(retcode, cmdline)
@@ -249,7 +275,9 @@ class CompilationRunner(Worker):
249275

250276
@classmethod
251277
def is_priority_method(cls, method_name: str) -> bool:
252-
return method_name == 'cancel_all_work'
278+
return method_name in {
279+
'cancel_all_work', 'pause_all_work', 'resume_all_work'
280+
}
253281

254282
def __init__(self,
255283
clang_path: Optional[str] = None,
@@ -275,6 +303,12 @@ def enable(self):
275303
def cancel_all_work(self):
276304
self._cancellation_manager.kill_all_processes()
277305

306+
def pause_all_work(self):
307+
self._cancellation_manager.pause_all_processes()
308+
309+
def resume_all_work(self):
310+
self._cancellation_manager.resume_all_processes()
311+
278312
def collect_data(
279313
self, module_spec: corpus.ModuleSpec, tf_policy_path: str,
280314
reward_stat: Optional[Dict[str, RewardStat]]) -> CompilationResult:

compiler_opt/rl/compilation_runner_test.py

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
import os
1818
import string
1919
import subprocess
20+
import threading
2021
import time
2122
from unittest import mock
2223

@@ -213,9 +214,9 @@ def test_exception_handling(self, mock_compile_fn):
213214
self.assertEqual(1, mock_compile_fn.call_count)
214215

215216
def test_start_subprocess_output(self):
216-
ct = compilation_runner.WorkerCancellationManager()
217+
cm = compilation_runner.WorkerCancellationManager()
217218
output = compilation_runner.start_cancellable_process(
218-
['ls', '-l'], timeout=100, cancellation_manager=ct, want_output=True)
219+
['ls', '-l'], timeout=100, cancellation_manager=cm, want_output=True)
219220
if output:
220221
output_str = output.decode('utf-8')
221222
else:
@@ -235,6 +236,23 @@ def test_timeout_kills_process(self):
235236
time.sleep(2)
236237
self.assertFalse(os.path.exists(sentinel_file))
237238

239+
def test_pause_resume(self):
240+
cm = compilation_runner.WorkerCancellationManager()
241+
start_time = time.time()
242+
243+
def stop_and_start():
244+
time.sleep(0.25)
245+
cm.pause_all_processes()
246+
time.sleep(1)
247+
cm.resume_all_processes()
248+
249+
threading.Thread(target=stop_and_start).start()
250+
compilation_runner.start_cancellable_process(['sleep', '0.5'],
251+
30,
252+
cancellation_manager=cm)
253+
# should be at least 1 second due to the pause.
254+
self.assertGreater(time.time() - start_time, 1)
255+
238256

239257
if __name__ == '__main__':
240258
tf.test.main()

requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ oauthlib==3.1.1
2727
opt-einsum==3.3.0
2828
pillow==8.3.1
2929
protobuf==3.17.3
30+
psutil==5.9.0
3031
pyasn1==0.4.8
3132
pyasn1_modules==0.2.8
3233
pyglet==1.5.0

0 commit comments

Comments
 (0)