Skip to content

Commit 573bbf1

Browse files
committed
Move logic to runners
1 parent f505374 commit 573bbf1

File tree

7 files changed

+114
-99
lines changed

7 files changed

+114
-99
lines changed

compiler_opt/distributed/local/local_worker_manager.py

Lines changed: 0 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -32,9 +32,7 @@
3232
import functools
3333
import multiprocessing
3434
import threading
35-
import os
3635
import psutil
37-
import signal
3836

3937
from absl import logging
4038
# pylint: disable=unused-import
@@ -134,7 +132,6 @@ def __init__(self):
134132
# when we stop.
135133
self._lock = threading.Lock()
136134
self._map: Dict[int, concurrent.futures.Future] = {}
137-
self.is_paused = False
138135

139136
# thread draining the pipe
140137
self._pump = threading.Thread(target=self._msg_pump)
@@ -209,25 +206,10 @@ def shutdown(self):
209206
try:
210207
# Killing the process triggers observer exit, which triggers msg_pump
211208
# exit
212-
self.resume()
213209
self._process.kill()
214210
except: # pylint: disable=bare-except
215211
pass
216212

217-
def pause(self):
218-
if self.is_paused:
219-
return
220-
self.is_paused = True
221-
# used to send the STOP signal; does not actually kill the process
222-
os.kill(self._process.pid, signal.SIGSTOP)
223-
224-
def resume(self):
225-
if not self.is_paused:
226-
return
227-
self.is_paused = False
228-
# used to send the CONTINUE signal; does not actually kill the process
229-
os.kill(self._process.pid, signal.SIGCONT)
230-
231213
def set_nice(self, val: int):
232214
"""Sets the nice-ness of the process, this modifies how the OS
233215
schedules it. Only works on Unix, since val is presumed to be an int.

compiler_opt/distributed/local/local_worker_manager_test.py

Lines changed: 0 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -58,25 +58,6 @@ def method(self):
5858
time.sleep(3600)
5959

6060

61-
class JobCounter(Worker):
62-
"""Test worker."""
63-
64-
def __init__(self):
65-
self.times = []
66-
67-
@classmethod
68-
def is_priority_method(cls, method_name: str) -> bool:
69-
return method_name == 'get_times'
70-
71-
def start(self):
72-
while True:
73-
self.times.append(time.time())
74-
time.sleep(0.05)
75-
76-
def get_times(self):
77-
return self.times
78-
79-
8061
class LocalWorkerManagerTest(absltest.TestCase):
8162

8263
def test_pool(self):
@@ -118,34 +99,6 @@ def test_worker_crash_while_waiting(self):
11899
with self.assertRaises(concurrent.futures.CancelledError):
119100
_ = f.result()
120101

121-
def test_pause_resume(self):
122-
123-
with local_worker_manager.LocalWorkerPool(JobCounter, 1) as pool:
124-
p = pool[0]
125-
126-
# Fill the q for 1 second
127-
p.start()
128-
time.sleep(1)
129-
130-
# Then pause the process for 1 second
131-
p.pause()
132-
time.sleep(1)
133-
134-
# Then resume the process and wait 1 more second
135-
p.resume()
136-
time.sleep(1)
137-
138-
times = p.get_times().result()
139-
140-
# If pause/resume worked, there should be a gap of at least 0.5 seconds.
141-
# Otherwise, this will throw an exception.
142-
self.assertNotEmpty(times)
143-
last_time = times[0]
144-
for cur_time in times:
145-
if cur_time - last_time > 0.5:
146-
return
147-
raise ValueError('Failed to find a 2 second gap in times.')
148-
149102

150103
if __name__ == '__main__':
151104
multiprocessing.handle_test_main(absltest.main)

compiler_opt/rl/compilation_runner.py

Lines changed: 86 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,11 @@
1818
import dataclasses
1919
import json
2020
import os
21+
import signal
2122
import subprocess
2223
import threading
23-
from typing import Dict, List, Optional, Tuple
24+
import time
25+
from typing import Dict, List, Optional, Tuple, Union
2426

2527
from absl import flags
2628
from compiler_opt.distributed.worker import Worker, WorkerFuture
@@ -102,13 +104,22 @@ class WorkerCancellationManager:
102104
managing resources.
103105
"""
104106

105-
def __init__(self):
107+
@dataclasses.dataclass
108+
class ProcData:
109+
process: 'subprocess.Popen[bytes]'
110+
timeout: threading.Timer
111+
time_left: float
112+
start_time: float
113+
114+
def __init__(self, timeout: float = _COMPILATION_TIMEOUT.value):
106115
# the queue is filled only by workers, and drained only by the single
107116
# consumer. we use _done to manage access to the queue. We can then assume
108117
# empty() is accurate and get() never blocks.
109-
self._processes = set()
118+
self._processes: Dict[int, WorkerCancellationManager.ProcData] = {}
110119
self._done = False
120+
self._paused = False
111121
self._lock = threading.Lock()
122+
self._timeout = timeout
112123

