|
15 | 15 |
|
16 | 16 | from typing import Protocol
|
17 | 17 | from collections.abc import Sequence
|
| 18 | +import os |
18 | 19 |
|
19 | 20 | import gin
|
20 | 21 | import numpy as np
|
@@ -122,3 +123,32 @@ def save_policy(policy: 'tf_policy.TFPolicy | HasModelVariables',
|
122 | 123 | set_vectorized_parameters_for_policy(policy, parameters)
|
123 | 124 | saver = policy_saver.PolicySaver({policy_name: policy})
|
124 | 125 | saver.save(save_folder)
|
| 126 | + |
| 127 | + |
| 128 | +def convert_to_tflite(policy_as_bytes: bytes, scratch_dir: str, |
| 129 | + base_policy_path: str) -> str: |
| 130 | + """Converts a policy serialized to bytes to TFLite. |
| 131 | +
|
| 132 | + Args: |
| 133 | + policy_as_bytes: An array of model parameters serialized to a byte stream. |
| 134 | + scratch_dir: A temporary directory being used for scratch that the model |
| 135 | + will get saved into. |
| 136 | + base_policy_path: The path to the base TF saved model that is used to |
| 137 | + determine the model architecture. |
| 138 | + """ |
| 139 | + perturbation = np.frombuffer(policy_as_bytes, dtype=np.float32) |
| 140 | + |
| 141 | + saved_model = tf.saved_model.load(base_policy_path) |
| 142 | + set_vectorized_parameters_for_policy(saved_model, perturbation) |
| 143 | + |
| 144 | + saved_model_dir = os.path.join(scratch_dir, 'saved_model') |
| 145 | + tf.saved_model.save( |
| 146 | + saved_model, saved_model_dir, signatures=saved_model.signatures) |
| 147 | + source = os.path.join(base_policy_path, policy_saver.OUTPUT_SIGNATURE) |
| 148 | + destination = os.path.join(saved_model_dir, policy_saver.OUTPUT_SIGNATURE) |
| 149 | + tf.io.gfile.copy(source, destination) |
| 150 | + |
| 151 | + # convert to tflite |
| 152 | + tflite_dir = os.path.join(scratch_dir, 'tflite') |
| 153 | + policy_saver.convert_mlgo_model(saved_model_dir, tflite_dir) |
| 154 | + return tflite_dir |
0 commit comments