Skip to content

Commit 4cae2c6

Browse files
authored
import modules instead of importing packages (#76)
1 parent 6efe174 commit 4cae2c6

File tree

5 files changed

+27
-29
lines changed

5 files changed

+27
-29
lines changed

compiler_opt/rl/agent_creators.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
import gin
2020
import tensorflow as tf
2121

22-
from tf_agents.agents import TFAgent
22+
from tf_agents.agents import tf_agent
2323
from tf_agents.agents.behavioral_cloning import behavioral_cloning_agent
2424
from tf_agents.agents.dqn import dqn_agent
2525
from tf_agents.agents.ppo import ppo_agent
@@ -29,10 +29,10 @@
2929
from compiler_opt.rl import constant_value_network
3030

3131

32-
def _create_behavioral_cloning_agent(time_step_spec: types.NestedTensorSpec,
33-
action_spec: types.NestedTensorSpec,
34-
preprocessing_layers: types.NestedLayer,
35-
policy_network: types.Network) -> TFAgent:
32+
def _create_behavioral_cloning_agent(
33+
time_step_spec: types.NestedTensorSpec, action_spec: types.NestedTensorSpec,
34+
preprocessing_layers: types.NestedLayer,
35+
policy_network: types.Network) -> tf_agent.TFAgent:
3636
"""Creates a behavioral_cloning_agent."""
3737

3838
network = policy_network(
@@ -48,7 +48,7 @@ def _create_behavioral_cloning_agent(time_step_spec: types.NestedTensorSpec,
4848
def _create_dqn_agent(time_step_spec: types.NestedTensorSpec,
4949
action_spec: types.NestedTensorSpec,
5050
preprocessing_layers: types.NestedLayer,
51-
policy_network: types.Network) -> TFAgent:
51+
policy_network: types.Network) -> tf_agent.TFAgent:
5252
"""Creates a dqn_agent."""
5353
network = policy_network(
5454
time_step_spec.observation,
@@ -62,7 +62,7 @@ def _create_dqn_agent(time_step_spec: types.NestedTensorSpec,
6262
def _create_ppo_agent(time_step_spec: types.NestedTensorSpec,
6363
action_spec: types.NestedTensorSpec,
6464
preprocessing_layers: types.NestedLayer,
65-
policy_network: types.Network) -> TFAgent:
65+
policy_network: types.Network) -> tf_agent.TFAgent:
6666
"""Creates a ppo_agent."""
6767

6868
actor_network = policy_network(
@@ -87,7 +87,7 @@ def create_agent(agent_name: constant.AgentName,
8787
action_spec: types.NestedTensorSpec,
8888
preprocessing_layer_creator: Callable[[types.TensorSpec],
8989
tf.keras.layers.Layer],
90-
policy_network: types.Network) -> TFAgent:
90+
policy_network: types.Network) -> tf_agent.TFAgent:
9191
"""Creates a tfa.agents.TFAgent object.
9292
9393
Args:

compiler_opt/rl/policy_saver.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
import os
1919

2020
import tensorflow as tf
21-
from tf_agents.policies import TFPolicy
21+
from tf_agents.policies import tf_policy
2222
from tf_agents.policies import policy_saver
2323

2424
from typing import Dict, Tuple
@@ -77,14 +77,14 @@ class PolicySaver(object):
7777
```
7878
"""
7979

80-
def __init__(self, policy_dict: Dict[str, TFPolicy]):
80+
def __init__(self, policy_dict: Dict[str, tf_policy.TFPolicy]):
8181
"""Initialize the PolicySaver object.
8282
8383
Args:
8484
policy_dict: A dict mapping from policy name to policy.
8585
"""
8686
self._policy_saver_dict: Dict[str, Tuple[
87-
policy_saver.PolicySaver, TFPolicy]] = {
87+
policy_saver.PolicySaver, tf_policy.TFPolicy]] = {
8888
policy_name: (policy_saver.PolicySaver(
8989
policy, batch_size=1, use_nest_path_signatures=False), policy)
9090
for policy_name, policy in policy_dict.items()

compiler_opt/rl/train_bc.py

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -30,8 +30,8 @@
3030
from compiler_opt.rl import registry
3131
from compiler_opt.rl import trainer
3232

33-
from tf_agents.agents import TFAgent
34-
from tf_agents.policies import TFPolicy
33+
from tf_agents.agents import tf_agent
34+
from tf_agents.policies import tf_policy
3535

3636
from typing import Dict
3737

@@ -63,13 +63,12 @@ def train_eval(agent_name=constant.AgentName.BEHAVIORAL_CLONE,
6363
preprocessing_layer_creator = problem_config.get_preprocessing_layer_creator()
6464

6565
# Initialize trainer and policy saver.
66-
tf_agent: TFAgent = agent_creators.create_agent(agent_name, time_step_spec,
67-
action_spec,
68-
preprocessing_layer_creator)
69-
llvm_trainer = trainer.Trainer(root_dir=root_dir, agent=tf_agent)
70-
policy_dict: Dict[str, TFPolicy] = {
71-
'saved_policy': tf_agent.policy,
72-
'saved_collect_policy': tf_agent.collect_policy,
66+
agent: tf_agent.TFAgent = agent_creators.create_agent(
67+
agent_name, time_step_spec, action_spec, preprocessing_layer_creator)
68+
llvm_trainer = trainer.Trainer(root_dir=root_dir, agent=agent)
69+
policy_dict: Dict[str, tf_policy.TFPolicy] = {
70+
'saved_policy': agent.policy,
71+
'saved_collect_policy': agent.collect_policy,
7372
}
7473
saver = policy_saver.PolicySaver(policy_dict=policy_dict)
7574

compiler_opt/rl/train_locally.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
from absl import logging
2626
import gin
2727
import tensorflow as tf
28-
from tf_agents.agents import TFAgent
28+
from tf_agents.agents import tf_agent
2929
from tf_agents.system import system_multiprocessing as multiprocessing
3030
from typing import List
3131

@@ -78,9 +78,8 @@ def train_eval(agent_name=constant.AgentName.PPO,
7878
preprocessing_layer_creator = problem_config.get_preprocessing_layer_creator()
7979

8080
# Initialize trainer and policy saver.
81-
tf_agent: TFAgent = agent_creators.create_agent(agent_name, time_step_spec,
82-
action_spec,
83-
preprocessing_layer_creator)
81+
agent: tf_agent.TFAgent = agent_creators.create_agent(
82+
agent_name, time_step_spec, action_spec, preprocessing_layer_creator)
8483
# create the random network distillation object
8584
random_network_distillation = None
8685
if use_random_network_distillation:
@@ -91,13 +90,13 @@ def train_eval(agent_name=constant.AgentName.PPO,
9190

9291
llvm_trainer = trainer.Trainer(
9392
root_dir=root_dir,
94-
agent=tf_agent,
93+
agent=agent,
9594
random_network_distillation=random_network_distillation,
9695
warmstart_policy_dir=warmstart_policy_dir)
9796

9897
policy_dict = {
99-
'saved_policy': tf_agent.policy,
100-
'saved_collect_policy': tf_agent.collect_policy,
98+
'saved_policy': agent.policy,
99+
'saved_collect_policy': agent.collect_policy,
101100
}
102101
saver = policy_saver.PolicySaver(policy_dict=policy_dict)
103102

compiler_opt/rl/trainer.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
import gin
2222
import tensorflow as tf
2323
from compiler_opt.rl import random_net_distillation
24-
from tf_agents.agents import TFAgent
24+
from tf_agents.agents import tf_agent
2525
from tf_agents.policies import policy_loader
2626

2727
from tf_agents.utils import common as common_utils
@@ -47,7 +47,7 @@ class Trainer(object):
4747
def __init__(
4848
self,
4949
root_dir: str,
50-
agent: TFAgent,
50+
agent: tf_agent.TFAgent,
5151
random_network_distillation: Optional[
5252
random_net_distillation.RandomNetworkDistillation] = None,
5353
warmstart_policy_dir: Optional[str] = None,

0 commit comments

Comments
 (0)