Skip to content

Commit 079d4c7

Browse files
Add convert_to_tflite to policy_utils
This patch takes the function that converts policies serialized to numpy arrays serialized to bytes into TFLite models to policy_utils. This is necessary so that we can perform this on the worker rather than trying to perform it serially on the main invocation. Reviewers: mtrofin Reviewed By: mtrofin Pull Request: #447
1 parent 50faa11 commit 079d4c7

File tree

2 files changed

+51
-0
lines changed

2 files changed

+51
-0
lines changed

compiler_opt/es/policy_utils.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515

1616
from typing import Protocol
1717
from collections.abc import Sequence
18+
import os
1819

1920
import gin
2021
import numpy as np
@@ -122,3 +123,32 @@ def save_policy(policy: 'tf_policy.TFPolicy | HasModelVariables',
122123
set_vectorized_parameters_for_policy(policy, parameters)
123124
saver = policy_saver.PolicySaver({policy_name: policy})
124125
saver.save(save_folder)
126+
127+
128+
def convert_to_tflite(policy_as_bytes: bytes, scratch_dir: str,
129+
base_policy_path: str) -> str:
130+
"""Converts a policy serialized to bytes to TFLite.
131+
132+
Args:
133+
policy_as_bytes: An array of model parameters serialized to a byte stream.
134+
scratch_dir: A temporary directory being used for scratch that the model
135+
will get saved into.
136+
base_policy_path: The path to the base TF saved model that is used to
137+
determine the model architecture.
138+
"""
139+
perturbation = np.frombuffer(policy_as_bytes, dtype=np.float32)
140+
141+
saved_model = tf.saved_model.load(base_policy_path)
142+
set_vectorized_parameters_for_policy(saved_model, perturbation)
143+
144+
saved_model_dir = os.path.join(scratch_dir, 'saved_model')
145+
tf.saved_model.save(
146+
saved_model, saved_model_dir, signatures=saved_model.signatures)
147+
source = os.path.join(base_policy_path, policy_saver.OUTPUT_SIGNATURE)
148+
destination = os.path.join(saved_model_dir, policy_saver.OUTPUT_SIGNATURE)
149+
tf.io.gfile.copy(source, destination)
150+
151+
# convert to tflite
152+
tflite_dir = os.path.join(scratch_dir, 'tflite')
153+
policy_saver.convert_mlgo_model(saved_model_dir, tflite_dir)
154+
return tflite_dir

compiler_opt/es/policy_utils_test.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -220,6 +220,27 @@ def test_tfpolicy_and_loaded_policy_produce_same_variable_order(self):
220220
# assert that they result in the same order of values
221221
np.testing.assert_array_almost_equal(tf_params, loaded_params)
222222

223+
def test_convert_to_tflite(self):
224+
policy_save_path, _, _ = self._save_inlining_policy()
225+
saved_model_path = os.path.join(policy_save_path, self.POLICY_NAME)
226+
227+
output_bytes = self.params.tobytes()
228+
229+
scratch_dir = self.create_tempdir()
230+
tflite_dir = policy_utils.convert_to_tflite(output_bytes, scratch_dir,
231+
saved_model_path)
232+
233+
self.assertTrue(os.path.exists(os.path.join(tflite_dir, 'model.tflite')))
234+
self.assertTrue(
235+
os.path.exists(os.path.join(tflite_dir, 'output_spec.json')))
236+
237+
# Additionally assert that the saved model that we create as part of the
238+
# conversion process has the correct paramters.
239+
load_path = os.path.join(scratch_dir, 'saved_model')
240+
sm = tf.saved_model.load(load_path)
241+
loaded_params = policy_utils.get_vectorized_parameters_from_policy(sm)
242+
np.testing.assert_array_almost_equal(self.params, loaded_params)
243+
223244

224245
if __name__ == '__main__':
225246
absltest.main()

0 commit comments

Comments
 (0)