113124
def enable(self):
114125
with self._lock:
@@ -118,26 +129,70 @@ def register_process(self, p: 'subprocess.Popen[bytes]'):
118129
"""Register a process for potential cancellation."""
119130
with self._lock:
120131
if not self._done:
121-
self._processes.add(p)
132+
self._processes[p.pid] = self.ProcData(
133+
process=p,
134+
timeout=threading.Timer(self._timeout,
135+
kill_process_ignore_exceptions, (p,)),
136+
time_left=self._timeout,
137+
start_time=time.time())
138+
self._processes[p.pid].timeout.start()
122139
return
123140
kill_process_ignore_exceptions(p)
124141

125142
def kill_all_processes(self):
126143
"""Cancel any pending work."""
127144
with self._lock:
128145
self._done = True
129-
for p in self._processes:
130-
kill_process_ignore_exceptions(p)
146+
for pdata in self._processes.values():
147+
kill_process_ignore_exceptions(pdata.process)
148+
149+
def pause_all_processes(self):
150+
with self._lock:
151+
if self._paused:
152+
return
153+
self._paused = True
154+
155+
cur_time = time.time()
156+
for pid, pdata in self._processes.items():
157+
pdata.timeout.cancel()
158+
pdata.time_left -= cur_time - pdata.start_time
159+
if pdata.time_left > 0:
160+
# used to send the STOP signal; does not actually kill the process
161+
os.kill(pid, signal.SIGSTOP)
162+
else:
163+
# In case we cancelled right after the timeout expired,
164+
# but before actually killing the process.
165+
kill_process_ignore_exceptions(pdata.process)
166+
167+
def resume_all_processes(self):
168+
with self._lock:
169+
if not self._paused:
170+
return
171+
self._paused = False
172+
173+
cur_time = time.time()
174+
for pid, pdata in self._processes.items():
175+
pdata.timeout = threading.Timer(pdata.time_left,
176+
kill_process_ignore_exceptions,
177+
(pdata.process,))
178+
pdata.timeout.start()
179+
pdata.start_time = cur_time
180+
# used to send the CONTINUE signal; does not actually kill the process
181+
os.kill(pid, signal.SIGCONT)
131182

132183
def unregister_process(self, p: 'subprocess.Popen[bytes]'):
133184
with self._lock:
134-
if not self._done:
135-
self._processes.remove(p)
185+
if p.pid in self._processes:
186+
self._processes[p.pid].timeout.cancel()
187+
del self._processes[p.pid]
188+
189+
def __del__(self):
190+
if len(self._processes) > 0:
191+
raise RuntimeError('Cancellation manager deleted while containing items.')
136192

137193

