Skip to content

Commit d3ee08e

Browse files
committed
Remove custom timeout
1 parent 3c71bf3 commit d3ee08e

File tree

7 files changed

+43
-85
lines changed

7 files changed

+43
-85
lines changed

compiler_opt/distributed/local/local_worker_manager.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,16 +31,16 @@
3131
import dataclasses
3232
import functools
3333
import multiprocessing
34-
import threading
3534
import psutil
35+
import threading
3636

3737
from absl import logging
3838
# pylint: disable=unused-import
3939
from compiler_opt.distributed.worker import Worker
4040

4141
from contextlib import AbstractContextManager
4242
from multiprocessing import connection
43-
from typing import Any, Callable, Dict, Optional, List
43+
from typing import Any, Callable, Dict, List, Optional
4444

4545

4646
@dataclasses.dataclass(frozen=True)

compiler_opt/distributed/worker.py

Lines changed: 1 addition & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
# limitations under the License.
1515
"""Common abstraction for a worker contract."""
1616

17-
from typing import Iterable, Optional, TypeVar, Protocol, runtime_checkable
17+
from typing import Iterable, Optional, Protocol, TypeVar
1818

1919

2020
class Worker(Protocol):
@@ -25,17 +25,6 @@ def is_priority_method(cls, method_name: str) -> bool:
2525
return False
2626

2727

28-
@runtime_checkable
29-
class ContextAwareWorker(Worker, Protocol):
30-
"""ContextAwareWorkers use set_context to modify internal state, this allows
31-
it to behave differently when run remotely vs locally. The user of a
32-
ContextAwareWorker can check for this with isinstance(obj, ContextAwareWorker)
33-
"""
34-
35-
def set_context(self, local: bool) -> None:
36-
return
37-
38-
3928
T = TypeVar('T')
4029

4130

compiler_opt/rl/compilation_runner.py

Lines changed: 26 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@
2121
import signal
2222
import subprocess
2323
import threading
24-
import time
2524
from typing import Dict, List, Optional, Tuple
2625

