Skip to content

Commit f89c355

Browse files
authored
Switch generate_default_trace to worker infrastructure. (#222)
1 parent 2f63d6b commit f89c355

File tree

2 files changed

+88
-115
lines changed

2 files changed

+88
-115
lines changed

compiler_opt/tools/generate_default_trace.py

Lines changed: 84 additions & 113 deletions
Original file line numberDiff line numberDiff line change
@@ -14,26 +14,30 @@
1414
# limitations under the License.
1515
"""Generate initial training data from the behavior of the current heuristic."""
1616

17+
import concurrent.futures
1718
import contextlib
1819
import functools
19-
import os
20-
import queue
2120
import re
22-
import subprocess
2321
from typing import Dict, List, Optional, Union, Tuple # pylint:disable=unused-import
2422

2523
from absl import app
2624
from absl import flags
2725
from absl import logging
2826
import gin
29-
import multiprocessing
27+
3028
import tensorflow as tf
3129

30+
from compiler_opt.distributed import worker
31+
from compiler_opt.distributed import buffered_scheduler
32+
from compiler_opt.distributed.local import local_worker_manager
33+
3234
from compiler_opt.rl import compilation_runner
3335
from compiler_opt.rl import corpus
3436
from compiler_opt.rl import policy_saver
3537
from compiler_opt.rl import registry
3638

39+
from tf_agents.system import system_multiprocessing as multiprocessing
40+
3741
# see https://bugs.python.org/issue33315 - we do need these types, but must
3842
# currently use them as string annotations
3943

@@ -76,64 +80,45 @@ def get_runner() -> compilation_runner.CompilationRunner:
7680
return problem_config.get_runner_type()(moving_average_decay_rate=0)
7781

7882

79-
def worker(policy_path: Optional[str],
80-
work_queue: 'queue.Queue[corpus.LoadedModuleSpec]',
81-
results_queue: 'queue.Queue[ResultsQueueEntry]',
82-
key_filter: Optional[str]):
83-
"""Describes the job each paralleled worker process does.
83+
class FilteringWorker(worker.Worker):
84+
"""Worker that performs a computation and optionally filters the result.
8485
85-
The worker picks a workitem from the work_queue, process it, and deposits
86-
a result on the results_queue, in either success or failure cases.
87-
The results_queue items are tuples (workitem, result). On failure, the result
88-
is None.
8986
9087
Args:
91-
runner: the data collector.
9288
policy_path: the policy_path to generate trace with.
93-
work_queue: the queue of unprocessed work items.
94-
results_queue: the queue where results are deposited.
9589
key_filter: regex filter for key names to include, or None to include all.
9690
"""
97-
try:
98-
runner = get_runner()
99-
m = re.compile(key_filter) if key_filter else None
100-
policy = policy_saver.Policy.from_filesystem(
91+
92+
def __init__(self, policy_path: Optional[str], key_filter: Optional[str]):
93+
self._policy_path = policy_path
94+
self._key_filter = re.compile(key_filter) if key_filter else None
95+
self._runner = get_runner()
96+
self._policy = policy_saver.Policy.from_filesystem(
10197
policy_path) if policy_path else None
102-
while True:
103-
try:
104-
loaded_module_spec = work_queue.get_nowait()
105-
except queue.Empty:
106-
return
107-
try:
108-
data = runner.collect_data(
109-
loaded_module_spec=loaded_module_spec,
110-
policy=policy,
111-
reward_stat=None,
112-
model_id=0)
113-
if not m:
114-
results_queue.put(
115-
(loaded_module_spec.name, data.serialized_sequence_examples,
116-
data.reward_stats))
117-
continue
118-
new_reward_stats = {}
119-
new_sequence_examples = []
120-
for k, sequence_example in zip(data.keys,
121-
data.serialized_sequence_examples):
122-
if not m.match(k):
123-
continue
124-
new_reward_stats[k] = data.reward_stats[k]
125-
new_sequence_examples.append(sequence_example)
126-
results_queue.put(
127-
(loaded_module_spec.name, new_sequence_examples, new_reward_stats))
128-
except (subprocess.CalledProcessError, subprocess.TimeoutExpired,
129-
RuntimeError):
130-
logging.error('Failed to compile %s.', loaded_module_spec.name)
131-
results_queue.put(None)
132-
except BaseException as e: # pylint: disable=broad-except
133-
results_queue.put(e)
134-
135-
136-
def main(_):
98+
99+
def compile_and_filter(
100+
self, loaded_module_spec: corpus.LoadedModuleSpec
101+
) -> Tuple[str, List[str], Dict[str, compilation_runner.RewardStat]]:
102+
data = self._runner.collect_data(
103+
loaded_module_spec=loaded_module_spec,
104+
policy=self._policy,
105+
reward_stat=None,
106+
model_id=0)
107+
if self._key_filter is None:
108+
return (loaded_module_spec.name, data.serialized_sequence_examples,
109+
data.reward_stats)
110+
new_reward_stats = {}
111+
new_sequence_examples = []
112+
for k, sequence_example in zip(data.keys,
113+
data.serialized_sequence_examples):
114+
if not self._key_filter.match(k):
115+
continue
116+
new_reward_stats[k] = data.reward_stats[k]
117+
new_sequence_examples.append(sequence_example)
118+
return (loaded_module_spec.name, new_sequence_examples, new_reward_stats)
119+
120+
121+
def main(worker_manager_class=local_worker_manager.LocalWorkerPoolManager):
137122

138123
gin.parse_config_files_and_bindings(
139124
_GIN_FILES.value, bindings=_GIN_BINDINGS.value, skip_unknown=False)
@@ -160,74 +145,60 @@ def main(_):
160145
# other smaller files are processed in parallel
161146
corpus_elements = cps.sample(k=sampled_modules, sort=True)
162147

163-
worker_count = (
164-
min(os.cpu_count(), _NUM_WORKERS.value)
165-
if _NUM_WORKERS.value else os.cpu_count())
166-
167148
tfrecord_context = (
168149
tf.io.TFRecordWriter(_OUTPUT_PATH.value)
169150
if _OUTPUT_PATH.value else contextlib.nullcontext())
170151
performance_context = (
171152
tf.io.gfile.GFile(_OUTPUT_PERFORMANCE_PATH.value, 'w')
172153
if _OUTPUT_PERFORMANCE_PATH.value else contextlib.nullcontext())
154+
work = [
155+
cps.load_module_spec(corpus_element) for corpus_element in corpus_elements
156+
]
173157

174158
with tfrecord_context as tfrecord_writer:
175159
with performance_context as performance_writer:
176-
ctx = multiprocessing.get_context()
177-
m = ctx.Manager()
178-
results_queue: 'queue.Queue[ResultsQueueEntry]' = m.Queue()
179-
work_queue: 'queue.Queue[corpus.LoadedModuleSpec]' = m.Queue()
180-
for corpus_element in corpus_elements:
181-
work_queue.put(cps.load_module_spec(corpus_element))
182-
183-
# pylint:disable=g-complex-comprehension
184-
processes = [
185-
ctx.Process(
186-
target=functools.partial(worker, _POLICY_PATH.value, work_queue,
187-
results_queue, _KEY_FILTER.value))
188-
for _ in range(0, worker_count)
189-
]
190-
# pylint:enable=g-complex-comprehension
191-
192-
for p in processes:
193-
p.start()
194-
195-
total_successful_examples = 0
196-
total_work = len(corpus_elements)
197-
total_failed_examples = 0
198-
total_training_examples = 0
199-
for _ in range(total_work):
200-
logging.log_every_n_seconds(logging.INFO,
201-
'%d success, %d failed out of %d', 10,
202-
total_successful_examples,
203-
total_failed_examples, total_work)
204-
205-
results = results_queue.get()
206-
if isinstance(results, BaseException):
207-
logging.fatal(results)
208-
if not results:
209-
total_failed_examples += 1
210-
continue
211-
212-
total_successful_examples += 1
213-
module_name, records, reward_stat = results
214-
if tfrecord_writer:
215-
total_training_examples += len(records)
216-
for r in records:
217-
tfrecord_writer.write(r)
218-
if performance_writer:
219-
for key, value in reward_stat.items():
220-
performance_writer.write(
221-
(f'{module_name},{key},{value.default_reward},'
222-
f'{value.moving_average_reward}\n'))
223-
224-
print((f'{total_successful_examples} of {len(corpus_elements)} modules '
225-
f'succeeded, and {total_training_examples} trainining examples '
226-
'written'))
227-
for p in processes:
228-
p.join()
160+
with worker_manager_class(
161+
FilteringWorker,
162+
_NUM_WORKERS.value,
163+
policy_path=_POLICY_PATH.value,
164+
key_filter=_KEY_FILTER.value) as lwm:
165+
166+
_, result_futures = buffered_scheduler.schedule_on_worker_pool(
167+
action=lambda w, j: w.compile_and_filter(j),
168+
jobs=work,
169+
worker_pool=lwm)
170+
total_successful_examples = 0
171+
total_work = len(corpus_elements)
172+
total_failed_examples = 0
173+
total_training_examples = 0
174+
not_done = result_futures
175+
while not_done:
176+
(done, not_done) = concurrent.futures.wait(not_done, 10)
177+
succeeded = [
178+
r for r in done if not r.cancelled() and r.exception() is None
179+
]
180+
total_successful_examples += len(succeeded)
181+
total_failed_examples += (len(done) - len(succeeded))
182+
for r in succeeded:
183+
module_name, records, reward_stat = r.result()
184+
if tfrecord_writer:
185+
total_training_examples += len(records)
186+
for r in records:
187+
tfrecord_writer.write(r)
188+
if performance_writer:
189+
for key, value in reward_stat.items():
190+
performance_writer.write(
191+
(f'{module_name},{key},{value.default_reward},'
192+
f'{value.moving_average_reward}\n'))
193+
logging.info('%d success, %d failed out of %d',
194+
total_successful_examples, total_failed_examples,
195+
total_work)
196+
197+
print((f'{total_successful_examples} of {len(corpus_elements)} modules '
198+
f'succeeded, and {total_training_examples} trainining examples '
199+
'written'))
229200

230201

231202
if __name__ == '__main__':
232203
flags.mark_flag_as_required('data_path')
233-
app.run(main)
204+
multiprocessing.handle_main(functools.partial(app.run, main))

compiler_opt/tools/generate_default_trace_test.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,8 @@
2828
from compiler_opt.rl import compilation_runner
2929
from compiler_opt.tools import generate_default_trace
3030

31+
from tf_agents.system import system_multiprocessing as multiprocessing
32+
3133
flags.FLAGS['num_workers'].allow_override = True
3234
flags.FLAGS['gin_files'].allow_override = True
3335
flags.FLAGS['gin_bindings'].allow_override = True
@@ -105,12 +107,12 @@ def test_api(self, mock_get_runner):
105107
output_performance_path=os.path.join(tmp_dir.full_path,
106108
'output_performance'),
107109
):
108-
generate_default_trace.main(None)
110+
generate_default_trace.main()
109111

110112
def test_get_runner(self):
111113
runner = generate_default_trace.get_runner()
112114
self.assertIsInstance(runner, compilation_runner.CompilationRunner)
113115

114116

115117
if __name__ == '__main__':
116-
absltest.main()
118+
multiprocessing.handle_main(absltest.main)

0 commit comments

Comments
 (0)