Skip to content

Commit b623fd0

Browse files
committed
Add pause/resume/context to workers
- Allows a user to start/stop processes at will, via OS signals SIGSTOP and SIGCONT. - Allows a user to bind processes to specific CPUs. - Allows local_worker_pool to be used outside of a context manager - Switch workers to be Protocol based, so Workers are effectively duck-typed (i.e. anything that has the required methods passes as a Worker) Part of google#96
1 parent 7e4b19f commit b623fd0

File tree

3 files changed

+54
-3
lines changed

3 files changed

+54
-3
lines changed

compiler_opt/distributed/local/local_worker_manager.py

Lines changed: 40 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,14 +32,17 @@
3232
import functools
3333
import multiprocessing
3434
import threading
35+
import os
36+
import psutil
37+
import signal
3538

3639
from absl import logging
3740
# pylint: disable=unused-import
3841
from compiler_opt.distributed.worker import Worker
3942

4043
from contextlib import AbstractContextManager
4144
from multiprocessing import connection
42-
from typing import Any, Callable, Dict, Optional
45+
from typing import Any, Callable, Dict, Optional, List
4346

4447

4548
@dataclasses.dataclass(frozen=True)
@@ -131,6 +134,7 @@ def __init__(self):
131134
# when we stop.
132135
self._lock = threading.Lock()
133136
self._map: Dict[int, concurrent.futures.Future] = {}
137+
self.is_paused = False
134138

135139
# thread draining the pipe
136140
self._pump = threading.Thread(target=self._msg_pump)
@@ -205,10 +209,37 @@ def shutdown(self):
205209
try:
206210
# Killing the process triggers observer exit, which triggers msg_pump
207211
# exit
212+
self.resume()
208213
self._process.kill()
209214
except: # pylint: disable=bare-except
210215
pass
211216

217+
def pause(self):
218+
if self.is_paused:
219+
return
220+
self.is_paused = True
221+
# used to send the STOP signal; does not actually kill the process
222+
os.kill(self._process.pid, signal.SIGSTOP)
223+
224+
def resume(self):
225+
if not self.is_paused:
226+
return
227+
self.is_paused = False
228+
# used to send the CONTINUE signal; does not actually kill the process
229+
os.kill(self._process.pid, signal.SIGCONT)
230+
231+
def set_nice(self, val: int):
232+
"""Sets the nice-ness of the process, this modifies how the OS
233+
schedules it. Only works on Unix, since val is presumed to be an int.
234+
"""
235+
psutil.Process(self._process.pid).nice(val)
236+
237+
def set_affinity(self, val: List[int]):
238+
"""Sets the CPU affinity of the process, this modifies which cores the OS
239+
schedules it on.
240+
"""
241+
psutil.Process(self._process.pid).cpu_affinity(val)
242+
212243
def join(self):
213244
self._observer.join()
214245
self._pump.join()
@@ -242,3 +273,11 @@ def __exit__(self, *args):
242273
# now wait for the message pumps to indicate they exit.
243274
for s in self._stubs:
244275
s.join()
276+
277+
def __del__(self):
278+
self.__exit__()
279+
280+
@property
281+
def stubs(self):
282+
# Return a shallow copy, to avoid something messing the internal list up
283+
return list(self._stubs)

compiler_opt/distributed/worker.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,17 +15,28 @@
1515
"""Common abstraction for a worker contract."""
1616

1717
import abc
18-
from typing import Generic, Iterable, Optional, TypeVar
18+
from typing import Generic, Iterable, Optional, TypeVar, Protocol, runtime_checkable
1919

2020

21-
class Worker:
21+
class Worker(Protocol):
2222

2323
@classmethod
2424
def is_priority_method(cls, method_name: str) -> bool:
2525
_ = method_name
2626
return False
2727

2828

29+
@runtime_checkable
30+
class ContextAwareWorker(Worker, Protocol):
31+
"""ContextAwareWorkers use set_context to modify internal state, this allows
32+
it to behave differently when run remotely vs locally. The user of a
33+
ContextAwareWorker can check for this with isinstance(obj, ContextAwareWorker)
34+
"""
35+
36+
def set_context(self, local: bool) -> None:
37+
return
38+
39+
2940
T = TypeVar('T')
3041

3142

requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ oauthlib==3.1.1
2727
opt-einsum==3.3.0
2828
pillow==8.3.1
2929
protobuf==3.17.3
30+
psutil==5.9.0
3031
pyasn1==0.4.8
3132
pyasn1_modules==0.2.8
3233
pyglet==1.5.0

0 commit comments

Comments
 (0)