Skip to content

Commit bbe20bc

Browse files
authored
Add a model_id parameter to collect_data. (#178)
The model_id parameter allows joining compilation results with the model used during the compilation. This is useful for debugging errors in training as well as for ensuring experience freshness in the distributed training setting.
1 parent aef0a0b commit bbe20bc

File tree

7 files changed

+43
-27
lines changed

7 files changed

+43
-27
lines changed

compiler_opt/rl/compilation_runner.py

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -237,6 +237,9 @@ class CompilationResult:
237237
policy_rewards: List[float]
238238
keys: List[str]
239239

240+
# The id of the model used to generate this compilation result
241+
model_id: Optional[int]
242+
240243
def __post_init__(self, sequence_examples: List[tf.train.SequenceExample]):
241244
object.__setattr__(self, 'serialized_sequence_examples',
242245
[x.SerializeToString() for x in sequence_examples])
@@ -260,8 +263,8 @@ def collect_data(
260263
self,
261264
loaded_module_spec: corpus.LoadedModuleSpec,
262265
policy: Optional[policy_saver.Policy] = None,
263-
reward_stat: Optional[Dict[str, RewardStat]] = None
264-
) -> WorkerFuture[CompilationResult]:
266+
reward_stat: Optional[Dict[str, RewardStat]] = None,
267+
model_id: Optional[int] = None) -> WorkerFuture[CompilationResult]:
265268
raise NotImplementedError()
266269

267270
@abc.abstractmethod
@@ -315,17 +318,18 @@ def pause_all_work(self):
315318
def resume_all_work(self):
316319
self._cancellation_manager.resume_all_processes()
317320

318-
def collect_data(
319-
self,
320-
loaded_module_spec: corpus.LoadedModuleSpec,
321-
policy: Optional[policy_saver.Policy] = None,
322-
reward_stat: Optional[Dict[str, RewardStat]] = None) -> CompilationResult:
321+
def collect_data(self,
322+
loaded_module_spec: corpus.LoadedModuleSpec,
323+
policy: Optional[policy_saver.Policy] = None,
324+
reward_stat: Optional[Dict[str, RewardStat]] = None,
325+
model_id: Optional[int] = None) -> CompilationResult:
323326
"""Collect data for the given IR file and policy.
324327
325328
Args:
326329
loaded_module_spec: a LoadedModuleSpec.
327330
policy: serialized policy.
328331
reward_stat: reward stat of this module, None if unknown.
332+
model_id: id for the model used to collect data.
329333
330334
Returns:
331335
A CompilationResult. In particular:
@@ -341,7 +345,8 @@ def collect_data(
341345
final_cmd_line = loaded_module_spec.build_command_line(tempdir)
342346
tf_policy_path = ''
343347
if policy is not None:
344-
tf_policy_path = os.path.join(tempdir, 'policy')
348+
model_id_suffix = f'-{model_id}' if model_id is not None else ''
349+
tf_policy_path = os.path.join(tempdir, 'policy' + model_id_suffix)
345350
policy.to_filesystem(tf_policy_path)
346351

347352
if reward_stat is None:
@@ -388,7 +393,8 @@ def collect_data(
388393
reward_stats=reward_stat,
389394
rewards=rewards,
390395
policy_rewards=policy_rewards,
391-
keys=keys)
396+
keys=keys,
397+
model_id=model_id)
392398

393399
def compile_fn(
394400
self, command_line: corpus.FullyQualifiedCmdLine, tf_policy_path: str,

compiler_opt/rl/data_collector.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -52,12 +52,13 @@ class DataCollector(metaclass=abc.ABCMeta):
5252

5353
@abc.abstractmethod
5454
def collect_data(
55-
self, policy: policy_saver.Policy
55+
self, policy: policy_saver.Policy, model_id: int
5656
) -> Tuple[Iterator[trajectory.Trajectory], Dict[str, Dict[str, float]]]:
5757
"""Collect data for a given policy.
5858
5959
Args:
6060
policy_path: the path to the policy directory to collect data with.
61+
model_id: the id of the model used to collect data.
6162
6263
Returns:
6364
An iterator of batched trajectory.Trajectory that are ready to be fed to
@@ -126,7 +127,7 @@ def wait(self, get_num_finished_work):
126127
127128
Args:
128129
get_num_finished_work: a callable object which returns the amount of
129-
finished work.
130+
finished work.
130131
131132
Returns:
132133
The amount of time waited.

compiler_opt/rl/local_data_collector.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,7 @@ def _join_pending_jobs(self):
101101
time.time() - t1)
102102

103103
def _schedule_jobs(
104-
self, policy: policy_saver.Policy,
104+
self, policy: policy_saver.Policy, model_id: int,
105105
sampled_modules: List[corpus.LoadedModuleSpec]
106106
) -> List[worker.WorkerFuture[compilation_runner.CompilationResult]]:
107107
# by now, all the pending work, which was signaled to cancel, must've
@@ -114,7 +114,7 @@ def _schedule_jobs(
114114
def work_factory(job):
115115

116116
def work(w: compilation_runner.CompilationRunnerStub):
117-
return w.collect_data(*job)
117+
return w.collect_data(*job, model_id=model_id)
118118

119119
return work
120120

@@ -124,7 +124,7 @@ def work(w: compilation_runner.CompilationRunnerStub):
124124
work, self._workers, self._worker_pool.get_worker_concurrency())
125125

