Skip to content

Commit 0b06bf2

Browse files
authored
Utility library/script to convert saved models to tflite models. (#95)
Utility library/script to convert saved models to tflite models.
1 parent faa5c90 commit 0b06bf2

File tree

2 files changed

+103
-0
lines changed

2 files changed

+103
-0
lines changed

compiler_opt/rl/policy_saver.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
from typing import Dict, Tuple
2525

2626
OUTPUT_SIGNATURE = 'output_spec.json'
27+
TFLITE_MODEL_NAME = 'model.tflite'
2728

2829
_TYPE_CONVERSION_DICT = {
2930
tf.float32: 'float',
@@ -158,3 +159,44 @@ def save(self, root_dir: str):
158159
for policy_name, (saver, _) in self._policy_saver_dict.items():
159160
self._save_policy(saver, os.path.join(root_dir, policy_name))
160161
self._write_output_signature(saver, os.path.join(root_dir, policy_name))
162+
163+
164+
def convert_saved_model(sm_dir: str, tflite_model_path: str):
165+
"""Convert a saved model to tflite.
166+
167+
Args:
168+
sm_dir: path to the saved model to convert
169+
170+
tflite_model_path: desired output file path. Directory structure will
171+
be created by this function, as needed.
172+
"""
173+
tf.io.gfile.makedirs(os.path.dirname(tflite_model_path))
174+
converter = tf.lite.TFLiteConverter.from_saved_model(sm_dir)
175+
converter.target_spec.supported_ops = [
176+
tf.lite.OpsSet.TFLITE_BUILTINS,
177+
]
178+
tfl_model = converter.convert()
179+
with tf.io.gfile.GFile(tflite_model_path, 'wb') as f:
180+
f.write(tfl_model)
181+
182+
183+
def convert_mlgo_model(mlgo_model_dir: str, tflite_model_dir: str):
184+
"""Convert a mlgo saved model to mlgo tflite.
185+
186+
Args:
187+
mlgo_model_dir: path to the mlgo saved model dir. It is expected to contain
188+
the saved model files (i.e. saved_model.pb, the variables dir) and the
189+
output_spec.json file
190+
191+
tflite_model_dir: path to a directory where the tflite model will be placed.
192+
The model will be named model.tflite. Alongside it will be placed a copy of
193+
the output_spec.json file.
194+
"""
195+
tf.io.gfile.makedirs(tflite_model_dir)
196+
convert_saved_model(mlgo_model_dir,
197+
os.path.join(tflite_model_dir, TFLITE_MODEL_NAME))
198+
199+
json_file = 'output_spec.json'
200+
src_json = os.path.join(mlgo_model_dir, json_file)
201+
dest_json = os.path.join(tflite_model_dir, json_file)
202+
tf.io.gfile.copy(src_json, dest_json)

compiler_opt/rl/policy_saver_test.py

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,55 @@
2727
from compiler_opt.rl import policy_saver
2828

2929

30+
# copied from the llvm regalloc generator
31+
def _gen_test_model(outdir: str):
32+
policy_decision_label = 'index_to_evict'
33+
policy_output_spec = """
34+
[
35+
{
36+
"logging_name": "index_to_evict",
37+
"tensor_spec": {
38+
"name": "StatefulPartitionedCall",
39+
"port": 0,
40+
"type": "int64_t",
41+
"shape": [
42+
1
43+
]
44+
}
45+
}
46+
]
47+
"""
48+
per_register_feature_list = ['mask']
49+
num_registers = 33
50+
51+
def get_input_signature():
52+
"""Returns (time_step_spec, action_spec) for LLVM register allocation."""
53+
inputs = dict(
54+
(key, tf.TensorSpec(dtype=tf.int64, shape=(num_registers), name=key))
55+
for key in per_register_feature_list)
56+
return inputs
57+
58+
module = tf.Module()
59+
# We have to set this useless variable in order for the TF C API to correctly
60+
# intake it
61+
module.var = tf.Variable(0, dtype=tf.int64)
62+
63+
def action(*inputs):
64+
result = tf.math.argmax(
65+
tf.cast(inputs[0]['mask'], tf.int32), axis=-1) + module.var
66+
return {policy_decision_label: result}
67+
68+
module.action = tf.function()(action)
69+
action = {
70+
'action': module.action.get_concrete_function(get_input_signature())
71+
}
72+
tf.saved_model.save(module, outdir, signatures=action)
73+
output_spec_path = os.path.join(outdir, 'output_spec.json')
74+
with tf.io.gfile.GFile(output_spec_path, 'w') as f:
75+
print(f'Writing output spec to {output_spec_path}.')
76+
f.write(policy_output_spec)
77+
78+
3079
class PolicySaverTest(tf.test.TestCase):
3180

3281
def setUp(self):
@@ -80,6 +129,18 @@ def test_save_policy(self):
80129
}
81130
}], json.loads(tf.io.gfile.GFile(output_signature_fn).read()))
82131

132+
def test_tflite_conversion(self):
133+
sm_dir = os.path.join(self.get_temp_dir(), 'saved_model')
134+
tflite_dir = os.path.join(self.get_temp_dir(), 'tflite_model')
135+
_gen_test_model(sm_dir)
136+
policy_saver.convert_mlgo_model(sm_dir, tflite_dir)
137+
self.assertTrue(
138+
tf.io.gfile.exists(
139+
os.path.join(tflite_dir, policy_saver.TFLITE_MODEL_NAME)))
140+
self.assertTrue(
141+
tf.io.gfile.exists(
142+
os.path.join(tflite_dir, policy_saver.OUTPUT_SIGNATURE)))
143+
83144

84145
if __name__ == '__main__':
85146
tf.test.main()

0 commit comments

Comments
 (0)