2726
from absl import flags
@@ -104,22 +103,14 @@ class WorkerCancellationManager:
104103
managing resources.
105104
"""
106105

107-
@dataclasses.dataclass
108-
class ProcData:
109-
process: 'subprocess.Popen[bytes]'
110-
timeout: threading.Timer
111-
time_left: float
112-
start_time: float
113-
114-
def __init__(self, timeout: float = _COMPILATION_TIMEOUT.value):
106+
def __init__(self):
115107
# the queue is filled only by workers, and drained only by the single
116108
# consumer. we use _done to manage access to the queue. We can then assume
117109
# empty() is accurate and get() never blocks.
118-
self._processes: Dict[int, WorkerCancellationManager.ProcData] = {}
110+
self._processes = set()
119111
self._done = False
120112
self._paused = False
121113
self._lock = threading.Lock()
122-
self._timeout = timeout
123114

124115
def enable(self):
125116
with self._lock:
@@ -129,65 +120,40 @@ def register_process(self, p: 'subprocess.Popen[bytes]'):
129120
"""Register a process for potential cancellation."""
130121
with self._lock:
131122
if not self._done:
132-
self._processes[p.pid] = self.ProcData(
133-
process=p,
134-
timeout=threading.Timer(self._timeout,
135-
kill_process_ignore_exceptions, (p,)),
136-
time_left=self._timeout,
137-
start_time=time.time())
138-
if self._paused:
139-
os.kill(p.pid, signal.SIGSTOP)
140-
else:
141-
self._processes[p.pid].timeout.start()
123+
self._processes.add(p)
142124
return
143125
kill_process_ignore_exceptions(p)
144126

145127
def kill_all_processes(self):
146128
"""Cancel any pending work."""
147129
with self._lock:
148130
self._done = True
149-
for pdata in self._processes.values():
150-
kill_process_ignore_exceptions(pdata.process)
131+
for p in self._processes:
132+
kill_process_ignore_exceptions(p)
151133

152134
def pause_all_processes(self):
153135
with self._lock:
154136
if self._paused:
155137
return
156138
self._paused = True
157139

158-
cur_time = time.time()
159-
for pid, pdata in self._processes.items():
160-
pdata.timeout.cancel()
161-
pdata.time_left -= cur_time - pdata.start_time
162-
if pdata.time_left > 0:
163-
# used to send the STOP signal; does not actually kill the process
164-
os.kill(pid, signal.SIGSTOP)
165-
else:
166-
# In case we cancelled right after the timeout expired,
167-
# but before actually killing the process.
168-
kill_process_ignore_exceptions(pdata.process)
140+
for p in self._processes:
141+
# used to send the STOP signal; does not actually kill the process
142+
os.kill(p.pid, signal.SIGSTOP)
169143

170144
def resume_all_processes(self):
171145
with self._lock:
172146
if not self._paused:
173147
return
174148
self._paused = False
175149

176-
cur_time = time.time()
177-
for pid, pdata in self._processes.items():
178-
pdata.timeout = threading.Timer(pdata.time_left,
179-
kill_process_ignore_exceptions,
180-
(pdata.process,))
181-
pdata.timeout.start()
182-
pdata.start_time = cur_time
150+
for p in self._processes:
183151
# used to send the CONTINUE signal; does not actually kill the process
184-
os.kill(pid, signal.SIGCONT)
152+
os.kill(p.pid, signal.SIGCONT)
185153

186154
def unregister_process(self, p: 'subprocess.Popen[bytes]'):
187155
with self._lock:
188-
if p.pid in self._processes:
189-
self._processes[p.pid].timeout.cancel()
190-
del self._processes[p.pid]
156+
self._processes.remove(p)
191157

192158
def __del__(self):
193159
if len(self._processes) > 0:
@@ -196,6 +162,7 @@ def __del__(self):
196162

197163
def start_cancellable_process(
198164
cmdline: List[str],
165+
timeout: float,
199166
cancellation_manager: Optional[WorkerCancellationManager],
200167
want_output: bool = False) -> Optional[bytes]:
201168
"""Start a cancellable process.
@@ -224,10 +191,15 @@ def start_cancellable_process(
224191
if cancellation_manager:
225192
cancellation_manager.register_process(p)
226193

227-
retcode = p.wait()
194+
try:
195+
retcode = p.wait(timeout=timeout)
196+
except subprocess.TimeoutExpired as e:
197+
kill_process_ignore_exceptions(p)
198+
raise e
199+
finally:
200+
if cancellation_manager:
201+
cancellation_manager.unregister_process(p)
228202

229-
if cancellation_manager:
230-
cancellation_manager.unregister_process(p)
231203
if retcode != 0:
232204
raise ProcessKilledError(
233205
) if retcode == -9 else subprocess.CalledProcessError(retcode, cmdline)
@@ -307,12 +279,10 @@ def is_priority_method(cls, method_name: str) -> bool:
307279
'cancel_all_work', 'pause_all_work', 'resume_all_work'
308280
}
309281

