diff --git a/compiler_opt/rl/compilation_runner.py b/compiler_opt/rl/compilation_runner.py index b0db6dc5..aa835de1 100644 --- a/compiler_opt/rl/compilation_runner.py +++ b/compiler_opt/rl/compilation_runner.py @@ -298,7 +298,7 @@ def collect_data( ValueError if example under default policy and ml policy does not match. """ if reward_stat is None: - default_result = self._compile_fn( + default_result = self.compile_fn( module_spec, tf_policy_path='', reward_only=bool(tf_policy_path), @@ -308,7 +308,7 @@ def collect_data( } if tf_policy_path: - policy_result = self._compile_fn( + policy_result = self.compile_fn( module_spec, tf_policy_path, reward_only=False, @@ -346,7 +346,7 @@ def collect_data( rewards=rewards, keys=keys) - def _compile_fn( + def compile_fn( self, module_spec: corpus.ModuleSpec, tf_policy_path: str, reward_only: bool, cancellation_manager: Optional[WorkerCancellationManager] diff --git a/compiler_opt/rl/compilation_runner_test.py b/compiler_opt/rl/compilation_runner_test.py index d2c930b3..28b7178f 100644 --- a/compiler_opt/rl/compilation_runner_test.py +++ b/compiler_opt/rl/compilation_runner_test.py @@ -102,7 +102,7 @@ def assertListProtoEqual(self, a, b): self.assertProtoEquals(x, y) @mock.patch(constant.BASE_MODULE_DIR + - '.compilation_runner.CompilationRunner._compile_fn') + '.compilation_runner.CompilationRunner.compile_fn') def test_policy(self, mock_compile_fn): mock_compile_fn.side_effect = _mock_compile_fn runner = compilation_runner.CompilationRunner( @@ -132,7 +132,7 @@ def test_policy(self, mock_compile_fn): self.assertAllClose([0.1998002], data.rewards) @mock.patch(constant.BASE_MODULE_DIR + - '.compilation_runner.CompilationRunner._compile_fn') + '.compilation_runner.CompilationRunner.compile_fn') def test_default(self, mock_compile_fn): mock_compile_fn.side_effect = _mock_compile_fn runner = compilation_runner.CompilationRunner( @@ -163,7 +163,7 @@ def test_default(self, mock_compile_fn): self.assertAllClose([0], data.rewards) @mock.patch(constant.BASE_MODULE_DIR + - '.compilation_runner.CompilationRunner._compile_fn') + '.compilation_runner.CompilationRunner.compile_fn') def test_given_default_size(self, mock_compile_fn): mock_compile_fn.side_effect = _mock_compile_fn runner = compilation_runner.CompilationRunner( @@ -198,7 +198,7 @@ def test_given_default_size(self, mock_compile_fn): self.assertAllClose([0.199800], data.rewards) @mock.patch(constant.BASE_MODULE_DIR + - '.compilation_runner.CompilationRunner._compile_fn') + '.compilation_runner.CompilationRunner.compile_fn') def test_exception_handling(self, mock_compile_fn): mock_compile_fn.side_effect = subprocess.CalledProcessError( returncode=1, cmd='error') diff --git a/compiler_opt/rl/inlining/inlining_runner.py b/compiler_opt/rl/inlining/inlining_runner.py index 6117bbc3..69730f0a 100644 --- a/compiler_opt/rl/inlining/inlining_runner.py +++ b/compiler_opt/rl/inlining/inlining_runner.py @@ -45,7 +45,7 @@ def __init__(self, llvm_size_path: str, *args, **kwargs): super().__init__(*args, **kwargs) self._llvm_size_path = llvm_size_path - def _compile_fn( + def compile_fn( self, module_spec: corpus.ModuleSpec, tf_policy_path: str, reward_only: bool, cancellation_manager: Optional[ compilation_runner.WorkerCancellationManager] diff --git a/compiler_opt/rl/regalloc/regalloc_runner.py b/compiler_opt/rl/regalloc/regalloc_runner.py index 4f58bbfe..ef6cf0ce 100644 --- a/compiler_opt/rl/regalloc/regalloc_runner.py +++ b/compiler_opt/rl/regalloc/regalloc_runner.py @@ -43,7 +43,7 @@ class RegAllocRunner(compilation_runner.CompilationRunner): # TODO: refactor file_paths parameter to ensure correctness during # construction - def _compile_fn( + def compile_fn( self, module_spec: corpus.ModuleSpec, tf_policy_path: str, reward_only: bool, cancellation_manager: Optional[ compilation_runner.WorkerCancellationManager]