Skip to content

Commit d3bb465

Browse files
committed
Add pause/resume/context to workers
- Allows a user to start/stop processes at will, via OS signals SIGSTOP and SIGCONT. - Allows a user to bind processes to specific CPUs. - Allows local_worker_pool to be used outside of a context manager - Switch workers to be Protocol based, so Workers are effectively duck-typed (i.e. anything that has the required methods passes as a Worker) Part of google#96
1 parent 45d1e2d commit d3bb465

File tree

3 files changed

+54
-3
lines changed

3 files changed

+54
-3
lines changed

compiler_opt/distributed/local/local_worker_manager.py

Lines changed: 40 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,14 +32,17 @@
3232
import functools
3333
import multiprocessing
3434
import threading
35+
import os
36+
import psutil
37+
import signal
3538

3639
from absl import logging
3740
# pylint: disable=unused-import
3841
from compiler_opt.distributed.worker import Worker
3942

4043
from contextlib import AbstractContextManager
4144
from multiprocessing import connection
42-
from typing import Any, Callable, Dict, Optional
45+
from typing import Any, Callable, Dict, Optional, List
4346

4447

4548
@dataclasses.dataclass(frozen=True)
@@ -131,6 +134,7 @@ def __init__(self):
131134
# when we stop.
132135
self._lock = threading.Lock()
133136
self._map: Dict[int, concurrent.futures.Future] = {}
137+
self.is_paused = False
134138

135139
# thread draining the pipe
136140
self._pump = threading.Thread(target=self._msg_pump)
@@ -210,10 +214,37 @@ def shutdown(self):
210214
try:
211215
# Killing the process triggers observer exit, which triggers msg_pump
212216
# exit
217+
self.resume()
213218
self._process.kill()
214219
except: # pylint: disable=bare-except
215220
pass
216221

222+
def pause(self):
223+
if self.is_paused:
224+
return
225+
self.is_paused = True
226+
# used to send the STOP signal; does not actually kill the process
227+
os.kill(self._process.pid, signal.SIGSTOP)
228+
229+
def resume(self):
230+
if not self.is_paused:
231+
return
232+
self.is_paused = False
233+
# used to send the CONTINUE signal; does not actually kill the process
234+
os.kill(self._process.pid, signal.SIGCONT)
235+
236+
def set_nice(self, val: int):
237+
"""Sets the nice-ness of the process, this modifies how the OS
238+
schedules it. Only works on Unix, since val is presumed to be an int.
239+
"""
240+
psutil.Process(self._process.pid).nice(val)
241+
242+
def set_affinity(self, val: List[int]):
243+
"""Sets the CPU affinity of the process, this modifies which cores the OS
244+
schedules it on.
245+
"""
246+
psutil.Process(self._process.pid).cpu_affinity(val)
247+
217248
def join(self):
218249
self._observer.join()
219250
self._pump.join()
@@ -247,3 +278,11 @@ def __exit__(self, *args):
247278
# now wait for the message pumps to indicate they exit.
248279
for s in self._stubs:
249280
s.join()
281+
282+
def __del__(self):
283+
self.__exit__()
284+
285+
@property
286+
def stubs(self):
287+
# Return a shallow copy, to avoid something messing the internal list up
288+
return list(self._stubs)

compiler_opt/distributed/worker.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,17 +14,28 @@
1414
# limitations under the License.
1515
"""Common abstraction for a worker contract."""
1616

17-
from typing import Iterable, Optional, TypeVar, Protocol
17+
from typing import Iterable, Optional, TypeVar, Protocol, runtime_checkable
1818

1919

20-
class Worker:
20+
class Worker(Protocol):
2121

2222
@classmethod
2323
def is_priority_method(cls, method_name: str) -> bool:
2424
_ = method_name
2525
return False
2626

2727

28+
@runtime_checkable
29+
class ContextAwareWorker(Worker, Protocol):
30+
"""ContextAwareWorkers use set_context to modify internal state, this allows
31+
it to behave differently when run remotely vs locally. The user of a
32+
ContextAwareWorker can check for this with isinstance(obj, ContextAwareWorker)
33+
"""
34+
35+
def set_context(self, local: bool) -> None:
36+
return
37+
38+
2839
T = TypeVar('T')
2940

3041

requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ oauthlib==3.1.1
2727
opt-einsum==3.3.0
2828
pillow==8.3.1
2929
protobuf==3.17.3
30+
psutil==5.9.0
3031
pyasn1==0.4.8
3132
pyasn1_modules==0.2.8
3233
pyglet==1.5.0

0 commit comments

Comments
 (0)