@@ -110,9 +110,8 @@ class VectorTest(absltest.TestCase):
110
110
params = np .arange (expected_length_of_a_perturbation , dtype = np .float32 )
111
111
POLICY_NAME = 'test_policy_name'
112
112
113
- # TODO(abenalaast): Issue #280
114
- def test_set_vectorized_parameters_for_policy (self ):
115
- # create a policy
113
+ def _save_inlining_policy (
114
+ self ) -> tuple [str , actor_policy .ActorPolicy , policy_saver .PolicySaver ]:
116
115
problem_config = registry .get_configuration (
117
116
implementation = inlining .InliningConfig )
118
117
time_step_spec , action_spec = problem_config .get_signature_spec ()
@@ -143,6 +142,12 @@ def test_set_vectorized_parameters_for_policy(self):
143
142
testing_path = self .create_tempdir ()
144
143
policy_save_path = os .path .join (testing_path , 'temp_output' , 'policy' )
145
144
saver .save (policy_save_path )
145
+ return (policy_save_path , policy , saver )
146
+
147
+ # TODO(abenalaast): Issue #280
148
+ def test_set_vectorized_parameters_for_policy (self ):
149
+ # create a policy
150
+ policy_save_path , policy , _ = self ._save_inlining_policy ()
146
151
147
152
# set the values of the policy variables
148
153
policy_utils .set_vectorized_parameters_for_policy (policy , VectorTest .params )
@@ -177,36 +182,7 @@ def test_set_vectorized_parameters_for_policy(self):
177
182
# TODO(abenalaast): Issue #280
178
183
def test_get_vectorized_parameters_from_policy (self ):
179
184
# create a policy
180
- problem_config = registry .get_configuration (
181
- implementation = inlining .InliningConfig )
182
- time_step_spec , action_spec = problem_config .get_signature_spec ()
183
- quantile_file_dir = os .path .join (TEST_PATH_PREFIX , 'compiler_opt' , 'rl' ,
184
- 'inlining' , 'vocab' )
185
- creator = inlining_config .get_observation_processing_layer_creator (
186
- quantile_file_dir = quantile_file_dir ,
187
- with_sqrt = False ,
188
- with_z_score_normalization = False )
189
- layers = tf .nest .map_structure (creator , time_step_spec .observation )
190
-
191
- actor_network = actor_distribution_network .ActorDistributionNetwork (
192
- input_tensor_spec = time_step_spec .observation ,
193
- output_tensor_spec = action_spec ,
194
- preprocessing_layers = layers ,
195
- preprocessing_combiner = tf .keras .layers .Concatenate (),
196
- fc_layer_params = (64 , 64 , 64 , 64 ),
197
- dropout_layer_params = None ,
198
- activation_fn = tf .keras .activations .relu )
199
-
200
- policy = actor_policy .ActorPolicy (
201
- time_step_spec = time_step_spec ,
202
- action_spec = action_spec ,
203
- actor_network = actor_network )
204
-
205
- # save the policy
206
- saver = policy_saver .PolicySaver ({VectorTest .POLICY_NAME : policy })
207
- testing_path = self .create_tempdir ()
208
- policy_save_path = os .path .join (testing_path , 'temp_output' , 'policy' )
209
- saver .save (policy_save_path )
185
+ policy_save_path , policy , _ = self ._save_inlining_policy ()
210
186
211
187
# functionality verified in previous test
212
188
policy_utils .set_vectorized_parameters_for_policy (policy , VectorTest .params )
@@ -226,36 +202,7 @@ def test_get_vectorized_parameters_from_policy(self):
226
202
# TODO(abenalaast): Issue #280
227
203
def test_tfpolicy_and_loaded_policy_produce_same_variable_order (self ):
228
204
# create a policy
229
- problem_config = registry .get_configuration (
230
- implementation = inlining .InliningConfig )
231
- time_step_spec , action_spec = problem_config .get_signature_spec ()
232
- quantile_file_dir = os .path .join (TEST_PATH_PREFIX , 'compiler_opt' , 'rl' ,
233
- 'inlining' , 'vocab' )
234
- creator = inlining_config .get_observation_processing_layer_creator (
235
- quantile_file_dir = quantile_file_dir ,
236
- with_sqrt = False ,
237
- with_z_score_normalization = False )
238
- layers = tf .nest .map_structure (creator , time_step_spec .observation )
239
-
240
- actor_network = actor_distribution_network .ActorDistributionNetwork (
241
- input_tensor_spec = time_step_spec .observation ,
242
- output_tensor_spec = action_spec ,
243
- preprocessing_layers = layers ,
244
- preprocessing_combiner = tf .keras .layers .Concatenate (),
245
- fc_layer_params = (64 , 64 , 64 , 64 ),
246
- dropout_layer_params = None ,
247
- activation_fn = tf .keras .activations .relu )
248
-
249
- policy = actor_policy .ActorPolicy (
250
- time_step_spec = time_step_spec ,
251
- action_spec = action_spec ,
252
- actor_network = actor_network )
253
-
254
- # save the policy
255
- saver = policy_saver .PolicySaver ({VectorTest .POLICY_NAME : policy })
256
- testing_path = self .create_tempdir ()
257
- policy_save_path = os .path .join (testing_path , 'temp_output' , 'policy' )
258
- saver .save (policy_save_path )
205
+ policy_save_path , policy , saver = self ._save_inlining_policy ()
259
206
260
207
# set the values of the variables
261
208
policy_utils .set_vectorized_parameters_for_policy (policy , VectorTest .params )
0 commit comments