Skip to content

Commit 096d1fb

Browse files
committed
Add raw_reward_only to runners
- Will be useful for validation runner and generate default trace - Allows None values in sequence_examples
1 parent 99671e4 commit 096d1fb

File tree

4 files changed

+60
-15
lines changed

4 files changed

+60
-15
lines changed

compiler_opt/rl/compilation_runner.py

Lines changed: 25 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -211,16 +211,20 @@ class CompilationResult:
211211
keys: List[str]
212212

213213
def __post_init__(self, sequence_examples: List[tf.train.SequenceExample]):
214-
object.__setattr__(self, 'serialized_sequence_examples',
215-
[x.SerializeToString() for x in sequence_examples])
214+
object.__setattr__(
215+
self, 'serialized_sequence_examples',
216+
[x.SerializeToString() for x in sequence_examples if x is not None])
216217
lengths = [
217218
len(next(iter(x.feature_lists.feature_list.values())).feature)
218219
for x in sequence_examples
220+
if x is not None
219221
]
220222
object.__setattr__(self, 'length', sum(lengths))
221223

222-
assert (len(self.serialized_sequence_examples) == len(self.rewards) ==
223-
(len(self.keys)))
224+
# TODO: is it necessary to return keys AND reward_stats(which has the keys)?
225+
# sequence_examples' length could also just not be checked, this allows
226+
# raw_reward_only to do less work
227+
assert (len(sequence_examples) == len(self.rewards) == (len(self.keys)))
224228
assert set(self.keys) == set(self.reward_stats.keys())
225229
assert not hasattr(self, 'sequence_examples')
226230

@@ -230,9 +234,11 @@ class CompilationRunnerStub(metaclass=abc.ABCMeta):
230234

231235
@abc.abstractmethod
232236
def collect_data(
233-
self, module_spec: corpus.ModuleSpec, tf_policy_path: str,
234-
reward_stat: Optional[Dict[str, RewardStat]]
235-
) -> WorkerFuture[CompilationResult]:
237+
self,
238+
module_spec: corpus.ModuleSpec,
239+
tf_policy_path: str,
240+
reward_stat: Optional[Dict[str, RewardStat]],
241+
raw_reward_only: bool = False) -> WorkerFuture[CompilationResult]:
236242
raise NotImplementedError()
237243

238244
@abc.abstractmethod
@@ -275,17 +281,18 @@ def enable(self):
275281
def cancel_all_work(self):
276282
self._cancellation_manager.kill_all_processes()
277283

278-
def collect_data(
279-
self, module_spec: corpus.ModuleSpec, tf_policy_path: str,
280-
reward_stat: Optional[Dict[str, RewardStat]]) -> CompilationResult:
284+
def collect_data(self,
285+
module_spec: corpus.ModuleSpec,
286+
tf_policy_path: str,
287+
reward_stat: Optional[Dict[str, RewardStat]],
288+
raw_reward_only=False) -> CompilationResult:
281289
"""Collect data for the given IR file and policy.
282290
283291
Args:
284292
module_spec: a ModuleSpec.
285293
tf_policy_path: path to the tensorflow policy.
286294
reward_stat: reward stat of this module, None if unknown.
287-
cancellation_token: a CancellationToken through which workers may be
288-
signaled early termination
295+
raw_reward_only: whether to return the raw reward value without examples.
289296
290297
Returns:
291298
A CompilationResult. In particular:
@@ -311,7 +318,7 @@ def collect_data(
311318
policy_result = self._compile_fn(
312319
module_spec,
313320
tf_policy_path,
314-
reward_only=False,
321+
reward_only=raw_reward_only,
315322
cancellation_manager=self._cancellation_manager)
316323
else:
317324
policy_result = default_result
@@ -326,6 +333,11 @@ def collect_data(
326333
raise ValueError(
327334
(f'Example {k} does not exist under default policy for '
328335
f'module {module_spec.name}'))
336+
if raw_reward_only:
337+
sequence_example_list.append(None)
338+
rewards.append(policy_reward)
339+
keys.append(k)
340+
continue
329341
default_reward = reward_stat[k].default_reward
330342
moving_average_reward = reward_stat[k].moving_average_reward
331343
sequence_example = _overwrite_trajectory_reward(

compiler_opt/rl/compilation_runner_test.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -162,6 +162,31 @@ def test_default(self, mock_compile_fn):
162162
}, data.reward_stats)
163163
self.assertAllClose([0], data.rewards)
164164

165+
@mock.patch(constant.BASE_MODULE_DIR +
166+
'.compilation_runner.CompilationRunner._compile_fn')
167+
def test_raw_reward_only(self, mock_compile_fn):
168+
mock_compile_fn.side_effect = _mock_compile_fn
169+
runner = compilation_runner.CompilationRunner(
170+
moving_average_decay_rate=_MOVING_AVERAGE_DECAY_RATE)
171+
data = runner.collect_data(
172+
module_spec=corpus.ModuleSpec(name='dummy'),
173+
tf_policy_path='policy_path',
174+
reward_stat=None,
175+
raw_reward_only=True)
176+
self.assertEqual(2, mock_compile_fn.call_count)
177+
178+
self.assertLen(data.serialized_sequence_examples, 0)
179+
180+
self.assertEqual(0, data.length)
181+
self.assertCountEqual(
182+
{
183+
'default':
184+
compilation_runner.RewardStat(
185+
default_reward=_DEFAULT_REWARD,
186+
moving_average_reward=_DEFAULT_REWARD)
187+
}, data.reward_stats)
188+
self.assertAllClose([_POLICY_REWARD], data.rewards)
189+
165190
@mock.patch(constant.BASE_MODULE_DIR +
166191
'.compilation_runner.CompilationRunner._compile_fn')
167192
def test_given_default_size(self, mock_compile_fn):

compiler_opt/rl/local_data_collector_test.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,11 @@ def mock_collect_data(module_spec, tf_policy_dir, reward_stat):
8080
class Sleeper(compilation_runner.CompilationRunner):
8181
"""Test CompilationRunner that just sleeps."""
8282

83-
def collect_data(self, module_spec, tf_policy_path, reward_stat):
83+
def collect_data(self,
84+
module_spec,
85+
tf_policy_path,
86+
reward_stat,
87+
raw_reward_only=False):
8488
_ = module_spec, tf_policy_path, reward_stat
8589
compilation_runner.start_cancellable_process(['sleep', '3600s'], 3600,
8690
self._cancellation_manager)

compiler_opt/tools/generate_default_trace_test.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,11 @@
3434
class MockCompilationRunner(compilation_runner.CompilationRunner):
3535
"""A compilation runner just for test."""
3636

37-
def collect_data(self, module_spec, tf_policy_path, reward_stat):
37+
def collect_data(self,
38+
module_spec,
39+
tf_policy_path,
40+
reward_stat,
41+
raw_reward_only=False):
3842
sequence_example_text = """
3943
feature_lists {
4044
feature_list {

0 commit comments

Comments
 (0)