Skip to content

Commit 2006fcc

Browse files
authored
Refactor agent configuration. (#237)
Replace agent ID and if-checks in various parts of the code (in agent creation and in data processing, respectivelly) with `AgentConfig`, an abstraction encapsulating the various behaviors.
1 parent 74d638a commit 2006fcc

21 files changed

+298
-306
lines changed

compiler_opt/rl/agent_creators.py

Lines changed: 184 additions & 125 deletions
Original file line numberDiff line numberDiff line change
@@ -14,147 +14,206 @@
1414
# limitations under the License.
1515
"""util function to create a tf_agent."""
1616

17-
from typing import Callable
17+
from typing import Any, Callable, Dict
1818

19+
import abc
1920
import gin
2021
import tensorflow as tf
2122

2223
from tf_agents.agents import tf_agent
2324
from tf_agents.agents.behavioral_cloning import behavioral_cloning_agent
2425
from tf_agents.agents.dqn import dqn_agent
2526
from tf_agents.agents.ppo import ppo_agent
27+
from tf_agents.specs import tensor_spec
2628
from tf_agents.typing import types
2729

28-
from compiler_opt.rl import constant
2930
from compiler_opt.rl import constant_value_network
3031
from compiler_opt.rl.distributed import agent as distributed_ppo_agent
3132

3233

33-
def _create_behavioral_cloning_agent(
34-
time_step_spec: types.NestedTensorSpec, action_spec: types.NestedTensorSpec,
35-
preprocessing_layers: types.NestedLayer,
36-
policy_network: types.Network) -> tf_agent.TFAgent:
37-
"""Creates a behavioral_cloning_agent."""
38-
39-
network = policy_network(
40-
time_step_spec.observation,
41-
action_spec,
42-
preprocessing_layers=preprocessing_layers,
43-
name='QNetwork')
44-
45-
return behavioral_cloning_agent.BehavioralCloningAgent(
46-
time_step_spec, action_spec, cloning_network=network, num_outer_dims=2)
47-
48-
49-
def _create_dqn_agent(time_step_spec: types.NestedTensorSpec,
50-
action_spec: types.NestedTensorSpec,
51-
preprocessing_layers: types.NestedLayer,
52-
policy_network: types.Network) -> tf_agent.TFAgent:
53-
"""Creates a dqn_agent."""
54-
network = policy_network(
55-
time_step_spec.observation,
56-
action_spec,
57-
preprocessing_layers=preprocessing_layers,
58-
name='QNetwork')
59-
60-
return dqn_agent.DqnAgent(time_step_spec, action_spec, q_network=network)
61-
62-
63-
def _create_ppo_agent(time_step_spec: types.NestedTensorSpec,
64-
action_spec: types.NestedTensorSpec,
65-
preprocessing_layers: types.NestedLayer,
66-
policy_network: types.Network) -> tf_agent.TFAgent:
67-
"""Creates a ppo_agent."""
68-
69-
actor_network = policy_network(
70-
time_step_spec.observation,
71-
action_spec,
72-
preprocessing_layers=preprocessing_layers,
73-
name='ActorDistributionNetwork')
74-
75-
critic_network = constant_value_network.ConstantValueNetwork(
76-
time_step_spec.observation, name='ConstantValueNetwork')
77-
78-
return ppo_agent.PPOAgent(
79-
time_step_spec,
80-
action_spec,
81-
actor_net=actor_network,
82-
value_net=critic_network)
83-
84-
85-
def _create_ppo_distributed_agent(
86-
time_step_spec: types.NestedTensorSpec, action_spec: types.NestedTensorSpec,
87-
preprocessing_layers: types.NestedLayer,
88-
policy_network: types.Network) -> tf_agent.TFAgent:
89-
"""Creates a ppo_distributed agent."""
90-
actor_network = policy_network(
91-
time_step_spec.observation,
92-
action_spec,
93-
preprocessing_layers=preprocessing_layers,
94-
preprocessing_combiner=tf.keras.layers.Concatenate(),
95-
name='ActorDistributionNetwork')
96-
97-
critic_network = constant_value_network.ConstantValueNetwork(
98-
time_step_spec.observation, name='ConstantValueNetwork')
99-
100-
return distributed_ppo_agent.MLGOPPOAgent(
101-
time_step_spec,
102-
action_spec,
103-
optimizer=tf.keras.optimizers.Adam(learning_rate=4e-4, epsilon=1e-5),
104-
actor_net=actor_network,
105-
value_net=critic_network,
106-
value_pred_loss_coef=0.0,
107-
entropy_regularization=0.01,
108-
importance_ratio_clipping=0.2,
109-
discount_factor=1.0,
110-
gradient_clipping=1.0,
111-
debug_summaries=False,
112-
value_clipping=None,
113-
aggregate_losses_across_replicas=True,
114-
loss_scaling_factor=1.0)
34+
class AgentConfig(metaclass=abc.ABCMeta):
35+
"""Agent creation and data processing hook-ups."""
36+
37+
def __init__(self, *, time_step_spec: types.NestedTensorSpec,
38+
action_spec: types.NestedTensorSpec):
39+
self._time_step_spec = time_step_spec
40+
self._action_spec = action_spec
41+
42+
@property
43+
def time_step_spec(self):
44+
return self._time_step_spec
45+
46+
@property
47+
def action_spec(self):
48+
return self._action_spec
49+
50+
@abc.abstractmethod
51+
def create_agent(self, preprocessing_layers: tf.keras.layers.Layer,
52+
policy_network: types.Network) -> tf_agent.TFAgent:
53+
"""Specific agent configs must implement this."""
54+
raise NotImplementedError()
55+
56+
def get_policy_info_parsing_dict(
57+
self) -> Dict[str, tf.io.FixedLenSequenceFeature]:
58+
"""Return the parsing dict for the policy info."""
59+
return {}
60+
61+
# pylint: disable=unused-argument
62+
def process_parsed_sequence_and_get_policy_info(
63+
self, parsed_sequence: Dict[str, Any]) -> Dict[str, Dict[str, Any]]:
64+
"""Function to process parsed_sequence and to return policy_info.
65+
66+
Args:
67+
parsed_sequence: A dict from feature_name to feature_value parsed from TF
68+
SequenceExample.
69+
70+
Returns:
71+
A nested policy_info for given agent.
72+
"""
73+
return {}
11574

11675

11776
@gin.configurable
118-
def create_agent(agent_name: constant.AgentName,
119-
time_step_spec: types.NestedTensorSpec,
120-
action_spec: types.NestedTensorSpec,
77+
def create_agent(agent_config: AgentConfig,
12178
preprocessing_layer_creator: Callable[[types.TensorSpec],
12279
tf.keras.layers.Layer],
123-
policy_network: types.Network) -> tf_agent.TFAgent:
124-
"""Creates a tfa.agents.TFAgent object.
125-
126-
Args:
127-
agent_name: AgentName, enum type of the agent to create.
128-
time_step_spec: A `TimeStep` spec of the expected time_steps.
129-
action_spec: A nest of BoundedTensorSpec representing the actions.
130-
preprocessing_layer_creator: A callable returns feature processing layer
131-
given observation_spec.
132-
policy_network: A tf_agents.networks.Network class.
133-
134-
Returns:
135-
tf_agent: A tfa.agents.TFAgent object.
136-
137-
Raises:
138-
ValueError: If `agent_name` is not in supported list.
139-
"""
140-
assert policy_network is not None
141-
assert agent_name is not None
142-
143-
preprocessing_layers = tf.nest.map_structure(preprocessing_layer_creator,
144-
time_step_spec.observation)
145-
146-
if agent_name == constant.AgentName.BEHAVIORAL_CLONE:
147-
return _create_behavioral_cloning_agent(time_step_spec, action_spec,
148-
preprocessing_layers,
149-
policy_network)
150-
elif agent_name == constant.AgentName.DQN:
151-
return _create_dqn_agent(time_step_spec, action_spec, preprocessing_layers,
152-
policy_network)
153-
elif agent_name == constant.AgentName.PPO:
154-
return _create_ppo_agent(time_step_spec, action_spec, preprocessing_layers,
155-
policy_network)
156-
elif agent_name == constant.AgentName.PPO_DISTRIBUTED:
157-
return _create_ppo_distributed_agent(time_step_spec, action_spec,
158-
preprocessing_layers, policy_network)
159-
else:
160-
raise ValueError(f'Unknown agent: {agent_name}')
80+
policy_network: types.Network):
81+
"""Gin configurable wrapper of AgentConfig.create_agent.
82+
Works around the fact that class members aren't gin-configurable."""
83+
preprocessing_layers = tf.nest.map_structure(
84+
preprocessing_layer_creator, agent_config.time_step_spec.observation)
85+
return agent_config.create_agent(preprocessing_layers, policy_network)
86+
87+
88+
@gin.configurable(module='agents')
89+
class BCAgentConfig(AgentConfig):
90+
"""Behavioral Cloning agent configuration."""
91+
92+
def create_agent(self, preprocessing_layers: tf.keras.layers.Layer,
93+
policy_network: types.Network) -> tf_agent.TFAgent:
94+
"""Creates a behavioral_cloning_agent."""
95+
96+
network = policy_network(
97+
self.time_step_spec.observation,
98+
self.action_spec,
99+
preprocessing_layers=preprocessing_layers,
100+
name='QNetwork')
101+
102+
return behavioral_cloning_agent.BehavioralCloningAgent(
103+
self.time_step_spec,
104+
self.action_spec,
105+
cloning_network=network,
106+
num_outer_dims=2)
107+
108+
109+
@gin.configurable(module='agents')
110+
class DQNAgentConfig(AgentConfig):
111+
"""DQN agent configuration."""
112+
113+
def create_agent(self, preprocessing_layers: tf.keras.layers.Layer,
114+
policy_network: types.Network) -> tf_agent.TFAgent:
115+
"""Creates a dqn_agent."""
116+
network = policy_network(
117+
self.time_step_spec.observation,
118+
self.action_spec,
119+
preprocessing_layers=preprocessing_layers,
120+
name='QNetwork')
121+
122+
return dqn_agent.DqnAgent(
123+
self.time_step_spec, self.action_spec, q_network=network)
124+
125+
126+
@gin.configurable(module='agents')
127+
class PPOAgentConfig(AgentConfig):
128+
"""PPO/Reinforce agent configuration."""
129+
130+
def create_agent(self, preprocessing_layers: tf.keras.layers.Layer,
131+
policy_network: types.Network) -> tf_agent.TFAgent:
132+
"""Creates a ppo_agent."""
133+
134+
actor_network = policy_network(
135+
self.time_step_spec.observation,
136+
self.action_spec,
137+
preprocessing_layers=preprocessing_layers,
138+
name='ActorDistributionNetwork')
139+
140+
critic_network = constant_value_network.ConstantValueNetwork(
141+
self.time_step_spec.observation, name='ConstantValueNetwork')
142+
143+
return ppo_agent.PPOAgent(
144+
self.time_step_spec,
145+
self.action_spec,
146+
actor_net=actor_network,
147+
value_net=critic_network)
148+
149+
def get_policy_info_parsing_dict(
150+
self) -> Dict[str, tf.io.FixedLenSequenceFeature]:
151+
if tensor_spec.is_discrete(self._action_spec):
152+
return {
153+
'CategoricalProjectionNetwork_logits':
154+
tf.io.FixedLenSequenceFeature(
155+
shape=(self._action_spec.maximum - self._action_spec.minimum +
156+
1),
157+
dtype=tf.float32)
158+
}
159+
else:
160+
return {
161+
'NormalProjectionNetwork_scale':
162+
tf.io.FixedLenSequenceFeature(shape=(), dtype=tf.float32),
163+
'NormalProjectionNetwork_loc':
164+
tf.io.FixedLenSequenceFeature(shape=(), dtype=tf.float32)
165+
}
166+
167+
def process_parsed_sequence_and_get_policy_info(
168+
self, parsed_sequence: Dict[str, Any]) -> Dict[str, Dict[str, Any]]:
169+
if tensor_spec.is_discrete(self._action_spec):
170+
policy_info = {
171+
'dist_params': {
172+
'logits': parsed_sequence['CategoricalProjectionNetwork_logits']
173+
}
174+
}
175+
del parsed_sequence['CategoricalProjectionNetwork_logits']
176+
else:
177+
policy_info = {
178+
'dist_params': {
179+
'scale': parsed_sequence['NormalProjectionNetwork_scale'],
180+
'loc': parsed_sequence['NormalProjectionNetwork_loc']
181+
}
182+
}
183+
del parsed_sequence['NormalProjectionNetwork_scale']
184+
del parsed_sequence['NormalProjectionNetwork_loc']
185+
return policy_info
186+
187+
188+
@gin.configurable(module='agents')
189+
class DistributedPPOAgentConfig(PPOAgentConfig):
190+
"""Distributed PPO/Reinforce agent configuration."""
191+
192+
def _create_agent_implt(self, preprocessing_layers: tf.keras.layers.Layer,
193+
policy_network: types.Network) -> tf_agent.TFAgent:
194+
"""Creates a ppo_distributed agent."""
195+
actor_network = policy_network(
196+
self.time_step_spec.observation,
197+
self.action_spec,
198+
preprocessing_layers=preprocessing_layers,
199+
preprocessing_combiner=tf.keras.layers.Concatenate(),
200+
name='ActorDistributionNetwork')
201+
202+
critic_network = constant_value_network.ConstantValueNetwork(
203+
self.time_step_spec.observation, name='ConstantValueNetwork')
204+
205+
return distributed_ppo_agent.MLGOPPOAgent(
206+
self.time_step_spec,
207+
self.action_spec,
208+
optimizer=tf.keras.optimizers.Adam(learning_rate=4e-4, epsilon=1e-5),
209+
actor_net=actor_network,
210+
value_net=critic_network,
211+
value_pred_loss_coef=0.0,
212+
entropy_regularization=0.01,
213+
importance_ratio_clipping=0.2,
214+
discount_factor=1.0,
215+
gradient_clipping=1.0,
216+
debug_summaries=False,
217+
value_clipping=None,
218+
aggregate_losses_across_replicas=True,
219+
loss_scaling_factor=1.0)

compiler_opt/rl/agent_creators_test.py

Lines changed: 6 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@
2525
from tf_agents.trajectories import time_step
2626

2727
from compiler_opt.rl import agent_creators
28-
from compiler_opt.rl import constant
2928

3029

3130
def _observation_processing_layer(obs_spec):
@@ -56,9 +55,8 @@ def test_create_behavioral_cloning_agent(self):
5655
gin.bind_parameter('BehavioralCloningAgent.optimizer',
5756
tf.compat.v1.train.AdamOptimizer())
5857
tf_agent = agent_creators.create_agent(
59-
agent_name=constant.AgentName.BEHAVIORAL_CLONE,
60-
time_step_spec=self._time_step_spec,
61-
action_spec=self._action_spec,
58+
agent_creators.BCAgentConfig(
59+
time_step_spec=self._time_step_spec, action_spec=self._action_spec),
6260
preprocessing_layer_creator=_observation_processing_layer)
6361
self.assertIsInstance(tf_agent,
6462
behavioral_cloning_agent.BehavioralCloningAgent)
@@ -67,9 +65,8 @@ def test_create_dqn_agent(self):
6765
gin.bind_parameter('create_agent.policy_network', q_network.QNetwork)
6866
gin.bind_parameter('DqnAgent.optimizer', tf.compat.v1.train.AdamOptimizer())
6967
tf_agent = agent_creators.create_agent(
70-
agent_name=constant.AgentName.DQN,
71-
time_step_spec=self._time_step_spec,
72-
action_spec=self._action_spec,
68+
agent_creators.DQNAgentConfig(
69+
time_step_spec=self._time_step_spec, action_spec=self._action_spec),
7370
preprocessing_layer_creator=_observation_processing_layer)
7471
self.assertIsInstance(tf_agent, dqn_agent.DqnAgent)
7572

@@ -78,9 +75,8 @@ def test_create_ppo_agent(self):
7875
actor_distribution_network.ActorDistributionNetwork)
7976
gin.bind_parameter('PPOAgent.optimizer', tf.compat.v1.train.AdamOptimizer())
8077
tf_agent = agent_creators.create_agent(
81-
agent_name=constant.AgentName.PPO,
82-
time_step_spec=self._time_step_spec,
83-
action_spec=self._action_spec,
78+
agent_creators.PPOAgentConfig(
79+
time_step_spec=self._time_step_spec, action_spec=self._action_spec),
8480
preprocessing_layer_creator=_observation_processing_layer)
8581
self.assertIsInstance(tf_agent, ppo_agent.PPOAgent)
8682

0 commit comments

Comments
 (0)