|
14 | 14 | # limitations under the License.
|
15 | 15 | """Tests for policy_utils."""
|
16 | 16 |
|
| 17 | +import os |
| 18 | + |
17 | 19 | from absl.testing import absltest
|
18 | 20 | import numpy as np
|
19 |
| -import os |
20 | 21 | import tensorflow as tf
|
21 | 22 | from tf_agents.networks import actor_distribution_network
|
22 |
| -from tf_agents.policies import actor_policy, tf_policy |
| 23 | +from tf_agents.policies import actor_policy |
| 24 | +from tf_agents.policies import tf_policy |
23 | 25 |
|
24 | 26 | from compiler_opt.es import policy_utils
|
25 |
| -from compiler_opt.rl import policy_saver, registry |
| 27 | +from compiler_opt.rl import inlining |
| 28 | +from compiler_opt.rl import policy_saver |
| 29 | +from compiler_opt.rl import regalloc |
| 30 | +from compiler_opt.rl import registry |
26 | 31 | from compiler_opt.rl.inlining import config as inlining_config
|
27 |
| -from compiler_opt.rl.inlining import InliningConfig |
28 | 32 | from compiler_opt.rl.regalloc import config as regalloc_config
|
29 |
| -from compiler_opt.rl.regalloc import RegallocEvictionConfig, regalloc_network |
| 33 | +from compiler_opt.rl.regalloc import regalloc_network |
30 | 34 |
|
31 | 35 |
|
32 | 36 | class ConfigTest(absltest.TestCase):
|
33 | 37 |
|
34 | 38 | # TODO(abenalaast): Issue #280
|
35 | 39 | def test_inlining_config(self):
|
36 |
| - problem_config = registry.get_configuration(implementation=InliningConfig) |
| 40 | + problem_config = registry.get_configuration( |
| 41 | + implementation=inlining.InliningConfig) |
37 | 42 | time_step_spec, action_spec = problem_config.get_signature_spec()
|
38 | 43 | quantile_file_dir = os.path.join('compiler_opt', 'rl', 'inlining', 'vocab')
|
39 | 44 | creator = inlining_config.get_observation_processing_layer_creator(
|
@@ -64,7 +69,7 @@ def test_inlining_config(self):
|
64 | 69 | # TODO(abenalaast): Issue #280
|
65 | 70 | def test_regalloc_config(self):
|
66 | 71 | problem_config = registry.get_configuration(
|
67 |
| - implementation=RegallocEvictionConfig) |
| 72 | + implementation=regalloc.RegallocEvictionConfig) |
68 | 73 | time_step_spec, action_spec = problem_config.get_signature_spec()
|
69 | 74 | quantile_file_dir = os.path.join('compiler_opt', 'rl', 'regalloc', 'vocab')
|
70 | 75 | creator = regalloc_config.get_observation_processing_layer_creator(
|
@@ -105,7 +110,8 @@ class VectorTest(absltest.TestCase):
|
105 | 110 | # TODO(abenalaast): Issue #280
|
106 | 111 | def test_set_vectorized_parameters_for_policy(self):
|
107 | 112 | # create a policy
|
108 |
| - problem_config = registry.get_configuration(implementation=InliningConfig) |
| 113 | + problem_config = registry.get_configuration( |
| 114 | + implementation=inlining.InliningConfig) |
109 | 115 | time_step_spec, action_spec = problem_config.get_signature_spec()
|
110 | 116 | quantile_file_dir = os.path.join('compiler_opt', 'rl', 'inlining', 'vocab')
|
111 | 117 | creator = inlining_config.get_observation_processing_layer_creator(
|
@@ -167,7 +173,8 @@ def test_set_vectorized_parameters_for_policy(self):
|
167 | 173 | # TODO(abenalaast): Issue #280
|
168 | 174 | def test_get_vectorized_parameters_from_policy(self):
|
169 | 175 | # create a policy
|
170 |
| - problem_config = registry.get_configuration(implementation=InliningConfig) |
| 176 | + problem_config = registry.get_configuration( |
| 177 | + implementation=inlining.InliningConfig) |
171 | 178 | time_step_spec, action_spec = problem_config.get_signature_spec()
|
172 | 179 | quantile_file_dir = os.path.join('compiler_opt', 'rl', 'inlining', 'vocab')
|
173 | 180 | creator = inlining_config.get_observation_processing_layer_creator(
|
@@ -214,7 +221,8 @@ def test_get_vectorized_parameters_from_policy(self):
|
214 | 221 | # TODO(abenalaast): Issue #280
|
215 | 222 | def test_tfpolicy_and_loaded_policy_produce_same_variable_order(self):
|
216 | 223 | # create a policy
|
217 |
| - problem_config = registry.get_configuration(implementation=InliningConfig) |
| 224 | + problem_config = registry.get_configuration( |
| 225 | + implementation=inlining.InliningConfig) |
218 | 226 | time_step_spec, action_spec = problem_config.get_signature_spec()
|
219 | 227 | quantile_file_dir = os.path.join('compiler_opt', 'rl', 'inlining', 'vocab')
|
220 | 228 | creator = inlining_config.get_observation_processing_layer_creator(
|
|
0 commit comments