Skip to content

Commit 8b0d885

Browse files
authored
add the support for dumping best trajectories during training (#153)
* add the option to dump best trajectories during training * track->dump * fix test * typo
1 parent 12c239b commit 8b0d885

File tree

6 files changed

+60
-17
lines changed

6 files changed

+60
-17
lines changed

compiler_opt/rl/compilation_runner.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -25,12 +25,13 @@
2525

2626
from absl import flags
2727
from absl import logging
28+
import tensorflow as tf
29+
2830
from compiler_opt.distributed.worker import Worker
2931
from compiler_opt.distributed.worker import WorkerFuture
3032
from compiler_opt.rl import constant
31-
from compiler_opt.rl import policy_saver
3233
from compiler_opt.rl import corpus
33-
import tensorflow as tf
34+
from compiler_opt.rl import policy_saver
3435

3536
_COMPILATION_TIMEOUT = flags.DEFINE_integer(
3637
'compilation_timeout', 60,
@@ -219,11 +220,12 @@ class CompilationResult:
219220
length: total length of all sequence examples, derived from sequence_examples.
220221
reward_stats: a dictionary from keys (e.g. function names) to a RewardStat.
221222
rewards: a list of reward values.
223+
policy_rewards: a list of reward values under policy.
222224
keys: a list of keys.
223225
224226
The object must observe the following invariants:
225-
1) The entries of sequence_examples, rewards, and keys correspond to eachoter
226-
at the same index.
227+
1) The entries of sequence_examples, rewards, policy_rewards and keys
228+
correspond to each other at the same index.
227229
228230
2) The keys in reward stats are those in the keys field.
229231
"""
@@ -232,6 +234,7 @@ class CompilationResult:
232234
length: int = dataclasses.field(init=False)
233235
reward_stats: Dict[str, RewardStat]
234236
rewards: List[float]
237+
policy_rewards: List[float]
235238
keys: List[str]
236239

237240
def __post_init__(self, sequence_examples: List[tf.train.SequenceExample]):
@@ -243,8 +246,8 @@ def __post_init__(self, sequence_examples: List[tf.train.SequenceExample]):
243246
]
244247
object.__setattr__(self, 'length', sum(lengths))
245248

246-
assert (len(self.serialized_sequence_examples) == len(self.rewards) ==
247-
(len(self.keys)))
249+
assert (len(self.serialized_sequence_examples) == len(self.rewards) == len(
250+
self.policy_rewards) == len(self.keys))
248251
assert set(self.keys) == set(self.reward_stats.keys())
249252
assert not hasattr(self, 'sequence_examples')
250253

@@ -356,6 +359,7 @@ def collect_data(
356359

357360
sequence_example_list = []
358361
rewards = []
362+
policy_rewards = []
359363
keys = []
360364
for k, v in policy_result.items():
361365
sequence_example = v[0]
@@ -376,12 +380,14 @@ def collect_data(
376380
policy_reward * (1 - self._moving_average_decay_rate))
377381
rewards.append(
378382
_calculate_reward(policy=policy_reward, baseline=default_reward))
383+
policy_rewards.append(policy_reward)
379384
keys.append(k)
380385

381386
return CompilationResult(
382387
sequence_examples=sequence_example_list,
383388
reward_stats=reward_stat,
384389
rewards=rewards,
390+
policy_rewards=policy_rewards,
385391
keys=keys)
386392

387393
def compile_fn(

compiler_opt/rl/compilation_runner_test.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -134,6 +134,7 @@ def test_policy(self, mock_compile_fn):
134134
(1 - _MOVING_AVERAGE_DECAY_RATE))
135135
}, data.reward_stats)
136136
self.assertAllClose([0.1998002], data.rewards)
137+
self.assertAllClose([8], data.policy_rewards)
137138

138139
@mock.patch(constant.BASE_MODULE_DIR +
139140
'.compilation_runner.CompilationRunner.compile_fn')
@@ -162,6 +163,7 @@ def test_default(self, mock_compile_fn):
162163
moving_average_reward=_DEFAULT_REWARD)
163164
}, data.reward_stats)
164165
self.assertAllClose([0], data.rewards)
166+
self.assertAllClose([10], data.policy_rewards)
165167

166168
@mock.patch(constant.BASE_MODULE_DIR +
167169
'.compilation_runner.CompilationRunner.compile_fn')
@@ -197,6 +199,7 @@ def test_given_default_size(self, mock_compile_fn):
197199
_POLICY_REWARD * (1 - _MOVING_AVERAGE_DECAY_RATE))
198200
}, data.reward_stats)
199201
self.assertAllClose([0.199800], data.rewards)
202+
self.assertAllClose([8], data.policy_rewards)
200203

201204
@mock.patch(constant.BASE_MODULE_DIR +
202205
'.compilation_runner.CompilationRunner.compile_fn')

compiler_opt/rl/local_data_collector.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,13 +17,14 @@
1717
import concurrent.futures
1818
import itertools
1919
import time
20-
from typing import Callable, Dict, Iterator, List, Tuple, Optional
20+
from typing import Callable, Dict, Iterator, List, Optional, Tuple
2121

2222
from absl import logging
2323
from tf_agents.trajectories import trajectory
2424

2525
from compiler_opt.distributed import worker
2626
from compiler_opt.distributed import buffered_scheduler
27+
from compiler_opt.rl import best_trajectory
2728
from compiler_opt.rl import compilation_runner
2829
from compiler_opt.rl import corpus
2930
from compiler_opt.rl import data_collector
@@ -41,6 +42,7 @@ def __init__(
4142
parser: Callable[[List[str]], Iterator[trajectory.Trajectory]],
4243
reward_stat_map: Dict[str, Optional[Dict[str,
4344
compilation_runner.RewardStat]]],
45+
best_trajectory_repo: Optional[best_trajectory.BestTrajectoryRepo],
4446
exit_checker_ctor=data_collector.EarlyExitChecker):
4547
# TODO(mtrofin): type exit_checker_ctor when we get typing.Protocol support
4648
super().__init__()
@@ -53,6 +55,7 @@ def __init__(
5355
compilation_runner
5456
.CompilationRunnerStub] = self._worker_pool.get_currently_active()
5557
self._reward_stat_map = reward_stat_map
58+
self._best_trajectory_repo = best_trajectory_repo
5659
self._exit_checker_ctor = exit_checker_ctor
5760
# _reset_workers is a future that resolves when post-data collection cleanup
5861
# work completes, i.e. cancelling all work and re-enabling the workers.
@@ -126,7 +129,7 @@ def collect_data(
126129
"""Collect data for a given policy.
127130
128131
Args:
129-
policy_path: the path to the policy directory to collect data with.
132+
policy: a policy_saver.Policy object to collect data with.
130133
131134
Returns:
132135
An iterator of batched trajectory.Trajectory that are ready to be fed to
@@ -183,6 +186,15 @@ def wrapup():
183186
self._reward_stat_map.update(
184187
{spec.name: res.reward_stats for (spec, res) in successful_work})
185188

