Skip to content

Commit b82deb3

Browse files
authored
Fix spurious AgentConfig member reference (#277)
1 parent 5bf8776 commit b82deb3

File tree

3 files changed

+3
-3
lines changed

3 files changed

+3
-3
lines changed

compiler_opt/rl/distributed/ppo_collect_lib.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -124,7 +124,7 @@ def collect(corpus_path: str, replay_buffer_server_address: str,
124124
agent_cfg = agent_config.DistributedPPOAgentConfig(
125125
time_step_spec=time_step_spec, action_spec=action_spec)
126126
agent = agent_config.create_agent(
127-
agent_cfg.agent,
127+
agent_cfg,
128128
preprocessing_layer_creator=problem_config
129129
.get_preprocessing_layer_creator())
130130

compiler_opt/rl/distributed/ppo_eval_lib.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ def evaluate(root_dir: str, corpus_path: str,
6060
agent_cfg = agent_config.DistributedPPOAgentConfig(
6161
time_step_spec=time_step_spec, action_spec=action_spec)
6262
agent = agent_config.create_agent(
63-
agent_cfg.agent,
63+
agent_cfg,
6464
preprocessing_layer_creator=problem_config
6565
.get_preprocessing_layer_creator())
6666

compiler_opt/rl/train_locally.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ def train_eval(worker_manager_class=LocalWorkerPoolManager,
8282
agent_cfg = agent_config_type(
8383
time_step_spec=time_step_spec, action_spec=action_spec)
8484
agent: tf_agent.TFAgent = agent_config.create_agent(
85-
agent_cfg.agent, preprocessing_layer_creator=preprocessing_layer_creator)
85+
agent_cfg, preprocessing_layer_creator=preprocessing_layer_creator)
8686
# create the random network distillation object
8787
random_network_distillation = None
8888
if use_random_network_distillation:

0 commit comments

Comments
 (0)