Skip to content

Commit 08ba9aa

Browse files
authored
Make generate_default_trace_test closer to actual use (#225)
- making `main` take care of gin bindings, and any additional parameters, so the worker pool type doesn't get overridden by `main`'s args - making sure gin configuration for the underlying worker makes it through. This also fixes the actual tool, since we weren't passing the gin config of the underlying task correctly.
1 parent 2935865 commit 08ba9aa

File tree

2 files changed

+25
-19
lines changed

2 files changed

+25
-19
lines changed

compiler_opt/tools/generate_default_trace.py

Lines changed: 15 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -75,11 +75,6 @@
7575
BaseException]
7676

7777

78-
def get_runner() -> compilation_runner.CompilationRunner:
79-
problem_config = registry.get_configuration()
80-
return problem_config.get_runner_type()(moving_average_decay_rate=0)
81-
82-
8378
class FilteringWorker(worker.Worker):
8479
"""Worker that performs a computation and optionally filters the result.
8580
@@ -89,10 +84,12 @@ class FilteringWorker(worker.Worker):
8984
key_filter: regex filter for key names to include, or None to include all.
9085
"""
9186

92-
def __init__(self, policy_path: Optional[str], key_filter: Optional[str]):
87+
def __init__(self, policy_path: Optional[str], key_filter: Optional[str],
88+
runner_type: 'type[compilation_runner.CompilationRunner]',
89+
runner_kwargs):
9390
self._policy_path = policy_path
9491
self._key_filter = re.compile(key_filter) if key_filter else None
95-
self._runner = get_runner()
92+
self._runner = runner_type(**runner_kwargs)
9693
self._policy = policy_saver.Policy.from_filesystem(
9794
policy_path) if policy_path else None
9895

@@ -118,11 +115,15 @@ def compile_and_filter(
118115
return (loaded_module_spec.name, new_sequence_examples, new_reward_stats)
119116

120117

121-
def main(worker_manager_class=local_worker_manager.LocalWorkerPoolManager):
122-
118+
def main(_):
123119
gin.parse_config_files_and_bindings(
124120
_GIN_FILES.value, bindings=_GIN_BINDINGS.value, skip_unknown=False)
125121
logging.info(gin.config_str())
122+
generate_trace()
123+
124+
125+
def generate_trace(
126+
worker_manager_class=local_worker_manager.LocalWorkerPoolManager):
126127

127128
config = registry.get_configuration()
128129

@@ -155,13 +156,17 @@ def main(worker_manager_class=local_worker_manager.LocalWorkerPoolManager):
155156
cps.load_module_spec(corpus_element) for corpus_element in corpus_elements
156157
]
157158

159+
runner_type = config.get_runner_type()
158160
with tfrecord_context as tfrecord_writer:
159161
with performance_context as performance_writer:
160162
with worker_manager_class(
161163
FilteringWorker,
162164
_NUM_WORKERS.value,
163165
policy_path=_POLICY_PATH.value,
164-
key_filter=_KEY_FILTER.value) as lwm:
166+
key_filter=_KEY_FILTER.value,
167+
runner_type=runner_type,
168+
runner_kwargs=worker.get_full_worker_args(
169+
runner_type, moving_average_decay_rate=0)) as lwm:
165170

166171
_, result_futures = buffered_scheduler.schedule_on_worker_pool(
167172
action=lambda w, j: w.compile_and_filter(j),

compiler_opt/tools/generate_default_trace_test.py

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -35,9 +35,14 @@
3535
flags.FLAGS['gin_bindings'].allow_override = True
3636

3737

38+
@gin.configurable(module='runners')
3839
class MockCompilationRunner(compilation_runner.CompilationRunner):
3940
"""A compilation runner just for test."""
4041

42+
def __init__(self, sentinel=None):
43+
assert sentinel == 42
44+
super().__init__()
45+
4146
def collect_data(self,
4247
loaded_module_spec,
4348
policy=None,
@@ -75,11 +80,11 @@ def setUp(self):
7580
with gin.unlock_config():
7681
gin.parse_config_files_and_bindings(
7782
config_files=['compiler_opt/rl/inlining/gin_configs/common.gin'],
78-
bindings=None)
83+
bindings=['runners.MockCompilationRunner.sentinel=42'])
7984
return super().setUp()
8085

81-
@mock.patch('compiler_opt.tools.generate_default_trace.get_runner')
82-
def test_api(self, mock_get_runner):
86+
@mock.patch('compiler_opt.rl.inlining.InliningConfig.get_runner_type')
87+
def test_generate_trace(self, mock_get_runner):
8388

8489
tmp_dir = self.create_tempdir()
8590
module_names = ['a', 'b', 'c', 'd']
@@ -97,7 +102,7 @@ def test_api(self, mock_get_runner):
97102
os.path.join(tmp_dir.full_path, module_name + '.cmd'), 'w') as f:
98103
f.write('-cc1')
99104

100-
mock_compilation_runner = MockCompilationRunner()
105+
mock_compilation_runner = MockCompilationRunner
101106
mock_get_runner.return_value = mock_compilation_runner
102107

103108
with flagsaver.flagsaver(
@@ -107,11 +112,7 @@ def test_api(self, mock_get_runner):
107112
output_performance_path=os.path.join(tmp_dir.full_path,
108113
'output_performance'),
109114
):
110-
generate_default_trace.main()
111-
112-
def test_get_runner(self):
113-
runner = generate_default_trace.get_runner()
114-
self.assertIsInstance(runner, compilation_runner.CompilationRunner)
115+
generate_default_trace.generate_trace()
115116

116117

117118
if __name__ == '__main__':

0 commit comments

Comments
 (0)