diff --git a/compiler_opt/rl/compilation_runner.py b/compiler_opt/rl/compilation_runner.py index b0db6dc5..1496086d 100644 --- a/compiler_opt/rl/compilation_runner.py +++ b/compiler_opt/rl/compilation_runner.py @@ -219,6 +219,7 @@ def __post_init__(self, sequence_examples: List[tf.train.SequenceExample]): ] object.__setattr__(self, 'length', sum(lengths)) + # TODO: is it necessary to return keys AND reward_stats(which has the keys)? assert (len(self.serialized_sequence_examples) == len(self.rewards) == (len(self.keys))) assert set(self.keys) == set(self.reward_stats.keys()) @@ -228,6 +229,15 @@ def __post_init__(self, sequence_examples: List[tf.train.SequenceExample]): class CompilationRunnerStub(metaclass=abc.ABCMeta): """The interface of a stub to CompilationRunner, for type checkers.""" + @abc.abstractmethod + def collect_results( + self, + module_spec: corpus.ModuleSpec, + tf_policy_path: str, + collect_default_result: bool, + reward_only: bool = False) -> Tuple[Optional[Dict], Optional[Dict]]: + raise NotImplementedError() + @abc.abstractmethod def collect_data( self, module_spec: corpus.ModuleSpec, tf_policy_path: str, @@ -275,6 +285,46 @@ def enable(self): def cancel_all_work(self): self._cancellation_manager.kill_all_processes() + @staticmethod + def get_rewards(result: Dict) -> List[float]: + return [v[1] for v in result.values()] + + def collect_results( + self, + module_spec: corpus.ModuleSpec, + tf_policy_path: str, + collect_default_result: bool, + reward_only: bool = False) -> Tuple[Optional[Dict], Optional[Dict]]: + """Collect data for the given IR file and policy. + + Args: + module_spec: a ModuleSpec. + tf_policy_path: path to the tensorflow policy. + collect_default_result: whether to get the default result as well. + reward_only: whether to only collect the rewards in the results. + + Returns: + A tuple of the default result and policy result. + """ + default_result = None + policy_result = None + if collect_default_result: + default_result = self._compile_fn( + module_spec, + tf_policy_path='', + reward_only=bool(tf_policy_path) or reward_only, + cancellation_manager=self._cancellation_manager) + policy_result = default_result + + if tf_policy_path: + policy_result = self._compile_fn( + module_spec, + tf_policy_path, + reward_only=reward_only, + cancellation_manager=self._cancellation_manager) + + return default_result, policy_result + def collect_data( self, module_spec: corpus.ModuleSpec, tf_policy_path: str, reward_stat: Optional[Dict[str, RewardStat]]) -> CompilationResult: @@ -284,8 +334,6 @@ def collect_data( module_spec: a ModuleSpec. tf_policy_path: path to the tensorflow policy. reward_stat: reward stat of this module, None if unknown. - cancellation_token: a CancellationToken through which workers may be - signaled early termination Returns: A CompilationResult. In particular: @@ -297,25 +345,18 @@ def collect_data( compilation_runner.ProcessKilledException is passed through. ValueError if example under default policy and ml policy does not match. """ + default_result, policy_result = self.collect_results( + module_spec, + tf_policy_path, + collect_default_result=reward_stat is None, + reward_only=False) if reward_stat is None: - default_result = self._compile_fn( - module_spec, - tf_policy_path='', - reward_only=bool(tf_policy_path), - cancellation_manager=self._cancellation_manager) + # TODO: Add structure to default_result and policy_result. + # get_rewards above should be updated/removed when this is resolved. reward_stat = { k: RewardStat(v[1], v[1]) for (k, v) in default_result.items() } - if tf_policy_path: - policy_result = self._compile_fn( - module_spec, - tf_policy_path, - reward_only=False, - cancellation_manager=self._cancellation_manager) - else: - policy_result = default_result - sequence_example_list = [] rewards = [] keys = [] diff --git a/compiler_opt/rl/compilation_runner_test.py b/compiler_opt/rl/compilation_runner_test.py index d2c930b3..0f986ca1 100644 --- a/compiler_opt/rl/compilation_runner_test.py +++ b/compiler_opt/rl/compilation_runner_test.py @@ -162,6 +162,29 @@ def test_default(self, mock_compile_fn): }, data.reward_stats) self.assertAllClose([0], data.rewards) + @mock.patch(constant.BASE_MODULE_DIR + + '.compilation_runner.CompilationRunner._compile_fn') + def test_reward_only(self, mock_compile_fn): + mock_compile_fn.side_effect = _mock_compile_fn + runner = compilation_runner.CompilationRunner( + moving_average_decay_rate=_MOVING_AVERAGE_DECAY_RATE) + default_result, policy_result = runner.collect_results( + module_spec=corpus.ModuleSpec(name='dummy'), + tf_policy_path='policy_path', + collect_default_result=True, + reward_only=True) + self.assertEqual(2, mock_compile_fn.call_count) + + self.assertIsNotNone(default_result) + self.assertIsNotNone(policy_result) + + self.assertEqual( + [_DEFAULT_REWARD], + compilation_runner.CompilationRunner.get_rewards(default_result)) + self.assertEqual( + [_POLICY_REWARD], + compilation_runner.CompilationRunner.get_rewards(policy_result)) + @mock.patch(constant.BASE_MODULE_DIR + '.compilation_runner.CompilationRunner._compile_fn') def test_given_default_size(self, mock_compile_fn):