Skip to content

Commit 468adc3

Browse files
committed
Add raw_reward_only to runners
- Will be useful for validation runner and generate default trace - Allows None values in sequence_examples
1 parent 99671e4 commit 468adc3

File tree

2 files changed

+80
-16
lines changed

2 files changed

+80
-16
lines changed

compiler_opt/rl/compilation_runner.py

Lines changed: 57 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -219,6 +219,7 @@ def __post_init__(self, sequence_examples: List[tf.train.SequenceExample]):
219219
]
220220
object.__setattr__(self, 'length', sum(lengths))
221221

222+
# TODO: is it necessary to return keys AND reward_stats(which has the keys)?
222223
assert (len(self.serialized_sequence_examples) == len(self.rewards) ==
223224
(len(self.keys)))
224225
assert set(self.keys) == set(self.reward_stats.keys())
@@ -228,6 +229,14 @@ def __post_init__(self, sequence_examples: List[tf.train.SequenceExample]):
228229
class CompilationRunnerStub(metaclass=abc.ABCMeta):
229230
"""The interface of a stub to CompilationRunner, for type checkers."""
230231

232+
@abc.abstractmethod
233+
def collect_results(self,
234+
module_spec: corpus.ModuleSpec,
235+
tf_policy_path: str,
236+
collect_default_result: bool,
237+
reward_only: bool = False) -> Tuple[Dict, Dict]:
238+
raise NotImplementedError()
239+
231240
@abc.abstractmethod
232241
def collect_data(
233242
self, module_spec: corpus.ModuleSpec, tf_policy_path: str,
@@ -275,6 +284,47 @@ def enable(self):
275284
def cancel_all_work(self):
276285
self._cancellation_manager.kill_all_processes()
277286

287+
@staticmethod
288+
def get_rewards(result: Dict) -> List[float]:
289+
if len(result) == 0:
290+
return []
291+
return [v[1] for v in result.values()]
292+
293+
def collect_results(self,
294+
module_spec: corpus.ModuleSpec,
295+
tf_policy_path: str,
296+
collect_default_result: bool,
297+
reward_only: bool = False) -> Tuple[Dict, Dict]:
298+
"""Collect data for the given IR file and policy.
299+
300+
Args:
301+
module_spec: a ModuleSpec.
302+
tf_policy_path: path to the tensorflow policy.
303+
collect_default_result: whether to get the default result as well.
304+
reward_only: whether to only collect the rewards in the results.
305+
306+
Returns:
307+
A tuple of the default result and policy result.
308+
"""
309+
default_result = None
310+
policy_result = None
311+
if collect_default_result:
312+
default_result = self._compile_fn(
313+
module_spec,
314+
tf_policy_path='',
315+
reward_only=bool(tf_policy_path) or reward_only,
316+
cancellation_manager=self._cancellation_manager)
317+
policy_result = default_result
318+
319+
if tf_policy_path:
320+
policy_result = self._compile_fn(
321+
module_spec,
322+
tf_policy_path,
323+
reward_only=reward_only,
324+
cancellation_manager=self._cancellation_manager)
325+
326+
return default_result, policy_result
327+
278328
def collect_data(
279329
self, module_spec: corpus.ModuleSpec, tf_policy_path: str,
280330
reward_stat: Optional[Dict[str, RewardStat]]) -> CompilationResult:
@@ -284,8 +334,6 @@ def collect_data(
284334
module_spec: a ModuleSpec.
285335
tf_policy_path: path to the tensorflow policy.
286336
reward_stat: reward stat of this module, None if unknown.
287-
cancellation_token: a CancellationToken through which workers may be
288-
signaled early termination
289337
290338
Returns:
291339
A CompilationResult. In particular:
@@ -297,25 +345,18 @@ def collect_data(
297345
compilation_runner.ProcessKilledException is passed through.
298346
ValueError if example under default policy and ml policy does not match.
299347
"""
348+
default_result, policy_result = self.collect_results(
349+
module_spec,
350+
tf_policy_path,
351+
collect_default_result=reward_stat is None,
352+
reward_only=False)
300353
if reward_stat is None:
301-
default_result = self._compile_fn(
302-
module_spec,
303-
tf_policy_path='',
304-
reward_only=bool(tf_policy_path),
305-
cancellation_manager=self._cancellation_manager)
354+
# TODO: Add structure to default_result and policy_result.
355+
# get_rewards above should be updated/removed when this is resolved.
306356
reward_stat = {
307357
k: RewardStat(v[1], v[1]) for (k, v) in default_result.items()
308358
}
309359

310-
if tf_policy_path:
311-
policy_result = self._compile_fn(
312-
module_spec,
313-
tf_policy_path,
314-
reward_only=False,
315-
cancellation_manager=self._cancellation_manager)
316-
else:
317-
policy_result = default_result
318-
319360
sequence_example_list = []
320361
rewards = []
321362
keys = []

compiler_opt/rl/compilation_runner_test.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -162,6 +162,29 @@ def test_default(self, mock_compile_fn):
162162
}, data.reward_stats)
163163
self.assertAllClose([0], data.rewards)
164164

165+
@mock.patch(constant.BASE_MODULE_DIR +
166+
'.compilation_runner.CompilationRunner._compile_fn')
167+
def test_reward_only(self, mock_compile_fn):
168+
mock_compile_fn.side_effect = _mock_compile_fn
169+
runner = compilation_runner.CompilationRunner(
170+
moving_average_decay_rate=_MOVING_AVERAGE_DECAY_RATE)
171+
default_result, policy_result = runner.collect_results(
172+
module_spec=corpus.ModuleSpec(name='dummy'),
173+
tf_policy_path='policy_path',
174+
collect_default_result=True,
175+
reward_only=True)
176+
self.assertEqual(2, mock_compile_fn.call_count)
177+
178+
self.assertIsNotNone(default_result)
179+
self.assertIsNotNone(policy_result)
180+
181+
self.assertEqual(
182+
[_DEFAULT_REWARD],
183+
compilation_runner.CompilationRunner.get_rewards(default_result))
184+
self.assertEqual(
185+
[_POLICY_REWARD],
186+
compilation_runner.CompilationRunner.get_rewards(policy_result))
187+
165188
@mock.patch(constant.BASE_MODULE_DIR +
166189
'.compilation_runner.CompilationRunner._compile_fn')
167190
def test_given_default_size(self, mock_compile_fn):

0 commit comments

Comments
 (0)