From 468adc30851d67d317bac25b880c1fa0f26dd4bb Mon Sep 17 00:00:00 2001 From: Jin Xin Ng Date: Thu, 11 Aug 2022 13:53:30 -0700 Subject: [PATCH 1/3] Add raw_reward_only to runners - Will be useful for validation runner and generate default trace - Allows None values in sequence_examples --- compiler_opt/rl/compilation_runner.py | 73 +++++++++++++++++----- compiler_opt/rl/compilation_runner_test.py | 23 +++++++ 2 files changed, 80 insertions(+), 16 deletions(-) diff --git a/compiler_opt/rl/compilation_runner.py b/compiler_opt/rl/compilation_runner.py index b0db6dc5..67e01b97 100644 --- a/compiler_opt/rl/compilation_runner.py +++ b/compiler_opt/rl/compilation_runner.py @@ -219,6 +219,7 @@ def __post_init__(self, sequence_examples: List[tf.train.SequenceExample]): ] object.__setattr__(self, 'length', sum(lengths)) + # TODO: is it necessary to return keys AND reward_stats(which has the keys)? assert (len(self.serialized_sequence_examples) == len(self.rewards) == (len(self.keys))) assert set(self.keys) == set(self.reward_stats.keys()) @@ -228,6 +229,14 @@ def __post_init__(self, sequence_examples: List[tf.train.SequenceExample]): class CompilationRunnerStub(metaclass=abc.ABCMeta): """The interface of a stub to CompilationRunner, for type checkers.""" + @abc.abstractmethod + def collect_results(self, + module_spec: corpus.ModuleSpec, + tf_policy_path: str, + collect_default_result: bool, + reward_only: bool = False) -> Tuple[Dict, Dict]: + raise NotImplementedError() + @abc.abstractmethod def collect_data( self, module_spec: corpus.ModuleSpec, tf_policy_path: str, @@ -275,6 +284,47 @@ def enable(self): def cancel_all_work(self): self._cancellation_manager.kill_all_processes() + @staticmethod + def get_rewards(result: Dict) -> List[float]: + if len(result) == 0: + return [] + return [v[1] for v in result.values()] + + def collect_results(self, + module_spec: corpus.ModuleSpec, + tf_policy_path: str, + collect_default_result: bool, + reward_only: bool = False) -> Tuple[Dict, Dict]: + """Collect data for the given IR file and policy. + + Args: + module_spec: a ModuleSpec. + tf_policy_path: path to the tensorflow policy. + collect_default_result: whether to get the default result as well. + reward_only: whether to only collect the rewards in the results. + + Returns: + A tuple of the default result and policy result. + """ + default_result = None + policy_result = None + if collect_default_result: + default_result = self._compile_fn( + module_spec, + tf_policy_path='', + reward_only=bool(tf_policy_path) or reward_only, + cancellation_manager=self._cancellation_manager) + policy_result = default_result + + if tf_policy_path: + policy_result = self._compile_fn( + module_spec, + tf_policy_path, + reward_only=reward_only, + cancellation_manager=self._cancellation_manager) + + return default_result, policy_result + def collect_data( self, module_spec: corpus.ModuleSpec, tf_policy_path: str, reward_stat: Optional[Dict[str, RewardStat]]) -> CompilationResult: @@ -284,8 +334,6 @@ def collect_data( module_spec: a ModuleSpec. tf_policy_path: path to the tensorflow policy. reward_stat: reward stat of this module, None if unknown. - cancellation_token: a CancellationToken through which workers may be - signaled early termination Returns: A CompilationResult. In particular: @@ -297,25 +345,18 @@ def collect_data( compilation_runner.ProcessKilledException is passed through. ValueError if example under default policy and ml policy does not match. """ + default_result, policy_result = self.collect_results( + module_spec, + tf_policy_path, + collect_default_result=reward_stat is None, + reward_only=False) if reward_stat is None: - default_result = self._compile_fn( - module_spec, - tf_policy_path='', - reward_only=bool(tf_policy_path), - cancellation_manager=self._cancellation_manager) + # TODO: Add structure to default_result and policy_result. + # get_rewards above should be updated/removed when this is resolved. reward_stat = { k: RewardStat(v[1], v[1]) for (k, v) in default_result.items() } - if tf_policy_path: - policy_result = self._compile_fn( - module_spec, - tf_policy_path, - reward_only=False, - cancellation_manager=self._cancellation_manager) - else: - policy_result = default_result - sequence_example_list = [] rewards = [] keys = [] diff --git a/compiler_opt/rl/compilation_runner_test.py b/compiler_opt/rl/compilation_runner_test.py index d2c930b3..0f986ca1 100644 --- a/compiler_opt/rl/compilation_runner_test.py +++ b/compiler_opt/rl/compilation_runner_test.py @@ -162,6 +162,29 @@ def test_default(self, mock_compile_fn): }, data.reward_stats) self.assertAllClose([0], data.rewards) + @mock.patch(constant.BASE_MODULE_DIR + + '.compilation_runner.CompilationRunner._compile_fn') + def test_reward_only(self, mock_compile_fn): + mock_compile_fn.side_effect = _mock_compile_fn + runner = compilation_runner.CompilationRunner( + moving_average_decay_rate=_MOVING_AVERAGE_DECAY_RATE) + default_result, policy_result = runner.collect_results( + module_spec=corpus.ModuleSpec(name='dummy'), + tf_policy_path='policy_path', + collect_default_result=True, + reward_only=True) + self.assertEqual(2, mock_compile_fn.call_count) + + self.assertIsNotNone(default_result) + self.assertIsNotNone(policy_result) + + self.assertEqual( + [_DEFAULT_REWARD], + compilation_runner.CompilationRunner.get_rewards(default_result)) + self.assertEqual( + [_POLICY_REWARD], + compilation_runner.CompilationRunner.get_rewards(policy_result)) + @mock.patch(constant.BASE_MODULE_DIR + '.compilation_runner.CompilationRunner._compile_fn') def test_given_default_size(self, mock_compile_fn): From 113c71ea02b3874d86f39bc881685acf73195210 Mon Sep 17 00:00:00 2001 From: Jin Xin Ng Date: Thu, 11 Aug 2022 16:34:25 -0700 Subject: [PATCH 2/3] fix pytype --- compiler_opt/rl/compilation_runner.py | 22 ++++++++++++---------- 1 file changed, 12 insertions(+), 10 deletions(-) diff --git a/compiler_opt/rl/compilation_runner.py b/compiler_opt/rl/compilation_runner.py index 67e01b97..bc9e2b87 100644 --- a/compiler_opt/rl/compilation_runner.py +++ b/compiler_opt/rl/compilation_runner.py @@ -230,11 +230,12 @@ class CompilationRunnerStub(metaclass=abc.ABCMeta): """The interface of a stub to CompilationRunner, for type checkers.""" @abc.abstractmethod - def collect_results(self, - module_spec: corpus.ModuleSpec, - tf_policy_path: str, - collect_default_result: bool, - reward_only: bool = False) -> Tuple[Dict, Dict]: + def collect_results( + self, + module_spec: corpus.ModuleSpec, + tf_policy_path: str, + collect_default_result: bool, + reward_only: bool = False) -> Tuple[Optional[Dict], Optional[Dict]]: raise NotImplementedError() @abc.abstractmethod @@ -290,11 +291,12 @@ def get_rewards(result: Dict) -> List[float]: return [] return [v[1] for v in result.values()] - def collect_results(self, - module_spec: corpus.ModuleSpec, - tf_policy_path: str, - collect_default_result: bool, - reward_only: bool = False) -> Tuple[Dict, Dict]: + def collect_results( + self, + module_spec: corpus.ModuleSpec, + tf_policy_path: str, + collect_default_result: bool, + reward_only: bool = False) -> Tuple[Optional[Dict], Optional[Dict]]: """Collect data for the given IR file and policy. Args: From 102d27fe6ff5e18ce65fa325fbfcb225b13d6912 Mon Sep 17 00:00:00 2001 From: Jin Xin Ng Date: Fri, 12 Aug 2022 10:26:00 -0700 Subject: [PATCH 3/3] remove redundant lines --- compiler_opt/rl/compilation_runner.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/compiler_opt/rl/compilation_runner.py b/compiler_opt/rl/compilation_runner.py index bc9e2b87..1496086d 100644 --- a/compiler_opt/rl/compilation_runner.py +++ b/compiler_opt/rl/compilation_runner.py @@ -287,8 +287,6 @@ def cancel_all_work(self): @staticmethod def get_rewards(result: Dict) -> List[float]: - if len(result) == 0: - return [] return [v[1] for v in result.values()] def collect_results(