138194
def start_cancellable_process(
139195
cmdline: List[str],
140-
timeout: float,
141196
cancellation_manager: Optional[WorkerCancellationManager],
142197
want_output: bool = False) -> Optional[bytes]:
143198
"""Start a cancellable process.
@@ -166,14 +221,10 @@ def start_cancellable_process(
166221
if cancellation_manager:
167222
cancellation_manager.register_process(p)
168223

169-
try:
170-
retcode = p.wait(timeout=timeout)
171-
except subprocess.TimeoutExpired as e:
172-
kill_process_ignore_exceptions(p)
173-
raise e
174-
finally:
175-
if cancellation_manager:
176-
cancellation_manager.unregister_process(p)
224+
retcode = p.wait()
225+
226+
if cancellation_manager:
227+
cancellation_manager.unregister_process(p)
177228
if retcode != 0:
178229
raise ProcessKilledError(
179230
) if retcode == -9 else subprocess.CalledProcessError(retcode, cmdline)
@@ -249,12 +300,16 @@ class CompilationRunner(Worker):
249300

250301
@classmethod
251302
def is_priority_method(cls, method_name: str) -> bool:
252-
return method_name == 'cancel_all_work'
253-
254-
def __init__(self,
255-
clang_path: Optional[str] = None,
256-
launcher_path: Optional[str] = None,
257-
moving_average_decay_rate: float = 1):
303+
return method_name in {
304+
'cancel_all_work', 'pause_all_work', 'resume_all_work'
305+
}
306+
307+
def __init__(
308+
self,
309+
clang_path: Optional[str] = None,
310+
launcher_path: Optional[str] = None,
311+
moving_average_decay_rate: float = 1,
312+
cancellation_manager: Optional[WorkerCancellationManager] = None):
258313
"""Initialization of CompilationRunner class.
259314
260315
Args:
@@ -265,8 +320,8 @@ def __init__(self,
265320
self._clang_path = clang_path
266321
self._launcher_path = launcher_path
267322
self._moving_average_decay_rate = moving_average_decay_rate
268-
self._compilation_timeout = _COMPILATION_TIMEOUT.value
269-
self._cancellation_manager = WorkerCancellationManager()
323+
self._cancellation_manager = (
324+
cancellation_manager or WorkerCancellationManager())
270325

271326
# re-allow the cancellation manager accept work.
272327
def enable(self):
@@ -275,6 +330,12 @@ def enable(self):
275330
def cancel_all_work(self):
276331
self._cancellation_manager.kill_all_processes()
277332

333+
def pause_all_work(self):
334+
self._cancellation_manager.pause_all_processes()
335+
336+
def resume_all_work(self):
337+
self._cancellation_manager.resume_all_processes()
338+
278339
def collect_data(
279340
self, module_spec: corpus.ModuleSpec, tf_policy_path: str,
280341
reward_stat: Optional[Dict[str, RewardStat]]) -> CompilationResult:

compiler_opt/rl/compilation_runner_test.py

Lines changed: 23 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
import os
1818
import string
1919
import subprocess
20+
import threading
2021
import time
2122
from unittest import mock
2223

@@ -213,9 +214,9 @@ def test_exception_handling(self, mock_compile_fn):
213214
self.assertEqual(1, mock_compile_fn.call_count)
214215

215216
def test_start_subprocess_output(self):
216-
ct = compilation_runner.WorkerCancellationManager()
217+
cm = compilation_runner.WorkerCancellationManager(100)
217218
output = compilation_runner.start_cancellable_process(
218-
['ls', '-l'], timeout=100, cancellation_manager=ct, want_output=True)
219+
['ls', '-l'], cancellation_manager=cm, want_output=True)
219220
if output:
220221
output_str = output.decode('utf-8')
221222
else:
@@ -227,14 +228,31 @@ def test_timeout_kills_process(self):
227228
'test_timeout_kills_test_file')
228229
if os.path.exists(sentinel_file):
229230
os.remove(sentinel_file)
230-
with self.assertRaises(subprocess.TimeoutExpired):
231+
with self.assertRaises(compilation_runner.ProcessKilledError):
232+
cm = compilation_runner.WorkerCancellationManager(0.5)
231233
compilation_runner.start_cancellable_process(
232234
['bash', '-c', 'sleep 1s ; touch ' + sentinel_file],
233-
timeout=0.5,
234-
cancellation_manager=None)
235+
cancellation_manager=cm)
235236
time.sleep(2)
236237
self.assertFalse(os.path.exists(sentinel_file))
237238

239+
def test_pause_resume(self):
240+
# This also makes sure timeouts are restored properly.
241+
cm = compilation_runner.WorkerCancellationManager(1)
242+
start_time = time.time()
243+
244+
def stop_and_start():
245+
time.sleep(0.25)
246+
cm.pause_all_processes()
247+
time.sleep(2)
248+
cm.resume_all_processes()
249+
250+
threading.Thread(target=stop_and_start).start()
251+
compilation_runner.start_cancellable_process(['sleep', '0.5'],
252+
cancellation_manager=cm)
253+
# should be at least 2 seconds due to the pause.
254+
self.assertGreater(time.time() - start_time, 2)
255+
238256

239257
if __name__ == '__main__':
240258
tf.test.main()

compiler_opt/rl/inlining/inlining_runner.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -90,12 +90,10 @@ def _compile_fn(
9090
command_line.extend(
9191
['-mllvm', '-ml-inliner-model-under-training=' + tf_policy_path])
9292
compilation_runner.start_cancellable_process(command_line,
93-
self._compilation_timeout,
9493
cancellation_manager)
9594
command_line = [self._llvm_size_path, output_native_path]
9695
output_bytes = compilation_runner.start_cancellable_process(
9796
command_line,
98-
timeout=self._compilation_timeout,
9997
cancellation_manager=cancellation_manager,
10098
want_output=True)
10199
if not output_bytes:

compiler_opt/rl/local_data_collector_test.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,9 +80,13 @@ def mock_collect_data(module_spec, tf_policy_dir, reward_stat):
8080
class Sleeper(compilation_runner.CompilationRunner):
8181
"""Test CompilationRunner that just sleeps."""
8282

83+
def __init__(self):
84+
super().__init__(
85+
cancellation_manager=compilation_runner.WorkerCancellationManager(3600))
86+
8387
def collect_data(self, module_spec, tf_policy_path, reward_stat):
8488
_ = module_spec, tf_policy_path, reward_stat
85-
compilation_runner.start_cancellable_process(['sleep', '3600s'], 3600,
89+
compilation_runner.start_cancellable_process(['sleep', '3600s'],
8690
self._cancellation_manager)
8791

8892
return compilation_runner.CompilationResult(

compiler_opt/rl/regalloc/regalloc_runner.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,6 @@ def _compile_fn(
8787
if tf_policy_path:
8888
command_line.extend(['-mllvm', '-regalloc-model=' + tf_policy_path])
8989
compilation_runner.start_cancellable_process(command_line,
90-
self._compilation_timeout,
9190
cancellation_manager)
9291

9392
sequence_example = struct_pb2.Struct()

0 commit comments

Comments
 (0)