Skip to content

Commit d46095b

Browse files
authored
Make generate_default_trace fail-fast when hitting a bug (#61)
If the worker process hits an unrecoverable error, we stop processing instead of blocking indefinitely.
1 parent ba69015 commit d46095b

File tree

1 file changed

+42
-32
lines changed

1 file changed

+42
-32
lines changed

compiler_opt/tools/generate_default_trace.py

Lines changed: 42 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
import random
2222
import re
2323
import subprocess
24-
from typing import Dict, List, Optional, Tuple, Type # pylint:disable=unused-import
24+
from typing import Dict, List, Optional, Union, Tuple # pylint:disable=unused-import
2525

2626
from absl import app
2727
from absl import flags
@@ -66,8 +66,10 @@
6666
'gin_bindings', [],
6767
'Gin bindings to override the values set in the config files.')
6868

69-
ResultsQueueEntry = Optional[Tuple[str, List[str],
70-
Dict[str, compilation_runner.RewardStat]]]
69+
ResultsQueueEntry = Union[Optional[Tuple[str, List[str],
70+
Dict[str,
71+
compilation_runner.RewardStat]]],
72+
BaseException]
7173

7274

7375
def get_runner() -> compilation_runner.CompilationRunner:
@@ -81,7 +83,7 @@ def get_runner() -> compilation_runner.CompilationRunner:
8183

8284

8385
def worker(policy_path: str, work_queue: 'queue.Queue[corpus.ModuleSpec]',
84-
results_queue: 'queue.Queue[Optional[List[str]]]',
86+
results_queue: 'queue.Queue[ResultsQueueEntry]',
8587
key_filter: Optional[str]):
8688
"""Describes the job each paralleled worker process does.
8789
@@ -97,35 +99,41 @@ def worker(policy_path: str, work_queue: 'queue.Queue[corpus.ModuleSpec]',
9799
results_queue: the queue where results are deposited.
98100
key_filter: regex filter for key names to include, or None to include all.
99101
"""
100-
runner = get_runner()
101-
m = re.compile(key_filter) if key_filter else None
102-
103-
while True:
104-
try:
105-
module_spec = work_queue.get_nowait()
106-
except queue.Empty:
107-
return
108-
try:
109-
data = runner.collect_data(
110-
module_spec=module_spec, tf_policy_path=policy_path, reward_stat=None)
111-
if not m:
112-
results_queue.put((module_spec.name, data.serialized_sequence_examples,
113-
data.reward_stats))
114-
continue
115-
new_reward_stats = {}
116-
new_sequence_examples = []
117-
for k, sequence_example in zip(data.keys,
118-
data.serialized_sequence_examples):
119-
if not m.match(k):
102+
try:
103+
runner = get_runner()
104+
m = re.compile(key_filter) if key_filter else None
105+
106+
while True:
107+
try:
108+
module_spec = work_queue.get_nowait()
109+
except queue.Empty:
110+
return
111+
try:
112+
data = runner.collect_data(
113+
module_spec=module_spec,
114+
tf_policy_path=policy_path,
115+
reward_stat=None)
116+
if not m:
117+
results_queue.put(
118+
(module_spec.name, data.serialized_sequence_examples,
119+
data.reward_stats))
120120
continue
121-
new_reward_stats[k] = data.reward_stats[k]
122-
new_sequence_examples.append(sequence_example)
123-
results_queue.put(
124-
(module_spec.name, new_sequence_examples, new_reward_stats))
125-
except (subprocess.CalledProcessError, subprocess.TimeoutExpired,
126-
RuntimeError):
127-
logging.error('Failed to compile %s.', module_spec.name)
128-
results_queue.put(None)
121+
new_reward_stats = {}
122+
new_sequence_examples = []
123+
for k, sequence_example in zip(data.keys,
124+
data.serialized_sequence_examples):
125+
if not m.match(k):
126+
continue
127+
new_reward_stats[k] = data.reward_stats[k]
128+
new_sequence_examples.append(sequence_example)
129+
results_queue.put(
130+
(module_spec.name, new_sequence_examples, new_reward_stats))
131+
except (subprocess.CalledProcessError, subprocess.TimeoutExpired,
132+
RuntimeError):
133+
logging.error('Failed to compile %s.', module_spec.name)
134+
results_queue.put(None)
135+
except BaseException as e: # pylint: disable=broad-except
136+
results_queue.put(e)
129137

130138

131139
def main(_):
@@ -206,6 +214,8 @@ def main(_):
206214
total_failed_examples, total_work)
207215

208216
results = results_queue.get()
217+
if isinstance(results, BaseException):
218+
logging.fatal(results)
209219
if not results:
210220
total_failed_examples += 1
211221
continue

0 commit comments

Comments
 (0)