Skip to content

Commit 79d7049

Browse files
authored
add policy_utils (#279)
1 parent 45a45dd commit 79d7049

File tree

2 files changed

+374
-0
lines changed

2 files changed

+374
-0
lines changed

compiler_opt/es/policy_utils.py

Lines changed: 110 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,110 @@
1+
# coding=utf-8
2+
# Copyright 2020 Google LLC
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
"""Util functions to create and edit a tf_agent policy."""
16+
17+
import gin
18+
import numpy as np
19+
import numpy.typing as npt
20+
import tensorflow as tf
21+
from typing import Protocol, Sequence
22+
23+
from compiler_opt.rl import policy_saver, registry
24+
from tf_agents.networks import network
25+
from tf_agents.policies import actor_policy, greedy_policy, tf_policy
26+
27+
28+
class HasModelVariables(Protocol):
29+
model_variables: Sequence[tf.Variable]
30+
31+
32+
# TODO(abenalaast): Issue #280
33+
@gin.configurable(module='policy_utils')
34+
def create_actor_policy(actor_network_ctor: network.DistributionNetwork,
35+
greedy: bool = False) -> tf_policy.TFPolicy:
36+
"""Creates an actor policy."""
37+
problem_config = registry.get_configuration()
38+
time_step_spec, action_spec = problem_config.get_signature_spec()
39+
layers = tf.nest.map_structure(
40+
problem_config.get_preprocessing_layer_creator(),
41+
time_step_spec.observation)
42+
43+
actor_network = actor_network_ctor(
44+
input_tensor_spec=time_step_spec.observation,
45+
output_tensor_spec=action_spec,
46+
preprocessing_layers=layers)
47+
48+
policy = actor_policy.ActorPolicy(
49+
time_step_spec=time_step_spec,
50+
action_spec=action_spec,
51+
actor_network=actor_network)
52+
53+
if greedy:
54+
policy = greedy_policy.GreedyPolicy(policy)
55+
56+
return policy
57+
58+
59+
def get_vectorized_parameters_from_policy(
60+
policy: 'tf_policy.TFPolicy | HasModelVariables'
61+
) -> npt.NDArray[np.float32]:
62+
"""Returns a policy's variable values as a single np array."""
63+
if isinstance(policy, tf_policy.TFPolicy):
64+
variables = policy.variables()
65+
elif hasattr(policy, 'model_variables'):
66+
variables = policy.model_variables
67+
else:
68+
raise ValueError(f'Policy must be a TFPolicy or a loaded SavedModel. '
69+
f'Passed policy: {policy}')
70+
71+
parameters = [var.numpy().flatten() for var in variables]
72+
parameters = np.concatenate(parameters, axis=0)
73+
return parameters
74+
75+
76+
def set_vectorized_parameters_for_policy(
77+
policy: 'tf_policy.TFPolicy | HasModelVariables',
78+
parameters: npt.NDArray[np.float32]) -> None:
79+
"""Separates values in parameters into the policy's shapes
80+
and sets the policy variables to those values"""
81+
if isinstance(policy, tf_policy.TFPolicy):
82+
variables = policy.variables()
83+
elif hasattr(policy, 'model_variables'):
84+
variables = policy.model_variables
85+
else:
86+
raise ValueError(f'Policy must be a TFPolicy or a loaded SavedModel. '
87+
f'Passed policy: {policy}')
88+
89+
param_pos = 0
90+
for variable in variables:
91+
shape = tf.shape(variable).numpy()
92+
num_elems = np.prod(shape)
93+
param = np.reshape(parameters[param_pos:param_pos + num_elems], shape)
94+
variable.assign(param)
95+
param_pos += num_elems
96+
if param_pos != len(parameters):
97+
raise ValueError(
98+
f'Parameter dimensions are not matched! Expected {len(parameters)} '
99+
f'but only found {param_pos}.')
100+
101+
102+
def save_policy(policy: 'tf_policy.TFPolicy | HasModelVariables',
103+
parameters: npt.NDArray[np.float32], save_folder: str,
104+
policy_name: str) -> None:
105+
"""Assigns a policy the name policy_name
106+
and saves it to the directory of save_folder
107+
with the values in parameters."""
108+
set_vectorized_parameters_for_policy(policy, parameters)
109+
saver = policy_saver.PolicySaver({policy_name: policy})
110+
saver.save(save_folder)

compiler_opt/es/policy_utils_test.py

Lines changed: 264 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,264 @@
1+
# coding=utf-8
2+
# Copyright 2020 Google LLC
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
"""Tests for policy_utils."""
16+
17+
from absl.testing import absltest
18+
import numpy as np
19+
import os
20+
import tensorflow as tf
21+
from tf_agents.networks import actor_distribution_network
22+
from tf_agents.policies import actor_policy, tf_policy
23+
24+
from compiler_opt.es import policy_utils
25+
from compiler_opt.rl import policy_saver, registry
26+
from compiler_opt.rl.inlining import config as inlining_config
27+
from compiler_opt.rl.inlining import InliningConfig
28+
from compiler_opt.rl.regalloc import config as regalloc_config
29+
from compiler_opt.rl.regalloc import RegallocEvictionConfig, regalloc_network
30+
31+
32+
class ConfigTest(absltest.TestCase):
33+
34+
# TODO(abenalaast): Issue #280
35+
def test_inlining_config(self):
36+
problem_config = registry.get_configuration(implementation=InliningConfig)
37+
time_step_spec, action_spec = problem_config.get_signature_spec()
38+
quantile_file_dir = os.path.join('compiler_opt', 'rl', 'inlining', 'vocab')
39+
creator = inlining_config.get_observation_processing_layer_creator(
40+
quantile_file_dir=quantile_file_dir,
41+
with_sqrt=False,
42+
with_z_score_normalization=False)
43+
layers = tf.nest.map_structure(creator, time_step_spec.observation)
44+
45+
actor_network = actor_distribution_network.ActorDistributionNetwork(
46+
input_tensor_spec=time_step_spec.observation,
47+
output_tensor_spec=action_spec,
48+
preprocessing_layers=layers,
49+
preprocessing_combiner=tf.keras.layers.Concatenate(),
50+
fc_layer_params=(64, 64, 64, 64),
51+
dropout_layer_params=None,
52+
activation_fn=tf.keras.activations.relu)
53+
54+
policy = actor_policy.ActorPolicy(
55+
time_step_spec=time_step_spec,
56+
action_spec=action_spec,
57+
actor_network=actor_network)
58+
59+
self.assertIsNotNone(policy)
60+
self.assertIsInstance(
61+
policy._actor_network, # pylint: disable=protected-access
62+
actor_distribution_network.ActorDistributionNetwork)
63+
64+
# TODO(abenalaast): Issue #280
65+
def test_regalloc_config(self):
66+
problem_config = registry.get_configuration(
67+
implementation=RegallocEvictionConfig)
68+
time_step_spec, action_spec = problem_config.get_signature_spec()
69+
quantile_file_dir = os.path.join('compiler_opt', 'rl', 'regalloc', 'vocab')
70+
creator = regalloc_config.get_observation_processing_layer_creator(
71+
quantile_file_dir=quantile_file_dir,
72+
with_sqrt=False,
73+
with_z_score_normalization=False)
74+
layers = tf.nest.map_structure(creator, time_step_spec.observation)
75+
76+
actor_network = regalloc_network.RegAllocNetwork(
77+
input_tensor_spec=time_step_spec.observation,
78+
output_tensor_spec=action_spec,
79+
preprocessing_layers=layers,
80+
preprocessing_combiner=tf.keras.layers.Concatenate(),
81+
fc_layer_params=(64, 64, 64, 64),
82+
dropout_layer_params=None,
83+
activation_fn=tf.keras.activations.relu)
84+
85+
policy = actor_policy.ActorPolicy(
86+
time_step_spec=time_step_spec,
87+
action_spec=action_spec,
88+
actor_network=actor_network)
89+
90+
self.assertIsNotNone(policy)
91+
self.assertIsInstance(
92+
policy._actor_network, # pylint: disable=protected-access
93+
regalloc_network.RegAllocNetwork)
94+
95+
96+
class VectorTest(absltest.TestCase):
97+
98+
expected_variable_shapes = [(71, 64), (64), (64, 64), (64), (64, 64), (64),
99+
(64, 64), (64), (64, 2), (2)]
100+
expected_length_of_a_perturbation = sum(
101+
np.prod(shape) for shape in expected_variable_shapes)
102+
params = np.arange(expected_length_of_a_perturbation, dtype=np.float32)
103+
POLICY_NAME = 'test_policy_name'
104+
105+
# TODO(abenalaast): Issue #280
106+
def test_set_vectorized_parameters_for_policy(self):
107+
# create a policy
108+
problem_config = registry.get_configuration(implementation=InliningConfig)
109+
time_step_spec, action_spec = problem_config.get_signature_spec()
110+
quantile_file_dir = os.path.join('compiler_opt', 'rl', 'inlining', 'vocab')
111+
creator = inlining_config.get_observation_processing_layer_creator(
112+
quantile_file_dir=quantile_file_dir,
113+
with_sqrt=False,
114+
with_z_score_normalization=False)
115+
layers = tf.nest.map_structure(creator, time_step_spec.observation)
116+
117+
actor_network = actor_distribution_network.ActorDistributionNetwork(
118+
input_tensor_spec=time_step_spec.observation,
119+
output_tensor_spec=action_spec,
120+
preprocessing_layers=layers,
121+
preprocessing_combiner=tf.keras.layers.Concatenate(),
122+
fc_layer_params=(64, 64, 64, 64),
123+
dropout_layer_params=None,
124+
activation_fn=tf.keras.activations.relu)
125+
126+
policy = actor_policy.ActorPolicy(
127+
time_step_spec=time_step_spec,
128+
action_spec=action_spec,
129+
actor_network=actor_network)
130+
131+
# save the policy
132+
saver = policy_saver.PolicySaver({VectorTest.POLICY_NAME: policy})
133+
testing_path = self.create_tempdir()
134+
policy_save_path = os.path.join(testing_path, 'temp_output', 'policy')
135+
saver.save(policy_save_path)
136+
137+
# set the values of the policy variables
138+
policy_utils.set_vectorized_parameters_for_policy(policy, VectorTest.params)
139+
# iterate through variables and check their shapes and values
140+
# deep copy params in order to destructively iterate over values
141+
expected_values = [*VectorTest.params]
142+
for i, variable in enumerate(policy.variables()): # pylint: disable=not-callable
143+
self.assertEqual(variable.shape, VectorTest.expected_variable_shapes[i])
144+
variable_values = variable.numpy().flatten()
145+
np.testing.assert_array_almost_equal(
146+
expected_values[:len(variable_values)], variable_values)
147+
expected_values = expected_values[len(variable_values):]
148+
# all values in the copy should have been removed at this point
149+
self.assertEmpty(expected_values)
150+
151+
# get saved model to test a loaded policy
152+
load_path = os.path.join(policy_save_path, VectorTest.POLICY_NAME)
153+
sm = tf.saved_model.load(load_path)
154+
self.assertNotIsInstance(sm, tf_policy.TFPolicy)
155+
policy_utils.set_vectorized_parameters_for_policy(sm, VectorTest.params)
156+
# deep copy params in order to destructively iterate over values
157+
expected_values = [*VectorTest.params]
158+
for i, variable in enumerate(sm.model_variables):
159+
self.assertEqual(variable.shape, VectorTest.expected_variable_shapes[i])
160+
variable_values = variable.numpy().flatten()
161+
np.testing.assert_array_almost_equal(
162+
expected_values[:len(variable_values)], variable_values)
163+
expected_values = expected_values[len(variable_values):]
164+
# all values in the copy should have been removed at this point
165+
self.assertEmpty(expected_values)
166+
167+
# TODO(abenalaast): Issue #280
168+
def test_get_vectorized_parameters_from_policy(self):
169+
# create a policy
170+
problem_config = registry.get_configuration(implementation=InliningConfig)
171+
time_step_spec, action_spec = problem_config.get_signature_spec()
172+
quantile_file_dir = os.path.join('compiler_opt', 'rl', 'inlining', 'vocab')
173+
creator = inlining_config.get_observation_processing_layer_creator(
174+
quantile_file_dir=quantile_file_dir,
175+
with_sqrt=False,
176+
with_z_score_normalization=False)
177+
layers = tf.nest.map_structure(creator, time_step_spec.observation)
178+
179+
actor_network = actor_distribution_network.ActorDistributionNetwork(
180+
input_tensor_spec=time_step_spec.observation,
181+
output_tensor_spec=action_spec,
182+
preprocessing_layers=layers,
183+
preprocessing_combiner=tf.keras.layers.Concatenate(),
184+
fc_layer_params=(64, 64, 64, 64),
185+
dropout_layer_params=None,
186+
activation_fn=tf.keras.activations.relu)
187+
188+
policy = actor_policy.ActorPolicy(
189+
time_step_spec=time_step_spec,
190+
action_spec=action_spec,
191+
actor_network=actor_network)
192+
193+
# save the policy
194+
saver = policy_saver.PolicySaver({VectorTest.POLICY_NAME: policy})
195+
testing_path = self.create_tempdir()
196+
policy_save_path = os.path.join(testing_path, 'temp_output', 'policy')
197+
saver.save(policy_save_path)
198+
199+
# functionality verified in previous test
200+
policy_utils.set_vectorized_parameters_for_policy(policy, VectorTest.params)
201+
# vectorize and check if the outcome is the same as the start
202+
output = policy_utils.get_vectorized_parameters_from_policy(policy)
203+
np.testing.assert_array_almost_equal(output, VectorTest.params)
204+
205+
# get saved model to test a loaded policy
206+
load_path = os.path.join(policy_save_path, VectorTest.POLICY_NAME)
207+
sm = tf.saved_model.load(load_path)
208+
self.assertNotIsInstance(sm, tf_policy.TFPolicy)
209+
policy_utils.set_vectorized_parameters_for_policy(sm, VectorTest.params)
210+
# vectorize and check if the outcome is the same as the start
211+
output = policy_utils.get_vectorized_parameters_from_policy(sm)
212+
np.testing.assert_array_almost_equal(output, VectorTest.params)
213+
214+
# TODO(abenalaast): Issue #280
215+
def test_tfpolicy_and_loaded_policy_produce_same_variable_order(self):
216+
# create a policy
217+
problem_config = registry.get_configuration(implementation=InliningConfig)
218+
time_step_spec, action_spec = problem_config.get_signature_spec()
219+
quantile_file_dir = os.path.join('compiler_opt', 'rl', 'inlining', 'vocab')
220+
creator = inlining_config.get_observation_processing_layer_creator(
221+
quantile_file_dir=quantile_file_dir,
222+
with_sqrt=False,
223+
with_z_score_normalization=False)
224+
layers = tf.nest.map_structure(creator, time_step_spec.observation)
225+
226+
actor_network = actor_distribution_network.ActorDistributionNetwork(
227+
input_tensor_spec=time_step_spec.observation,
228+
output_tensor_spec=action_spec,
229+
preprocessing_layers=layers,
230+
preprocessing_combiner=tf.keras.layers.Concatenate(),
231+
fc_layer_params=(64, 64, 64, 64),
232+
dropout_layer_params=None,
233+
activation_fn=tf.keras.activations.relu)
234+
235+
policy = actor_policy.ActorPolicy(
236+
time_step_spec=time_step_spec,
237+
action_spec=action_spec,
238+
actor_network=actor_network)
239+
240+
# save the policy
241+
saver = policy_saver.PolicySaver({VectorTest.POLICY_NAME: policy})
242+
testing_path = self.create_tempdir()
243+
policy_save_path = os.path.join(testing_path, 'temp_output', 'policy')
244+
saver.save(policy_save_path)
245+
246+
# set the values of the variables
247+
policy_utils.set_vectorized_parameters_for_policy(policy, VectorTest.params)
248+
# save the changes
249+
saver.save(policy_save_path)
250+
# vectorize the tfpolicy
251+
tf_params = policy_utils.get_vectorized_parameters_from_policy(policy)
252+
253+
# get loaded policy
254+
load_path = os.path.join(policy_save_path, VectorTest.POLICY_NAME)
255+
sm = tf.saved_model.load(load_path)
256+
# vectorize the loaded policy
257+
loaded_params = policy_utils.get_vectorized_parameters_from_policy(sm)
258+
259+
# assert that they result in the same order of values
260+
np.testing.assert_array_almost_equal(tf_params, loaded_params)
261+
262+
263+
if __name__ == '__main__':
264+
absltest.main()

0 commit comments

Comments
 (0)