Skip to content

Commit 457a8c2

Browse files
authored
a script for generating test model. (#125)
1 parent c892802 commit 457a8c2

File tree

1 file changed

+75
-0
lines changed

1 file changed

+75
-0
lines changed
Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
# coding=utf-8
2+
# Copyright 2020 Google LLC
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
"""Generate test model given a problem and an algorithm."""
16+
17+
import os
18+
19+
from absl import app
20+
from absl import flags
21+
from absl import logging
22+
23+
import gin
24+
25+
from compiler_opt.rl import agent_creators
26+
from compiler_opt.rl import constant
27+
from compiler_opt.rl import gin_external_configurables # pylint: disable=unused-import
28+
from compiler_opt.rl import policy_saver
29+
from compiler_opt.rl import registry
30+
31+
flags.DEFINE_string('root_dir', os.getenv('TEST_UNDECLARED_OUTPUTS_DIR'),
32+
'Root directory for writing saved models.')
33+
flags.DEFINE_multi_string('gin_files', [],
34+
'List of paths to gin configuration files.')
35+
flags.DEFINE_multi_string(
36+
'gin_bindings', [],
37+
'Gin bindings to override the values set in the config files.')
38+
39+
FLAGS = flags.FLAGS
40+
41+
42+
@gin.configurable
43+
def generate_test_model(agent_name=constant.AgentName.PPO):
44+
"""Generate test model."""
45+
root_dir = FLAGS.root_dir
46+
47+
problem_config = registry.get_configuration()
48+
time_step_spec, action_spec = problem_config.get_signature_spec()
49+
preprocessing_layer_creator = problem_config.get_preprocessing_layer_creator()
50+
51+
# Initialize trainer and policy saver.
52+
tf_agent = agent_creators.create_agent(agent_name, time_step_spec,
53+
action_spec,
54+
preprocessing_layer_creator)
55+
56+
policy_dict = {
57+
'saved_policy': tf_agent.policy,
58+
'saved_collect_policy': tf_agent.collect_policy,
59+
}
60+
saver = policy_saver.PolicySaver(policy_dict=policy_dict)
61+
62+
# Save policy.
63+
saver.save(root_dir)
64+
65+
66+
def main(_):
67+
gin.parse_config_files_and_bindings(
68+
FLAGS.gin_files, bindings=FLAGS.gin_bindings, skip_unknown=True)
69+
logging.info(gin.config_str())
70+
71+
generate_test_model()
72+
73+
74+
if __name__ == '__main__':
75+
app.run(main)

0 commit comments

Comments
 (0)