310-
def __init__(
311-
self,
312-
clang_path: Optional[str] = None,
313-
launcher_path: Optional[str] = None,
314-
moving_average_decay_rate: float = 1,
315-
cancellation_manager: Optional[WorkerCancellationManager] = None):
282+
def __init__(self,
283+
clang_path: Optional[str] = None,
284+
launcher_path: Optional[str] = None,
285+
moving_average_decay_rate: float = 1):
316286
"""Initialization of CompilationRunner class.
317287
318288
Args:
@@ -323,8 +293,8 @@ def __init__(
323293
self._clang_path = clang_path
324294
self._launcher_path = launcher_path
325295
self._moving_average_decay_rate = moving_average_decay_rate
326-
self._cancellation_manager = (
327-
cancellation_manager or WorkerCancellationManager())
296+
self._compilation_timeout = _COMPILATION_TIMEOUT.value
297+
self._cancellation_manager = WorkerCancellationManager()
328298

329299
# re-allow the cancellation manager accept work.
330300
def enable(self):

compiler_opt/rl/compilation_runner_test.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -214,9 +214,9 @@ def test_exception_handling(self, mock_compile_fn):
214214
self.assertEqual(1, mock_compile_fn.call_count)
215215

216216
def test_start_subprocess_output(self):
217-
cm = compilation_runner.WorkerCancellationManager(100)
217+
cm = compilation_runner.WorkerCancellationManager()
218218
output = compilation_runner.start_cancellable_process(
219-
['ls', '-l'], cancellation_manager=cm, want_output=True)
219+
['ls', '-l'], timeout=100, cancellation_manager=cm, want_output=True)
220220
if output:
221221
output_str = output.decode('utf-8')
222222
else:
@@ -228,30 +228,30 @@ def test_timeout_kills_process(self):
228228
'test_timeout_kills_test_file')
229229
if os.path.exists(sentinel_file):
230230
os.remove(sentinel_file)
231-
with self.assertRaises(compilation_runner.ProcessKilledError):
232-
cm = compilation_runner.WorkerCancellationManager(0.5)
231+
with self.assertRaises(subprocess.TimeoutExpired):
233232
compilation_runner.start_cancellable_process(
234233
['bash', '-c', 'sleep 1s ; touch ' + sentinel_file],
235-
cancellation_manager=cm)
234+
timeout=0.5,
235+
cancellation_manager=None)
236236
time.sleep(2)
237237
self.assertFalse(os.path.exists(sentinel_file))
238238

239239
def test_pause_resume(self):
240-
# This also makes sure timeouts are restored properly.
241-
cm = compilation_runner.WorkerCancellationManager(1)
240+
cm = compilation_runner.WorkerCancellationManager()
242241
start_time = time.time()
243242

244243
def stop_and_start():
245244
time.sleep(0.25)
246245
cm.pause_all_processes()
247-
time.sleep(2)
246+
time.sleep(1)
248247
cm.resume_all_processes()
249248

250249
threading.Thread(target=stop_and_start).start()
251250
compilation_runner.start_cancellable_process(['sleep', '0.5'],
251+
30,
252252
cancellation_manager=cm)
253-
# should be at least 2 seconds due to the pause.
254-
self.assertGreater(time.time() - start_time, 2)
253+
# should be at least 1 second due to the pause.
254+
self.assertGreater(time.time() - start_time, 1)
255255

256256

257257
if __name__ == '__main__':

compiler_opt/rl/inlining/inlining_runner.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,10 +90,12 @@ def compile_fn(
9090
command_line.extend(
9191
['-mllvm', '-ml-inliner-model-under-training=' + tf_policy_path])
9292
compilation_runner.start_cancellable_process(command_line,
93+
self._compilation_timeout,
9394
cancellation_manager)
9495
command_line = [self._llvm_size_path, output_native_path]
9596
output_bytes = compilation_runner.start_cancellable_process(
9697
command_line,
98+
timeout=self._compilation_timeout,
9799
cancellation_manager=cancellation_manager,
98100
want_output=True)
99101
if not output_bytes:

compiler_opt/rl/local_data_collector_test.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -80,13 +80,9 @@ def mock_collect_data(module_spec, tf_policy_dir, reward_stat):
8080
class Sleeper(compilation_runner.CompilationRunner):
8181
"""Test CompilationRunner that just sleeps."""
8282

83-
def __init__(self):
84-
super().__init__(
85-
cancellation_manager=compilation_runner.WorkerCancellationManager(3600))
86-
8783
def collect_data(self, module_spec, tf_policy_path, reward_stat):
8884
_ = module_spec, tf_policy_path, reward_stat
89-
compilation_runner.start_cancellable_process(['sleep', '3600s'],
85+
compilation_runner.start_cancellable_process(['sleep', '3600s'], 3600,
9086
self._cancellation_manager)
9187

9288
return compilation_runner.CompilationResult(

compiler_opt/rl/regalloc/regalloc_runner.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,7 @@ def compile_fn(
8787
if tf_policy_path:
8888
command_line.extend(['-mllvm', '-regalloc-model=' + tf_policy_path])
8989
compilation_runner.start_cancellable_process(command_line,
90+
self._compilation_timeout,
9091
cancellation_manager)
9192

9293
sequence_example = struct_pb2.Struct()

0 commit comments

Comments
 (0)