126126
def collect_data(
127-
self, policy: policy_saver.Policy
127+
self, policy: policy_saver.Policy, model_id: int
128128
) -> Tuple[Iterator[trajectory.Trajectory], Dict[str, Dict[str, float]]]:
129129
"""Collect data for a given policy.
130130
@@ -145,7 +145,8 @@ def collect_data(
145145
logging.info('resolving prefetched sample took: %d seconds',
146146
time.time() - time1)
147147
self._next_sample = self._prefetch_next_sample()
148-
self._current_futures = self._schedule_jobs(policy, sampled_modules)
148+
self._current_futures = self._schedule_jobs(policy, model_id,
149+
sampled_modules)
149150

150151
def wait_for_termination():
151152
early_exit = self._exit_checker_ctor(num_modules=self._num_modules)

compiler_opt/rl/local_data_collector_test.py

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ def _get_sequence_example(feature_value):
5252

5353

5454
def mock_collect_data(loaded_module_spec: corpus.LoadedModuleSpec, policy,
55-
reward_stat):
55+
reward_stat, model_id):
5656
assert loaded_module_spec.name.startswith('dummy')
5757
assert policy.policy == _policy_str
5858
assert reward_stat is None or reward_stat == {
@@ -70,7 +70,8 @@ def mock_collect_data(loaded_module_spec: corpus.LoadedModuleSpec, policy,
7070
},
7171
rewards=[1.2],
7272
policy_rewards=[36],
73-
keys=['default'])
73+
keys=['default'],
74+
model_id=model_id)
7475
else:
7576
return compilation_runner.CompilationResult(
7677
sequence_examples=[_get_sequence_example(feature_value=2)],
@@ -81,13 +82,14 @@ def mock_collect_data(loaded_module_spec: corpus.LoadedModuleSpec, policy,
8182
},
8283
rewards=[3.4],
8384
policy_rewards=[18],
84-
keys=['default'])
85+
keys=['default'],
86+
model_id=model_id)
8587

8688

8789
class Sleeper(compilation_runner.CompilationRunner):
8890
"""Test CompilationRunner that just sleeps."""
8991

90-
def collect_data(self, loaded_module_spec, policy, reward_stat):
92+
def collect_data(self, loaded_module_spec, policy, reward_stat, model_id):
9193
_ = loaded_module_spec, policy, reward_stat
9294
compilation_runner.start_cancellable_process(['sleep', '3600s'], 3600,
9395
self._cancellation_manager)
@@ -97,7 +99,8 @@ def collect_data(self, loaded_module_spec, policy, reward_stat):
9799
reward_stats={},
98100
rewards=[],
99101
policy_rewards=[],
100-
keys=[])
102+
keys=[],
103+
model_id=model_id)
101104

102105

103106
class MyRunner(compilation_runner.CompilationRunner):
@@ -166,7 +169,8 @@ def _test_iterator_fn(data_list):
166169
# we'll re-sample to prefetch the next batch.
167170
sampler.reset()
168171

169-
data_iterator, monitor_dict = collector.collect_data(policy=_mock_policy)
172+
data_iterator, monitor_dict = collector.collect_data(
173+
policy=_mock_policy, model_id=0)
170174
data = list(data_iterator)
171175
self.assertEqual([1, 2, 3], data)
172176
expected_monitor_dict_subset = {
@@ -184,7 +188,8 @@ def _test_iterator_fn(data_list):
184188
**monitor_dict,
185189
**expected_monitor_dict_subset
186190
})
187-
data_iterator, monitor_dict = collector.collect_data(policy=_mock_policy)
191+
data_iterator, monitor_dict = collector.collect_data(
192+
policy=_mock_policy, model_id=0)
188193
data = list(data_iterator)
189194
# because we reset the sampler, these are the same modules
190195
self.assertEqual([4, 5, 6], data)
@@ -233,7 +238,7 @@ def wait(self, _):
233238
reward_stat_map=collections.defaultdict(lambda: None),
234239
best_trajectory_repo=None,
235240
exit_checker_ctor=QuickExiter)
236-
collector.collect_data(policy=_mock_policy)
241+
collector.collect_data(policy=_mock_policy, model_id=0)
237242
collector._join_pending_jobs()
238243
killed = 0
239244
for w in collector._current_futures:

compiler_opt/rl/train_locally.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -176,7 +176,8 @@ def sequence_example_iterator_fn(seq_ex: List[str]):
176176

177177
dataset_iter, monitor_dict = data_collector.collect_data(
178178
policy=policy_saver.Policy.from_filesystem(
179-
os.path.join(policy_path, deploy_policy_name)))
179+
os.path.join(policy_path, deploy_policy_name)),
180+
model_id=llvm_trainer.global_step_numpy())
180181
llvm_trainer.train(dataset_iter, monitor_dict, num_iterations)
181182

182183
data_collector.on_dataset_consumed(dataset_iter)

compiler_opt/tools/generate_default_trace.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,8 @@ def worker(policy_path: Optional[str],
108108
data = runner.collect_data(
109109
loaded_module_spec=loaded_module_spec,
110110
policy=policy,
111-
reward_stat=None)
111+
reward_stat=None,
112+
model_id=0)
112113
if not m:
113114
results_queue.put(
114115
(loaded_module_spec.name, data.serialized_sequence_examples,

compiler_opt/tools/generate_default_trace_test.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@
3636
class MockCompilationRunner(compilation_runner.CompilationRunner):
3737
"""A compilation runner just for test."""
3838

39-
def collect_data(self, loaded_module_spec, policy, reward_stat):
39+
def collect_data(self, loaded_module_spec, policy, reward_stat, model_id):
4040
sequence_example_text = """
4141
feature_lists {
4242
feature_list {
@@ -59,7 +59,8 @@ def collect_data(self, loaded_module_spec, policy, reward_stat):
5959
},
6060
rewards=[1.2],
6161
policy_rewards=[18],
62-
keys=['default'])
62+
keys=['default'],
63+
model_id=model_id)
6364

6465

6566
class GenerateDefaultTraceTest(absltest.TestCase):

0 commit comments

Comments
 (0)