Skip to content

Commit 32a5c1c

Browse files
authored
Use LocalWorkerManager (#48)
This change switches to using LocalWorkerManager. The main change to the compilation runner code is that cancellation is now handled 'all-out' - we just cancel all jobs on the worker. Issue #31
1 parent ea5f7d6 commit 32a5c1c

File tree

5 files changed

+251
-264
lines changed

5 files changed

+251
-264
lines changed

compiler_opt/distributed/worker.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,46 @@
1414
# limitations under the License.
1515
"""Common abstraction for a worker contract."""
1616

17+
import abc
18+
from typing import Generic, Iterable, TypeVar
19+
1720

1821
class Worker:
1922

2023
@classmethod
2124
def is_priority_method(cls, method_name: str) -> bool:
2225
_ = method_name
2326
return False
27+
28+
29+
T = TypeVar('T')
30+
31+
32+
# Dask's Futures are limited. This captures that.
33+
class WorkerFuture(Generic[T], metaclass=abc.ABCMeta):
34+
35+
@abc.abstractmethod
36+
def result(self) -> T:
37+
raise NotImplementedError()
38+
39+
@abc.abstractmethod
40+
def done(self) -> bool:
41+
raise NotImplementedError()
42+
43+
44+
def wait_for(futures: Iterable[WorkerFuture]):
45+
"""Dask futures don't support more than result() and done()."""
46+
for f in futures:
47+
try:
48+
_ = f.result()
49+
except: # pylint: disable=bare-except
50+
pass
51+
52+
53+
def get_exception(worker_future: WorkerFuture):
54+
assert worker_future.done()
55+
try:
56+
_ = worker_future.result()
57+
return None
58+
except Exception as e: # pylint: disable=broad-except
59+
return e

compiler_opt/rl/compilation_runner.py

Lines changed: 40 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -14,18 +14,17 @@
1414
# limitations under the License.
1515
"""Module for running compilation and collect training data."""
1616

17-
import concurrent
17+
import abc
1818
import dataclasses
1919
import json
20-
import multiprocessing
2120
import subprocess
2221
import threading
2322
from typing import Dict, List, Optional, Tuple
2423

2524
from absl import flags
26-
import tensorflow as tf
27-
25+
from compiler_opt.distributed.worker import Worker, WorkerFuture
2826
from compiler_opt.rl import constant
27+
import tensorflow as tf
2928

3029
_COMPILATION_TIMEOUT = flags.DEFINE_integer(
3130
'compilation_timeout', 60,
@@ -122,18 +121,6 @@ def __init__(self):
122121
Exception.__init__(self)
123122

124123

125-
class ProcessCancellationToken:
126-
127-
def __init__(self):
128-
self._event = multiprocessing.Manager().Event()
129-
130-
def signal(self):
131-
self._event.set()
132-
133-
def wait(self):
134-
self._event.wait()
135-
136-
137124
def kill_process_ignore_exceptions(p: 'subprocess.Popen[bytes]'):
138125
# kill the process and ignore exceptions. Exceptions would be thrown if the
139126
# process has already been killed/finished (which is inherently in a race
@@ -160,6 +147,10 @@ def __init__(self):
160147
self._done = False
161148
self._lock = threading.Lock()
162149

150+
def enable(self):
151+
with self._lock:
152+
self._done = False
153+
163154
def register_process(self, p: 'subprocess.Popen[bytes]'):
164155
"""Register a process for potential cancellation."""
165156
with self._lock:
@@ -168,7 +159,7 @@ def register_process(self, p: 'subprocess.Popen[bytes]'):
168159
return
169160
kill_process_ignore_exceptions(p)
170161

171-
def signal(self):
162+
def kill_all_processes(self):
172163
"""Cancel any pending work."""
173164
with self._lock:
174165
self._done = True
@@ -265,21 +256,31 @@ def __post_init__(self, sequence_examples: List[tf.train.SequenceExample]):
265256
assert not hasattr(self, 'sequence_examples')
266257

267258

268-
class CompilationRunner:
269-
"""Base class for collecting compilation data."""
259+
class CompilationRunnerStub(metaclass=abc.ABCMeta):
260+
"""The interface of a stub to CompilationRunner, for type checkers."""
270261

271-
_POOL: concurrent.futures.ThreadPoolExecutor = None
262+
@abc.abstractmethod
263+
def collect_data(
264+
self, file_paths: Tuple[str, ...], tf_policy_path: str,
265+
reward_stat: Optional[Dict[str, RewardStat]]
266+
) -> WorkerFuture[CompilationResult]:
267+
raise NotImplementedError()
272268

273-
@staticmethod
274-
def init_pool():
275-
"""Worker process initialization."""
276-
CompilationRunner._POOL = concurrent.futures.ThreadPoolExecutor()
269+
@abc.abstractmethod
270+
def cancel_all_work(self) -> WorkerFuture:
271+
raise NotImplementedError()
277272

278-
@staticmethod
279-
def _get_pool():
280-
"""Internal API for fetching the cancellation token waiting pool."""
281-
assert CompilationRunner._POOL
282-
return CompilationRunner._POOL
273+
@abc.abstractmethod
274+
def enable(self) -> WorkerFuture:
275+
raise NotImplementedError()
276+
277+
278+
class CompilationRunner(Worker):
279+
"""Base class for collecting compilation data."""
280+
281+
@classmethod
282+
def is_priority_method(cls, method_name: str) -> bool:
283+
return method_name == 'cancel_all_work'
283284

284285
def __init__(self,
285286
clang_path: Optional[str] = None,
@@ -302,40 +303,18 @@ def __init__(self,
302303
self._additional_flags = additional_flags
303304
self._delete_flags = delete_flags
304305
self._compilation_timeout = _COMPILATION_TIMEOUT.value
306+
self._cancellation_manager = WorkerCancellationManager()
305307

306-
def _get_cancellation_manager(
307-
self, cancellation_token: Optional[ProcessCancellationToken]
308-
) -> Optional[WorkerCancellationManager]:
309-
"""Convert the ProcessCancellationToken into a WorkerCancellationManager.
310-
311-
The conversion also registers the ProcessCancellationToken wait() on a
312-
thread which will call the WorkerCancellationManager upon completion.
313-
Since the token is always signaled, the thread always completes its work.
314-
315-
Args:
316-
cancellation_token: the ProcessCancellationToken to convert.
317-
318-
Returns:
319-
a WorkerCancellationManager, if a ProcessCancellationToken was given.
320-
"""
321-
if not cancellation_token:
322-
return None
323-
ret = WorkerCancellationManager()
324-
325-
def signaler():
326-
cancellation_token.wait()
327-
ret.signal()
308+
# re-allow the cancellation manager accept work.
309+
def enable(self):
310+
self._cancellation_manager.enable()
328311

329-
CompilationRunner._get_pool().submit(signaler)
330-
return ret
312+
def cancel_all_work(self):
313+
self._cancellation_manager.kill_all_processes()
331314

332315
def collect_data(
333-
self,
334-
file_paths: Tuple[str, ...],
335-
tf_policy_path: str,
336-
reward_stat: Optional[Dict[str, RewardStat]],
337-
cancellation_token: Optional[ProcessCancellationToken] = None
338-
) -> CompilationResult:
316+
self, file_paths: Tuple[str, ...], tf_policy_path: str,
317+
reward_stat: Optional[Dict[str, RewardStat]]) -> CompilationResult:
339318
"""Collect data for the given IR file and policy.
340319
341320
Args:
@@ -355,14 +334,12 @@ def collect_data(
355334
compilation_runner.ProcessKilledException is passed through.
356335
ValueError if example under default policy and ml policy does not match.
357336
"""
358-
cancellation_manager = self._get_cancellation_manager(cancellation_token)
359-
360337
if reward_stat is None:
361338
default_result = self._compile_fn(
362339
file_paths,
363340
tf_policy_path='',
364341
reward_only=bool(tf_policy_path),
365-
cancellation_manager=cancellation_manager)
342+
cancellation_manager=self._cancellation_manager)
366343
reward_stat = {
367344
k: RewardStat(v[1], v[1]) for (k, v) in default_result.items()
368345
}
@@ -372,7 +349,7 @@ def collect_data(
372349
file_paths,
373350
tf_policy_path,
374351
reward_only=False,
375-
cancellation_manager=cancellation_manager)
352+
cancellation_manager=self._cancellation_manager)
376353
else:
377354
policy_result = default_result
378355

0 commit comments

Comments
 (0)