Skip to content

Commit d5f076d

Browse files
authored
Type APIs working with policies (#66)
This clarifies we use TFAgents rather than raw saved models, and helps with readability and discoverability (the latter in an IDE in particular)
1 parent fca414f commit d5f076d

File tree

5 files changed

+34
-20
lines changed

5 files changed

+34
-20
lines changed

compiler_opt/rl/agent_creators.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@ def create_agent(agent_name: constant.AgentName,
8787
action_spec: types.NestedTensorSpec,
8888
preprocessing_layer_creator: Callable[[types.TensorSpec],
8989
tf.keras.layers.Layer],
90-
policy_network: types.Network):
90+
policy_network: types.Network) -> TFAgent:
9191
"""Creates a tfa.agents.TFAgent object.
9292
9393
Args:

compiler_opt/rl/policy_saver.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,11 @@
1818
import os
1919

2020
import tensorflow as tf
21+
from tf_agents.policies import TFPolicy
2122
from tf_agents.policies import policy_saver
2223

24+
from typing import Dict, Tuple
25+
2326
OUTPUT_SIGNATURE = 'output_spec.json'
2427

2528
_TYPE_CONVERSION_DICT = {
@@ -74,17 +77,18 @@ class PolicySaver(object):
7477
```
7578
"""
7679

77-
def __init__(self, policy_dict):
80+
def __init__(self, policy_dict: Dict[str, TFPolicy]):
7881
"""Initialize the PolicySaver object.
7982
8083
Args:
8184
policy_dict: A dict mapping from policy name to policy.
8285
"""
83-
self._policy_saver_dict = {
84-
policy_name: (policy_saver.PolicySaver(
85-
policy, batch_size=1, use_nest_path_signatures=False), policy)
86-
for policy_name, policy in policy_dict.items()
87-
}
86+
self._policy_saver_dict: Dict[str, Tuple[
87+
policy_saver.PolicySaver, TFPolicy]] = {
88+
policy_name: (policy_saver.PolicySaver(
89+
policy, batch_size=1, use_nest_path_signatures=False), policy)
90+
for policy_name, policy in policy_dict.items()
91+
}
8892

8993
def _save_policy(self, saver, path):
9094
"""Writes policy, model weights and model_binding.txt to path/."""
@@ -149,7 +153,7 @@ def _write_output_signature(self, saver, path):
149153
with tf.io.gfile.GFile(os.path.join(path, OUTPUT_SIGNATURE), 'w') as f:
150154
f.write(json.dumps(output_list))
151155

152-
def save(self, root_dir):
156+
def save(self, root_dir: str):
153157
"""Writes policy and model_binding.txt to root_dir/policy_name/."""
154158
for policy_name, (saver, _) in self._policy_saver_dict.items():
155159
self._save_policy(saver, os.path.join(root_dir, policy_name))

compiler_opt/rl/train_bc.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,11 @@
3030
from compiler_opt.rl import registry
3131
from compiler_opt.rl import trainer
3232

33+
from tf_agents.agents import TFAgent
34+
from tf_agents.policies import TFPolicy
35+
36+
from typing import Dict
37+
3338
_ROOT_DIR = flags.DEFINE_string(
3439
'root_dir', os.getenv('TEST_UNDECLARED_OUTPUTS_DIR'),
3540
'Root directory for writing logs/summaries/checkpoints.')
@@ -58,11 +63,11 @@ def train_eval(agent_name=constant.AgentName.BEHAVIORAL_CLONE,
5863
preprocessing_layer_creator = problem_config.get_preprocessing_layer_creator()
5964

6065
# Initialize trainer and policy saver.
61-
tf_agent = agent_creators.create_agent(agent_name, time_step_spec,
62-
action_spec,
63-
preprocessing_layer_creator)
66+
tf_agent: TFAgent = agent_creators.create_agent(agent_name, time_step_spec,
67+
action_spec,
68+
preprocessing_layer_creator)
6469
llvm_trainer = trainer.Trainer(root_dir=root_dir, agent=tf_agent)
65-
policy_dict = {
70+
policy_dict: Dict[str, TFPolicy] = {
6671
'saved_policy': tf_agent.policy,
6772
'saved_collect_policy': tf_agent.collect_policy,
6873
}

compiler_opt/rl/train_locally.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
from absl import logging
2626
import gin
2727
import tensorflow as tf
28+
from tf_agents.agents import TFAgent
2829
from tf_agents.system import system_multiprocessing as multiprocessing
2930
from typing import List
3031

@@ -77,9 +78,9 @@ def train_eval(agent_name=constant.AgentName.PPO,
7778
preprocessing_layer_creator = problem_config.get_preprocessing_layer_creator()
7879

7980
# Initialize trainer and policy saver.
80-
tf_agent = agent_creators.create_agent(agent_name, time_step_spec,
81-
action_spec,
82-
preprocessing_layer_creator)
81+
tf_agent: TFAgent = agent_creators.create_agent(agent_name, time_step_spec,
82+
action_spec,
83+
preprocessing_layer_creator)
8384
# create the random network distillation object
8485
random_network_distillation = None
8586
if use_random_network_distillation:

compiler_opt/rl/trainer.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,12 @@
2020

2121
import gin
2222
import tensorflow as tf
23+
from compiler_opt.rl import random_net_distillation
24+
from tf_agents.agents import TFAgent
2325
from tf_agents.policies import policy_loader
2426

2527
from tf_agents.utils import common as common_utils
28+
from typing import Optional
2629

2730
_INLINING_DEFAULT_KEY = 'inlining_default'
2831

@@ -43,10 +46,11 @@ class Trainer(object):
4346

4447
def __init__(
4548
self,
46-
root_dir,
47-
agent,
48-
random_network_distillation=None,
49-
warmstart_policy_dir=None,
49+
root_dir: str,
50+
agent: TFAgent,
51+
random_network_distillation: Optional[
52+
random_net_distillation.RandomNetworkDistillation] = None,
53+
warmstart_policy_dir: Optional[str] = None,
5054
# Params for summaries and logging
5155
checkpoint_interval=10000,
5256
log_interval=100,
@@ -180,7 +184,7 @@ def _save_checkpoint(self):
180184
def global_step_numpy(self):
181185
return self._global_step.numpy()
182186

183-
def train(self, dataset_iter, monitor_dict, num_iterations):
187+
def train(self, dataset_iter, monitor_dict, num_iterations: int):
184188
"""Trains policy with data from dataset_iter for num_iterations steps."""
185189
self._reset_metrics()
186190
# context management is implemented in decorator

0 commit comments

Comments
 (0)