Skip to content

Commit 087d446

Browse files
Add option to generate_default_trace to output keys (#477)
This patch adds an option to generate_default_trace to output the keys associated with the examples. This is primarily intended for use in extracting functions that contain eviction decisions to improve the efficiency of the new training workflow for ES regalloc trace modelling. However, this might also be useful in other circumstances if more introspection into the data is needed.
1 parent 24865b4 commit 087d446

File tree

2 files changed

+28
-7
lines changed

2 files changed

+28
-7
lines changed

compiler_opt/tools/generate_default_trace.py

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,9 @@
6464
_GIN_BINDINGS = flags.DEFINE_multi_string(
6565
'gin_bindings', [],
6666
'Gin bindings to override the values set in the config files.')
67+
_KEYS_FILE = flags.DEFINE_string(
68+
'keys_file', None,
69+
'The path to the file to write out the keys encountered.')
6770

6871

6972
class FilteringWorker(worker.Worker):
@@ -86,24 +89,28 @@ def __init__(self, policy_path: str | None, key_filter: str | None,
8689

8790
def compile_and_filter(
8891
self, loaded_module_spec: corpus.LoadedModuleSpec
89-
) -> tuple[str, list[str], dict[str, compilation_runner.RewardStat]]:
92+
) -> tuple[str, list[str], dict[str, compilation_runner.RewardStat],
93+
list[str]]:
9094
data = self._runner.collect_data(
9195
loaded_module_spec=loaded_module_spec,
9296
policy=self._policy,
9397
reward_stat=None,
9498
model_id=0)
9599
if self._key_filter is None:
96100
return (loaded_module_spec.name, data.serialized_sequence_examples,
97-
data.reward_stats)
101+
data.reward_stats, data.keys)
98102
new_reward_stats = {}
99103
new_sequence_examples = []
104+
new_keys = []
100105
for k, sequence_example in zip(data.keys,
101106
data.serialized_sequence_examples):
102107
if not self._key_filter.match(k):
103108
continue
104109
new_reward_stats[k] = data.reward_stats[k]
105110
new_sequence_examples.append(sequence_example)
106-
return (loaded_module_spec.name, new_sequence_examples, new_reward_stats)
111+
new_keys.append(k)
112+
return (loaded_module_spec.name, new_sequence_examples, new_reward_stats,
113+
new_keys)
107114

108115

109116
def main(_):
@@ -147,6 +154,7 @@ def generate_trace(worker_manager_class: type[
147154
work = [
148155
cps.load_module_spec(corpus_element) for corpus_element in corpus_elements
149156
]
157+
all_keys = []
150158

151159
runner_type = config.get_runner_type()
152160
with tfrecord_context as tfrecord_writer:
@@ -178,7 +186,8 @@ def generate_trace(worker_manager_class: type[
178186
total_successful_examples += len(succeeded)
179187
total_failed_examples += (len(done) - len(succeeded))
180188
for r in succeeded:
181-
module_name, records, reward_stat = r.result()
189+
module_name, records, reward_stat, keys = r.result()
190+
all_keys.extend(keys)
182191
if tfrecord_writer:
183192
total_training_examples += len(records)
184193
for r in records:
@@ -196,6 +205,10 @@ def generate_trace(worker_manager_class: type[
196205
f'succeeded, and {total_training_examples} trainining examples '
197206
'written')
198207

208+
if _KEYS_FILE.value is not None:
209+
with open(_KEYS_FILE.value, 'w', encoding='utf-8') as keys_file:
210+
keys_file.write('\n'.join(str(key) for key in all_keys) + '\n')
211+
199212

200213
if __name__ == '__main__':
201214
flags.mark_flag_as_required('data_path')

compiler_opt/tools/generate_default_trace_test.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -61,16 +61,17 @@ def collect_data(self,
6161
sequence_example = text_format.Parse(sequence_example_text,
6262
tf.train.SequenceExample())
6363

64+
key = f'key_{os.getpid()}'
6465
return compilation_runner.CompilationResult(
6566
sequence_examples=[sequence_example],
6667
reward_stats={
67-
'default':
68+
key:
6869
compilation_runner.RewardStat(
6970
default_reward=1, moving_average_reward=2)
7071
},
7172
rewards=[1.2],
7273
policy_rewards=[18],
73-
keys=['default'],
74+
keys=[key],
7475
model_id=model_id)
7576

7677

@@ -111,9 +112,16 @@ def test_generate_trace(self, mock_get_runner):
111112
output_path=os.path.join(tmp_dir.full_path, 'output'),
112113
output_performance_path=os.path.join(tmp_dir.full_path,
113114
'output_performance'),
114-
):
115+
keys_file=os.path.join(tmp_dir.full_path, 'keys_file')):
115116
generate_default_trace.generate_trace()
116117

118+
with open(
119+
os.path.join(tmp_dir.full_path, 'keys_file'),
120+
encoding='utf-8') as keys_file:
121+
keys = [key_line.strip() for key_line in keys_file.readlines()]
122+
for key in keys:
123+
self.assertStartsWith(key, 'key_')
124+
117125

118126
if __name__ == '__main__':
119127
multiprocessing.handle_main(absltest.main)

0 commit comments

Comments
 (0)