189+
if self._best_trajectory_repo is not None:
190+
for spec, res in successful_work:
191+
module_name = spec.name
192+
for (identifier, reward,
193+
sequence_example) in zip(res.keys, res.policy_rewards,
194+
res.serialized_sequence_examples):
195+
self._best_trajectory_repo.update_if_better_trajectory(
196+
module_name, identifier, reward, sequence_example)
197+
186198
monitor_dict = {}
187199
monitor_dict['default'] = {
188200
'success_modules': len(successful_work),

compiler_opt/rl/local_data_collector_test.py

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -16,24 +16,22 @@
1616

1717
# pylint: disable=protected-access
1818
import collections
19-
2019
import string
2120
import sys
21+
from typing import List, Tuple
2222

2323
import tensorflow as tf
2424
from tf_agents.system import system_multiprocessing as multiprocessing
2525

26+
# This is https://github.com/google/pytype/issues/764
27+
from google.protobuf import text_format # pytype: disable=pyi-error
2628
from compiler_opt.distributed.local.local_worker_manager import LocalWorkerPoolManager
2729
from compiler_opt.rl import compilation_runner
2830
from compiler_opt.rl import corpus
2931
from compiler_opt.rl import data_collector
3032
from compiler_opt.rl import local_data_collector
3133
from compiler_opt.rl import policy_saver
3234

33-
# This is https://github.com/google/pytype/issues/764
34-
from google.protobuf import text_format # pytype: disable=pyi-error
35-
from typing import List, Tuple
36-
3735
_policy_str = 'policy'.encode(encoding='utf-8')
3836

3937
_mock_policy = policy_saver.Policy(output_spec=bytes(), policy=_policy_str)
@@ -71,6 +69,7 @@ def mock_collect_data(loaded_module_spec: corpus.LoadedModuleSpec, policy,
7169
default_reward=1, moving_average_reward=2)
7270
},
7371
rewards=[1.2],
72+
policy_rewards=[36],
7473
keys=['default'])
7574
else:
7675
return compilation_runner.CompilationResult(
@@ -81,6 +80,7 @@ def mock_collect_data(loaded_module_spec: corpus.LoadedModuleSpec, policy,
8180
default_reward=1, moving_average_reward=3)
8281
},
8382
rewards=[3.4],
83+
policy_rewards=[18],
8484
keys=['default'])
8585

8686

