Skip to content

Commit 0889219

Browse files
authored
Rename agent_creators to agent_config (#239)
Post PR #237, for naming consistency.
1 parent f6cc33d commit 0889219

12 files changed

+78
-80
lines changed
File renamed without changes.

compiler_opt/rl/agent_creators_test.py renamed to compiler_opt/rl/agent_config_test.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
15-
"""Tests for compiler_opt.rl.agent_creators."""
15+
"""Tests for compiler_opt.rl.agent_config."""
1616

1717
import gin
1818
import tensorflow as tf
@@ -24,7 +24,7 @@
2424
from tf_agents.specs import tensor_spec
2525
from tf_agents.trajectories import time_step
2626

27-
from compiler_opt.rl import agent_creators
27+
from compiler_opt.rl import agent_config
2828

2929

3030
def _observation_processing_layer(obs_spec):
@@ -54,8 +54,8 @@ def test_create_behavioral_cloning_agent(self):
5454
gin.bind_parameter('create_agent.policy_network', q_network.QNetwork)
5555
gin.bind_parameter('BehavioralCloningAgent.optimizer',
5656
tf.compat.v1.train.AdamOptimizer())
57-
tf_agent = agent_creators.create_agent(
58-
agent_creators.BCAgentConfig(
57+
tf_agent = agent_config.create_agent(
58+
agent_config.BCAgentConfig(
5959
time_step_spec=self._time_step_spec, action_spec=self._action_spec),
6060
preprocessing_layer_creator=_observation_processing_layer)
6161
self.assertIsInstance(tf_agent,
@@ -64,8 +64,8 @@ def test_create_behavioral_cloning_agent(self):
6464
def test_create_dqn_agent(self):
6565
gin.bind_parameter('create_agent.policy_network', q_network.QNetwork)
6666
gin.bind_parameter('DqnAgent.optimizer', tf.compat.v1.train.AdamOptimizer())
67-
tf_agent = agent_creators.create_agent(
68-
agent_creators.DQNAgentConfig(
67+
tf_agent = agent_config.create_agent(
68+
agent_config.DQNAgentConfig(
6969
time_step_spec=self._time_step_spec, action_spec=self._action_spec),
7070
preprocessing_layer_creator=_observation_processing_layer)
7171
self.assertIsInstance(tf_agent, dqn_agent.DqnAgent)
@@ -74,8 +74,8 @@ def test_create_ppo_agent(self):
7474
gin.bind_parameter('create_agent.policy_network',
7575
actor_distribution_network.ActorDistributionNetwork)
7676
gin.bind_parameter('PPOAgent.optimizer', tf.compat.v1.train.AdamOptimizer())
77-
tf_agent = agent_creators.create_agent(
78-
agent_creators.PPOAgentConfig(
77+
tf_agent = agent_config.create_agent(
78+
agent_config.PPOAgentConfig(
7979
time_step_spec=self._time_step_spec, action_spec=self._action_spec),
8080
preprocessing_layer_creator=_observation_processing_layer)
8181
self.assertIsInstance(tf_agent, ppo_agent.PPOAgent)

compiler_opt/rl/data_reader.py

Lines changed: 24 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -19,11 +19,11 @@
1919
import tensorflow as tf
2020
from tf_agents.trajectories import trajectory
2121

22-
from compiler_opt.rl import agent_creators
22+
from compiler_opt.rl import agent_config
2323

2424

2525
def create_parser_fn(
26-
agent_config: agent_creators.AgentConfig
26+
agent_cfg: agent_config.AgentConfig
2727
) -> Callable[[str], trajectory.Trajectory]:
2828
"""Create a parser function for reading from a serialized tf.SequenceExample.
2929
@@ -48,16 +48,16 @@ def _parser_fn(serialized_proto):
4848
(tensor_spec.name,
4949
tf.io.FixedLenSequenceFeature(
5050
shape=tensor_spec.shape, dtype=tensor_spec.dtype))
51-
for tensor_spec in agent_config.time_step_spec.observation.values())
51+
for tensor_spec in agent_cfg.time_step_spec.observation.values())
5252
sequence_features[
53-
agent_config.action_spec.name] = tf.io.FixedLenSequenceFeature(
54-
shape=agent_config.action_spec.shape,
55-
dtype=agent_config.action_spec.dtype)
56-
sequence_features[agent_config.time_step_spec.reward
57-
.name] = tf.io.FixedLenSequenceFeature(
58-
shape=agent_config.time_step_spec.reward.shape,
59-
dtype=agent_config.time_step_spec.reward.dtype)
60-
sequence_features.update(agent_config.get_policy_info_parsing_dict())
53+
agent_cfg.action_spec.name] = tf.io.FixedLenSequenceFeature(
54+
shape=agent_cfg.action_spec.shape,
55+
dtype=agent_cfg.action_spec.dtype)
56+
sequence_features[
57+
agent_cfg.time_step_spec.reward.name] = tf.io.FixedLenSequenceFeature(
58+
shape=agent_cfg.time_step_spec.reward.shape,
59+
dtype=agent_cfg.time_step_spec.reward.dtype)
60+
sequence_features.update(agent_cfg.get_policy_info_parsing_dict())
6161

6262
# pylint: enable=g-complex-comprehension
6363
with tf.name_scope('parse'):
@@ -66,15 +66,15 @@ def _parser_fn(serialized_proto):
6666
context_features=context_features,
6767
sequence_features=sequence_features)
6868
# TODO(yundi): make the transformed reward configurable.
69-
action = parsed_sequence[agent_config.action_spec.name]
70-
reward = tf.cast(parsed_sequence[agent_config.time_step_spec.reward.name],
69+
action = parsed_sequence[agent_cfg.action_spec.name]
70+
reward = tf.cast(parsed_sequence[agent_cfg.time_step_spec.reward.name],
7171
tf.float32)
7272

73-
policy_info = agent_config.process_parsed_sequence_and_get_policy_info(
73+
policy_info = agent_cfg.process_parsed_sequence_and_get_policy_info(
7474
parsed_sequence)
7575

76-
del parsed_sequence[agent_config.time_step_spec.reward.name]
77-
del parsed_sequence[agent_config.action_spec.name]
76+
del parsed_sequence[agent_cfg.time_step_spec.reward.name]
77+
del parsed_sequence[agent_cfg.action_spec.name]
7878
full_trajectory = trajectory.from_episode(
7979
observation=parsed_sequence,
8080
action=action,
@@ -86,7 +86,7 @@ def _parser_fn(serialized_proto):
8686

8787

8888
def create_flat_sequence_example_dataset_fn(
89-
agent_config: agent_creators.AgentConfig
89+
agent_cfg: agent_config.AgentConfig
9090
) -> Callable[[List[str]], tf.data.Dataset]:
9191
"""Get a function that creates a dataset from serialized sequence examples.
9292
@@ -103,7 +103,7 @@ def create_flat_sequence_example_dataset_fn(
103103
a `tf.data.Dataset`. Treating this dataset as an iterator yields batched
104104
`trajectory.Trajectory` instances with shape `[...]`.
105105
"""
106-
parser_fn = create_parser_fn(agent_config)
106+
parser_fn = create_parser_fn(agent_cfg)
107107

108108
def _sequence_example_dataset_fn(sequence_examples):
109109
# Data collector returns empty strings for corner cases, filter them out
@@ -123,7 +123,7 @@ def _sequence_example_dataset_fn(sequence_examples):
123123

124124

125125
def create_sequence_example_dataset_fn(
126-
agent_config: agent_creators.AgentConfig, batch_size: int,
126+
agent_cfg: agent_config.AgentConfig, batch_size: int,
127127
train_sequence_length: int) -> Callable[[List[str]], tf.data.Dataset]:
128128
"""Get a function that creates a dataset from serialized sequence examples.
129129
@@ -142,7 +142,7 @@ def create_sequence_example_dataset_fn(
142142
trajectory_shuffle_buffer_size = 1024
143143

144144
flat_sequence_example_dataset_fn = create_flat_sequence_example_dataset_fn(
145-
agent_config)
145+
agent_cfg)
146146

147147
def _sequence_example_dataset_fn(sequence_examples):
148148
# Data collector returns empty strings for corner cases, filter them out
@@ -160,7 +160,7 @@ def _sequence_example_dataset_fn(sequence_examples):
160160
# TODO(yundi): PyType check of input_dataset as Type[tf.data.Dataset] is not
161161
# working.
162162
def create_file_dataset_fn(
163-
agent_config: agent_creators.AgentConfig,
163+
agent_cfg: agent_config.AgentConfig,
164164
batch_size: int,
165165
train_sequence_length: int,
166166
input_dataset) -> Callable[[List[str]], tf.data.Dataset]:
@@ -185,7 +185,7 @@ def create_file_dataset_fn(
185185
shuffle_buffer_size = 1024
186186
trajectory_shuffle_buffer_size = 1024
187187

188-
parser_fn = create_parser_fn(agent_config)
188+
parser_fn = create_parser_fn(agent_cfg)
189189

190190
def _file_dataset_fn(data_path):
191191
dataset = (
@@ -213,7 +213,7 @@ def _file_dataset_fn(data_path):
213213

214214

215215
def create_tfrecord_dataset_fn(
216-
agent_config: agent_creators.AgentConfig, batch_size: int,
216+
agent_cfg: agent_config.AgentConfig, batch_size: int,
217217
train_sequence_length: int) -> Callable[[List[str]], tf.data.Dataset]:
218218
"""Get a function that creates an dataset from tfrecord.
219219
@@ -230,7 +230,7 @@ def create_tfrecord_dataset_fn(
230230
shape `[B, T, ...]`.
231231
"""
232232
return create_file_dataset_fn(
233-
agent_config,
233+
agent_cfg,
234234
batch_size,
235235
train_sequence_length,
236236
input_dataset=tf.data.TFRecordDataset)

compiler_opt/rl/data_reader_test.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
from tf_agents.trajectories import time_step
2525
from tf_agents.trajectories import trajectory
2626

27-
from compiler_opt.rl import agent_creators
27+
from compiler_opt.rl import agent_config
2828
from compiler_opt.rl import data_reader
2929

3030

@@ -41,8 +41,8 @@ def _define_sequence_example(agent_config_type, is_action_discrete):
4141
).float_list.value.append(1.23)
4242
example.feature_lists.feature_list['reward'].feature.add(
4343
).float_list.value.append(2.3)
44-
if agent_config_type in (agent_creators.PPOAgentConfig,
45-
agent_creators.DistributedPPOAgentConfig):
44+
if agent_config_type in (agent_config.PPOAgentConfig,
45+
agent_config.DistributedPPOAgentConfig):
4646
if is_action_discrete:
4747
example.feature_lists.feature_list[
4848
'CategoricalProjectionNetwork_logits'].feature.add(
@@ -97,20 +97,20 @@ def _create_tfrecord_datasource(self, example):
9797

9898
_test_config = (('SequenceExampleDatasetFn',
9999
data_reader.create_sequence_example_dataset_fn,
100-
agent_creators.PPOAgentConfig,
100+
agent_config.PPOAgentConfig,
101101
_create_sequence_example_datasource),
102102
('TFRecordDatasetFn', data_reader.create_tfrecord_dataset_fn,
103-
agent_creators.PPOAgentConfig, _create_tfrecord_datasource))
103+
agent_config.PPOAgentConfig, _create_tfrecord_datasource))
104104

105105
@parameterized.named_parameters(*_test_config)
106106
def test_create_dataset_fn(self, test_fn, _, data_source_fn):
107-
agent_type_override = agent_creators.DQNAgentConfig
107+
agent_type_override = agent_config.DQNAgentConfig
108108
example = _define_sequence_example(
109109
agent_type_override, is_action_discrete=True)
110110

111111
data_source = data_source_fn(self, example)
112112
dataset_fn = test_fn(
113-
agent_config=agent_type_override(
113+
agent_cfg=agent_type_override(
114114
time_step_spec=self._time_step_spec,
115115
action_spec=self._discrete_action_spec),
116116
batch_size=2,
@@ -131,11 +131,11 @@ def test_create_dataset_fn(self, test_fn, _, data_source_fn):
131131

132132
_distrib_test_config = (('SequenceExampleDatasetFnDistributed',
133133
data_reader.create_sequence_example_dataset_fn,
134-
agent_creators.DistributedPPOAgentConfig,
134+
agent_config.DistributedPPOAgentConfig,
135135
_create_sequence_example_datasource),
136136
('TFRecordDatasetFnDistributed',
137137
data_reader.create_tfrecord_dataset_fn,
138-
agent_creators.DistributedPPOAgentConfig,
138+
agent_config.DistributedPPOAgentConfig,
139139
_create_tfrecord_datasource))
140140

141141
@parameterized.named_parameters(*(_test_config + _distrib_test_config))
@@ -147,7 +147,7 @@ def test_ppo_policy_info_discrete(self, test_fn, agent_config_type,
147147
data_source = data_source_fn(self, example)
148148

149149
dataset_fn = test_fn(
150-
agent_config=agent_config_type(
150+
agent_cfg=agent_config_type(
151151
time_step_spec=self._time_step_spec,
152152
action_spec=self._discrete_action_spec),
153153
batch_size=2,
@@ -169,7 +169,7 @@ def test_ppo_policy_info_continuous(self, test_fn, agent_config_type,
169169
data_source = data_source_fn(self, example)
170170

171171
dataset_fn = test_fn(
172-
agent_config=agent_config_type(
172+
agent_cfg=agent_config_type(
173173
time_step_spec=self._time_step_spec,
174174
action_spec=self._continuous_action_spec),
175175
batch_size=2,

compiler_opt/rl/distributed/ppo_collect_lib.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@
4141
from compiler_opt.rl import data_reader
4242
from compiler_opt.rl import policy_saver
4343
from compiler_opt.rl import registry
44-
from compiler_opt.rl import agent_creators
44+
from compiler_opt.rl import agent_config
4545
from compiler_opt.rl import compilation_runner
4646

4747

@@ -65,7 +65,7 @@ class ReverbCompilationObserver(compilation_runner.CompilationResultObserver):
6565
"""Observer which sends compilation results to reverb"""
6666

6767
def __init__(self,
68-
agent_config,
68+
agent_cfg,
6969
replay_buffer_server_address: str,
7070
sequence_length: int,
7171
initial_priority: float = 0.0):
@@ -79,7 +79,7 @@ def __init__(self,
7979
priority=initial_priority)
8080

8181
self._parser = data_reader.create_flat_sequence_example_dataset_fn(
82-
agent_config=agent_config)
82+
agent_cfg=agent_cfg)
8383

8484
def _is_actionable_result(
8585
self, result: compilation_runner.CompilationResult) -> bool:
@@ -121,10 +121,10 @@ def collect(corpus_path: str, replay_buffer_server_address: str,
121121
logging.info('Initializing the distributed PPO agent')
122122
problem_config = registry.get_configuration()
123123
time_step_spec, action_spec = problem_config.get_signature_spec()
124-
agent_config = agent_creators.DistributedPPOAgentConfig(
124+
agent_cfg = agent_config.DistributedPPOAgentConfig(
125125
time_step_spec=time_step_spec, action_spec=action_spec)
126-
agent = agent_creators.create_agent(
127-
agent_config.agent,
126+
agent = agent_config.create_agent(
127+
agent_cfg.agent,
128128
preprocessing_layer_creator=problem_config
129129
.get_preprocessing_layer_creator())
130130

@@ -145,15 +145,15 @@ def collect(corpus_path: str, replay_buffer_server_address: str,
145145
create_observer_fns = [
146146
functools.partial(
147147
ReverbCompilationObserver,
148-
agent_config=agent_config,
148+
agent_config=agent_cfg,
149149
replay_buffer_server_address=replay_buffer_server_address,
150150
sequence_length=sequence_length)
151151
]
152152

153153
# Setup the corpus
154154
logging.info('Constructing tf.data pipeline and module corpus')
155155
dataset_fn = data_reader.create_flat_sequence_example_dataset_fn(
156-
agent_config=agent_config)
156+
agent_cfg=agent_cfg)
157157

158158
def sequence_example_iterator_fn(seq_ex: List[str]):
159159
return iter(dataset_fn(seq_ex).prefetch(tf.data.AUTOTUNE))

compiler_opt/rl/distributed/ppo_eval_lib.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@
3333
from compiler_opt.rl import local_data_collector
3434
from compiler_opt.rl import gin_external_configurables # pylint: disable=unused-import
3535
from compiler_opt.rl import corpus
36-
from compiler_opt.rl import agent_creators
36+
from compiler_opt.rl import agent_config
3737
from compiler_opt.rl import registry
3838
from compiler_opt.rl import policy_saver
3939
from compiler_opt.rl import data_collector
@@ -57,10 +57,10 @@ def evaluate(root_dir: str, corpus_path: str,
5757
logging.info('Initializing the distributed PPO agent')
5858
problem_config = registry.get_configuration()
5959
time_step_spec, action_spec = problem_config.get_signature_spec()
60-
agent_config = agent_creators.DistributedPPOAgentConfig(
60+
agent_cfg = agent_config.DistributedPPOAgentConfig(
6161
time_step_spec=time_step_spec, action_spec=action_spec)
62-
agent = agent_creators.create_agent(
63-
agent_config.agent,
62+
agent = agent_config.create_agent(
63+
agent_cfg.agent,
6464
preprocessing_layer_creator=problem_config
6565
.get_preprocessing_layer_creator())
6666

@@ -85,7 +85,7 @@ def evaluate(root_dir: str, corpus_path: str,
8585
# Setup the corpus
8686
logging.info('Constructing tf.data pipeline and module corpus')
8787
dataset_fn = data_reader.create_flat_sequence_example_dataset_fn(
88-
agent_config=agent_config)
88+
agent_cfg=agent_cfg)
8989

9090
def sequence_example_iterator_fn(seq_ex: List[str]):
9191
return iter(dataset_fn(seq_ex).prefetch(tf.data.AUTOTUNE))

compiler_opt/rl/distributed/ppo_reverb_server.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222

2323
from compiler_opt.rl.distributed import ppo_reverb_server_lib
2424
from compiler_opt.rl import registry # pylint: disable=unused-import
25-
from compiler_opt.rl import agent_creators # pylint: disable=unused-import
25+
from compiler_opt.rl import agent_config # pylint: disable=unused-import
2626

2727
flags.DEFINE_string('root_dir', None,
2828
'Root directory for writing logs/summaries/checkpoints.')

compiler_opt/rl/distributed/ppo_train_lib.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@
3030
from tf_agents.utils import common
3131

3232
from compiler_opt.rl import gin_external_configurables # pylint: disable=unused-import
33-
from compiler_opt.rl import agent_creators
33+
from compiler_opt.rl import agent_config
3434
from compiler_opt.rl import registry
3535
from compiler_opt.rl.distributed import learner as learner_lib
3636

@@ -58,8 +58,8 @@ def train(
5858
# Create the agent.
5959
with strategy.scope():
6060
train_step = tf.compat.v1.train.get_or_create_global_step()
61-
agent = agent_creators.create_agent(
62-
agent_creators.DistributedPPOAgentConfig(
61+
agent = agent_config.create_agent(
62+
agent_config.DistributedPPOAgentConfig(
6363
time_step_spec=time_step_spec, action_spec=action_spec),
6464
preprocessing_layer_creator=problem_config
6565
.get_preprocessing_layer_creator())

0 commit comments

Comments
 (0)