Skip to content

Commit c892802

Browse files
authored
TFLite: when saving a saved model, also save the tflite model (#120)
Currently just saving it next to the saved model, so we can start iterating using tflite models with llvm. Eventually we can get more efficient.
1 parent 7175852 commit c892802

File tree

2 files changed

+51
-43
lines changed

2 files changed

+51
-43
lines changed

compiler_opt/rl/policy_saver.py

Lines changed: 49 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,47 @@ def _get_non_identity_op(tensor):
6969
return tensor
7070

7171

72+
def convert_saved_model(sm_dir: str, tflite_model_path: str):
73+
"""Convert a saved model to tflite.
74+
75+
Args:
76+
sm_dir: path to the saved model to convert
77+
78+
tflite_model_path: desired output file path. Directory structure will
79+
be created by this function, as needed.
80+
"""
81+
tf.io.gfile.makedirs(os.path.dirname(tflite_model_path))
82+
converter = tf.lite.TFLiteConverter.from_saved_model(sm_dir)
83+
converter.target_spec.supported_ops = [
84+
tf.lite.OpsSet.TFLITE_BUILTINS,
85+
]
86+
converter.allow_custom_ops = True
87+
tfl_model = converter.convert()
88+
with tf.io.gfile.GFile(tflite_model_path, 'wb') as f:
89+
f.write(tfl_model)
90+
91+
92+
def convert_mlgo_model(mlgo_model_dir: str, tflite_model_dir: str):
93+
"""Convert a mlgo saved model to mlgo tflite.
94+
95+
Args:
96+
mlgo_model_dir: path to the mlgo saved model dir. It is expected to contain
97+
the saved model files (i.e. saved_model.pb, the variables dir) and the
98+
output_spec.json file
99+
100+
tflite_model_dir: path to a directory where the tflite model will be placed.
101+
The model will be named model.tflite. Alongside it will be placed a copy
102+
of the output_spec.json file.
103+
"""
104+
tf.io.gfile.makedirs(tflite_model_dir)
105+
convert_saved_model(mlgo_model_dir,
106+
os.path.join(tflite_model_dir, TFLITE_MODEL_NAME))
107+
108+
src_json = os.path.join(mlgo_model_dir, OUTPUT_SIGNATURE)
109+
dest_json = os.path.join(tflite_model_dir, OUTPUT_SIGNATURE)
110+
tf.io.gfile.copy(src_json, dest_json)
111+
112+
72113
class PolicySaver(object):
73114
"""Object that saves policy and model config file required by inference.
74115
@@ -157,46 +198,11 @@ def _write_output_signature(self, saver, path):
157198
def save(self, root_dir: str):
158199
"""Writes policy and model_binding.txt to root_dir/policy_name/."""
159200
for policy_name, (saver, _) in self._policy_saver_dict.items():
160-
self._save_policy(saver, os.path.join(root_dir, policy_name))
161-
self._write_output_signature(saver, os.path.join(root_dir, policy_name))
162-
163-
164-
def convert_saved_model(sm_dir: str, tflite_model_path: str):
165-
"""Convert a saved model to tflite.
166-
167-
Args:
168-
sm_dir: path to the saved model to convert
169-
170-
tflite_model_path: desired output file path. Directory structure will
171-
be created by this function, as needed.
172-
"""
173-
tf.io.gfile.makedirs(os.path.dirname(tflite_model_path))
174-
converter = tf.lite.TFLiteConverter.from_saved_model(sm_dir)
175-
converter.target_spec.supported_ops = [
176-
tf.lite.OpsSet.TFLITE_BUILTINS,
177-
]
178-
tfl_model = converter.convert()
179-
with tf.io.gfile.GFile(tflite_model_path, 'wb') as f:
180-
f.write(tfl_model)
181-
182-
183-
def convert_mlgo_model(mlgo_model_dir: str, tflite_model_dir: str):
184-
"""Convert a mlgo saved model to mlgo tflite.
185-
186-
Args:
187-
mlgo_model_dir: path to the mlgo saved model dir. It is expected to contain
188-
the saved model files (i.e. saved_model.pb, the variables dir) and the
189-
output_spec.json file
190-
191-
tflite_model_dir: path to a directory where the tflite model will be placed.
192-
The model will be named model.tflite. Alongside it will be placed a copy of
193-
the output_spec.json file.
194-
"""
195-
tf.io.gfile.makedirs(tflite_model_dir)
196-
convert_saved_model(mlgo_model_dir,
197-
os.path.join(tflite_model_dir, TFLITE_MODEL_NAME))
198-
199-
json_file = 'output_spec.json'
200-
src_json = os.path.join(mlgo_model_dir, json_file)
201-
dest_json = os.path.join(tflite_model_dir, json_file)
202-
tf.io.gfile.copy(src_json, dest_json)
201+
saved_model_dir = os.path.join(root_dir, policy_name)
202+
self._save_policy(saver, saved_model_dir)
203+
self._write_output_signature(saver, saved_model_dir)
204+
# This is not quite the most efficient way to do this - we save the model
205+
# just to load it again and save it as tflite - but it's the minimum,
206+
# temporary step so we can validate more thoroughly our use of tflite.
207+
convert_saved_model(saved_model_dir,
208+
os.path.join(saved_model_dir, TFLITE_MODEL_NAME))

compiler_opt/rl/policy_saver_test.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,8 @@ def test_save_policy(self):
113113
for sub_dir in ['saved_policy', 'saved_collect_policy']:
114114
self.assertTrue(
115115
tf.io.gfile.exists(os.path.join(root_dir, sub_dir, 'saved_model.pb')))
116+
self.assertTrue(
117+
tf.io.gfile.exists(os.path.join(root_dir, sub_dir, 'model.tflite')))
116118
self.assertTrue(
117119
tf.io.gfile.exists(
118120
os.path.join(root_dir, sub_dir,

0 commit comments

Comments
 (0)