@@ -93,7 +93,11 @@ def collect_data(self, loaded_module_spec, policy, reward_stat):
9393
self._cancellation_manager)
9494

9595
return compilation_runner.CompilationResult(
96-
sequence_examples=[], reward_stats={}, rewards=[], keys=[])
96+
sequence_examples=[],
97+
reward_stats={},
98+
rewards=[],
99+
policy_rewards=[],
100+
keys=[])
97101

98102

99103
class MyRunner(compilation_runner.CompilationRunner):
@@ -154,7 +158,8 @@ def _test_iterator_fn(data_list):
154158
num_modules=9,
155159
worker_pool=lwp,
156160
parser=create_test_iterator_fn(),
157-
reward_stat_map=collections.defaultdict(lambda: None))
161+
reward_stat_map=collections.defaultdict(lambda: None),
162+
best_trajectory_repo=None)
158163

159164
# reset the sampler, so the next time we collect, we collect the same
160165
# modules. We do it before the collect_data call, because that's when
@@ -226,6 +231,7 @@ def wait(self, _):
226231
worker_pool=lwp,
227232
parser=parser,
228233
reward_stat_map=collections.defaultdict(lambda: None),
234+
best_trajectory_repo=None,
229235
exit_checker_ctor=QuickExiter)
230236
collector.collect_data(policy=_mock_policy)
231237
collector._join_pending_jobs()

compiler_opt/rl/train_locally.py

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
import json
2020
import os
2121
import time
22+
from typing import List
2223

2324
from absl import app
2425
from absl import flags
@@ -27,10 +28,10 @@
2728
import tensorflow as tf
2829
from tf_agents.agents import tf_agent
2930
from tf_agents.system import system_multiprocessing as multiprocessing
30-
from typing import List
3131

3232
from compiler_opt.distributed.local.local_worker_manager import LocalWorkerPoolManager
3333
from compiler_opt.rl import agent_creators
34+
from compiler_opt.rl import best_trajectory
3435
from compiler_opt.rl import compilation_runner
3536
from compiler_opt.rl import constant
3637
from compiler_opt.rl import corpus
@@ -69,6 +70,7 @@ def train_eval(worker_manager_class=LocalWorkerPoolManager,
6970
train_sequence_length=1,
7071
deploy_policy_name='saved_policy',
7172
use_random_network_distillation=False,
73+
dump_best_trajectory=False,
7274
moving_average_decay_rate=1):
7375
"""Train for LLVM inliner."""
7476
root_dir = FLAGS.root_dir
@@ -134,6 +136,15 @@ def sequence_example_iterator_fn(seq_ex: List[str]):
134136
logging.info('Loaded Reward Stat Map from disk, containing %d modules',
135137
len(reward_stat_map))
136138

139+
best_trajectory_repo = None
140+
best_trajecroty_repo_path = os.path.join(root_dir,
141+
'best_trajectory_repo.json')
142+
if dump_best_trajectory:
143+
best_trajectory_repo = best_trajectory.BestTrajectoryRepo(
144+
action_name=action_spec.name)
145+
if tf.io.gfile.exists(best_trajecroty_repo_path):
146+
best_trajectory_repo.load_from_json_file(best_trajecroty_repo_path)
147+
137148
with worker_manager_class(
138149
worker_class=problem_config.get_runner_type(),
139150
count=FLAGS.num_workers,
@@ -143,7 +154,8 @@ def sequence_example_iterator_fn(seq_ex: List[str]):
143154
num_modules=num_modules,
144155
worker_pool=worker_pool,
145156
parser=sequence_example_iterator_fn,
146-
reward_stat_map=reward_stat_map)
157+
reward_stat_map=reward_stat_map,
158+
best_trajectory_repo=best_trajectory_repo)
147159

148160
# Repeat for num_policy_iterations iterations.
149161
t1 = time.time()
@@ -155,6 +167,9 @@ def sequence_example_iterator_fn(seq_ex: List[str]):
155167
with tf.io.gfile.GFile(reward_stat_map_path, 'w') as f:
156168
json.dump(reward_stat_map, f, cls=constant.DataClassJSONEncoder)
157169

170+
if best_trajectory_repo is not None:
171+
best_trajectory_repo.sink_to_json_file(best_trajecroty_repo_path)
172+
158173
policy_path = os.path.join(root_dir, 'policy',
159174
str(llvm_trainer.global_step_numpy()))
160175
saver.save(policy_path)

compiler_opt/tools/generate_default_trace_test.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@ def collect_data(self, loaded_module_spec, policy, reward_stat):
5858
default_reward=1, moving_average_reward=2)
5959
},
6060
rewards=[1.2],
61+
policy_rewards=[18],
6162
keys=['default'])
6263

6364

0 commit comments

Comments
 (0)