Skip to content

Commit 6354047

Browse files
Refactor policy creation in policy_utils_test
This patch refactors out common policy creation functionality for three tests into a helper function to get rid of some trivial code duplication. Seems like this is a bit of a short term fix for #280. Reviewers: mtrofin Reviewed By: mtrofin Pull Request: #446
1 parent 7037b06 commit 6354047

File tree

1 file changed

+10
-63
lines changed

1 file changed

+10
-63
lines changed

compiler_opt/es/policy_utils_test.py

Lines changed: 10 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -110,9 +110,8 @@ class VectorTest(absltest.TestCase):
110110
params = np.arange(expected_length_of_a_perturbation, dtype=np.float32)
111111
POLICY_NAME = 'test_policy_name'
112112

113-
# TODO(abenalaast): Issue #280
114-
def test_set_vectorized_parameters_for_policy(self):
115-
# create a policy
113+
def _save_inlining_policy(
114+
self) -> tuple[str, actor_policy.ActorPolicy, policy_saver.PolicySaver]:
116115
problem_config = registry.get_configuration(
117116
implementation=inlining.InliningConfig)
118117
time_step_spec, action_spec = problem_config.get_signature_spec()
@@ -143,6 +142,12 @@ def test_set_vectorized_parameters_for_policy(self):
143142
testing_path = self.create_tempdir()
144143
policy_save_path = os.path.join(testing_path, 'temp_output', 'policy')
145144
saver.save(policy_save_path)
145+
return (policy_save_path, policy, saver)
146+
147+
# TODO(abenalaast): Issue #280
148+
def test_set_vectorized_parameters_for_policy(self):
149+
# create a policy
150+
policy_save_path, policy, _ = self._save_inlining_policy()
146151

147152
# set the values of the policy variables
148153
policy_utils.set_vectorized_parameters_for_policy(policy, VectorTest.params)
@@ -177,36 +182,7 @@ def test_set_vectorized_parameters_for_policy(self):
177182
# TODO(abenalaast): Issue #280
178183
def test_get_vectorized_parameters_from_policy(self):
179184
# create a policy
180-
problem_config = registry.get_configuration(
181-
implementation=inlining.InliningConfig)
182-
time_step_spec, action_spec = problem_config.get_signature_spec()
183-
quantile_file_dir = os.path.join(TEST_PATH_PREFIX, 'compiler_opt', 'rl',
184-
'inlining', 'vocab')
185-
creator = inlining_config.get_observation_processing_layer_creator(
186-
quantile_file_dir=quantile_file_dir,
187-
with_sqrt=False,
188-
with_z_score_normalization=False)
189-
layers = tf.nest.map_structure(creator, time_step_spec.observation)
190-
191-
actor_network = actor_distribution_network.ActorDistributionNetwork(
192-
input_tensor_spec=time_step_spec.observation,
193-
output_tensor_spec=action_spec,
194-
preprocessing_layers=layers,
195-
preprocessing_combiner=tf.keras.layers.Concatenate(),
196-
fc_layer_params=(64, 64, 64, 64),
197-
dropout_layer_params=None,
198-
activation_fn=tf.keras.activations.relu)
199-
200-
policy = actor_policy.ActorPolicy(
201-
time_step_spec=time_step_spec,
202-
action_spec=action_spec,
203-
actor_network=actor_network)
204-
205-
# save the policy
206-
saver = policy_saver.PolicySaver({VectorTest.POLICY_NAME: policy})
207-
testing_path = self.create_tempdir()
208-
policy_save_path = os.path.join(testing_path, 'temp_output', 'policy')
209-
saver.save(policy_save_path)
185+
policy_save_path, policy, _ = self._save_inlining_policy()
210186

211187
# functionality verified in previous test
212188
policy_utils.set_vectorized_parameters_for_policy(policy, VectorTest.params)
@@ -226,36 +202,7 @@ def test_get_vectorized_parameters_from_policy(self):
226202
# TODO(abenalaast): Issue #280
227203
def test_tfpolicy_and_loaded_policy_produce_same_variable_order(self):
228204
# create a policy
229-
problem_config = registry.get_configuration(
230-
implementation=inlining.InliningConfig)
231-
time_step_spec, action_spec = problem_config.get_signature_spec()
232-
quantile_file_dir = os.path.join(TEST_PATH_PREFIX, 'compiler_opt', 'rl',
233-
'inlining', 'vocab')
234-
creator = inlining_config.get_observation_processing_layer_creator(
235-
quantile_file_dir=quantile_file_dir,
236-
with_sqrt=False,
237-
with_z_score_normalization=False)
238-
layers = tf.nest.map_structure(creator, time_step_spec.observation)
239-
240-
actor_network = actor_distribution_network.ActorDistributionNetwork(
241-
input_tensor_spec=time_step_spec.observation,
242-
output_tensor_spec=action_spec,
243-
preprocessing_layers=layers,
244-
preprocessing_combiner=tf.keras.layers.Concatenate(),
245-
fc_layer_params=(64, 64, 64, 64),
246-
dropout_layer_params=None,
247-
activation_fn=tf.keras.activations.relu)
248-
249-
policy = actor_policy.ActorPolicy(
250-
time_step_spec=time_step_spec,
251-
action_spec=action_spec,
252-
actor_network=actor_network)
253-
254-
# save the policy
255-
saver = policy_saver.PolicySaver({VectorTest.POLICY_NAME: policy})
256-
testing_path = self.create_tempdir()
257-
policy_save_path = os.path.join(testing_path, 'temp_output', 'policy')
258-
saver.save(policy_save_path)
205+
policy_save_path, policy, saver = self._save_inlining_policy()
259206

260207
# set the values of the variables
261208
policy_utils.set_vectorized_parameters_for_policy(policy, VectorTest.params)

0 commit comments

Comments
 (0)