Skip to content

Commit aca7c4e

Browse files
authored
Fix type annotation and reformat imports in policy_utils (#282)
1 parent 79d7049 commit aca7c4e

File tree

2 files changed

+30
-16
lines changed

2 files changed

+30
-16
lines changed

compiler_opt/es/policy_utils.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -14,15 +14,19 @@
1414
# limitations under the License.
1515
"""Util functions to create and edit a tf_agent policy."""
1616

17+
from typing import Protocol, Sequence, Type
18+
1719
import gin
1820
import numpy as np
1921
import numpy.typing as npt
2022
import tensorflow as tf
21-
from typing import Protocol, Sequence
22-
23-
from compiler_opt.rl import policy_saver, registry
2423
from tf_agents.networks import network
25-
from tf_agents.policies import actor_policy, greedy_policy, tf_policy
24+
from tf_agents.policies import actor_policy
25+
from tf_agents.policies import greedy_policy
26+
from tf_agents.policies import tf_policy
27+
28+
from compiler_opt.rl import policy_saver
29+
from compiler_opt.rl import registry
2630

2731

2832
class HasModelVariables(Protocol):
@@ -31,8 +35,10 @@ class HasModelVariables(Protocol):
3135

3236
# TODO(abenalaast): Issue #280
3337
@gin.configurable(module='policy_utils')
34-
def create_actor_policy(actor_network_ctor: network.DistributionNetwork,
35-
greedy: bool = False) -> tf_policy.TFPolicy:
38+
def create_actor_policy(
39+
actor_network_ctor: Type[network.DistributionNetwork],
40+
greedy: bool = False,
41+
) -> tf_policy.TFPolicy:
3642
"""Creates an actor policy."""
3743
problem_config = registry.get_configuration()
3844
time_step_spec, action_spec = problem_config.get_signature_spec()

compiler_opt/es/policy_utils_test.py

Lines changed: 18 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -14,26 +14,31 @@
1414
# limitations under the License.
1515
"""Tests for policy_utils."""
1616

17+
import os
18+
1719
from absl.testing import absltest
1820
import numpy as np
19-
import os
2021
import tensorflow as tf
2122
from tf_agents.networks import actor_distribution_network
22-
from tf_agents.policies import actor_policy, tf_policy
23+
from tf_agents.policies import actor_policy
24+
from tf_agents.policies import tf_policy
2325

2426
from compiler_opt.es import policy_utils
25-
from compiler_opt.rl import policy_saver, registry
27+
from compiler_opt.rl import inlining
28+
from compiler_opt.rl import policy_saver
29+
from compiler_opt.rl import regalloc
30+
from compiler_opt.rl import registry
2631
from compiler_opt.rl.inlining import config as inlining_config
27-
from compiler_opt.rl.inlining import InliningConfig
2832
from compiler_opt.rl.regalloc import config as regalloc_config
29-
from compiler_opt.rl.regalloc import RegallocEvictionConfig, regalloc_network
33+
from compiler_opt.rl.regalloc import regalloc_network
3034

3135

3236
class ConfigTest(absltest.TestCase):
3337

3438
# TODO(abenalaast): Issue #280
3539
def test_inlining_config(self):
36-
problem_config = registry.get_configuration(implementation=InliningConfig)
40+
problem_config = registry.get_configuration(
41+
implementation=inlining.InliningConfig)
3742
time_step_spec, action_spec = problem_config.get_signature_spec()
3843
quantile_file_dir = os.path.join('compiler_opt', 'rl', 'inlining', 'vocab')
3944
creator = inlining_config.get_observation_processing_layer_creator(
@@ -64,7 +69,7 @@ def test_inlining_config(self):
6469
# TODO(abenalaast): Issue #280
6570
def test_regalloc_config(self):
6671
problem_config = registry.get_configuration(
67-
implementation=RegallocEvictionConfig)
72+
implementation=regalloc.RegallocEvictionConfig)
6873
time_step_spec, action_spec = problem_config.get_signature_spec()
6974
quantile_file_dir = os.path.join('compiler_opt', 'rl', 'regalloc', 'vocab')
7075
creator = regalloc_config.get_observation_processing_layer_creator(
@@ -105,7 +110,8 @@ class VectorTest(absltest.TestCase):
105110
# TODO(abenalaast): Issue #280
106111
def test_set_vectorized_parameters_for_policy(self):
107112
# create a policy
108-
problem_config = registry.get_configuration(implementation=InliningConfig)
113+
problem_config = registry.get_configuration(
114+
implementation=inlining.InliningConfig)
109115
time_step_spec, action_spec = problem_config.get_signature_spec()
110116
quantile_file_dir = os.path.join('compiler_opt', 'rl', 'inlining', 'vocab')
111117
creator = inlining_config.get_observation_processing_layer_creator(
@@ -167,7 +173,8 @@ def test_set_vectorized_parameters_for_policy(self):
167173
# TODO(abenalaast): Issue #280
168174
def test_get_vectorized_parameters_from_policy(self):
169175
# create a policy
170-
problem_config = registry.get_configuration(implementation=InliningConfig)
176+
problem_config = registry.get_configuration(
177+
implementation=inlining.InliningConfig)
171178
time_step_spec, action_spec = problem_config.get_signature_spec()
172179
quantile_file_dir = os.path.join('compiler_opt', 'rl', 'inlining', 'vocab')
173180
creator = inlining_config.get_observation_processing_layer_creator(
@@ -214,7 +221,8 @@ def test_get_vectorized_parameters_from_policy(self):
214221
# TODO(abenalaast): Issue #280
215222
def test_tfpolicy_and_loaded_policy_produce_same_variable_order(self):
216223
# create a policy
217-
problem_config = registry.get_configuration(implementation=InliningConfig)
224+
problem_config = registry.get_configuration(
225+
implementation=inlining.InliningConfig)
218226
time_step_spec, action_spec = problem_config.get_signature_spec()
219227
quantile_file_dir = os.path.join('compiler_opt', 'rl', 'inlining', 'vocab')
220228
creator = inlining_config.get_observation_processing_layer_creator(

0 commit comments

Comments
 (0)