@@ -69,6 +69,47 @@ def _get_non_identity_op(tensor):
69
69
return tensor
70
70
71
71
72
+ def convert_saved_model (sm_dir : str , tflite_model_path : str ):
73
+ """Convert a saved model to tflite.
74
+
75
+ Args:
76
+ sm_dir: path to the saved model to convert
77
+
78
+ tflite_model_path: desired output file path. Directory structure will
79
+ be created by this function, as needed.
80
+ """
81
+ tf .io .gfile .makedirs (os .path .dirname (tflite_model_path ))
82
+ converter = tf .lite .TFLiteConverter .from_saved_model (sm_dir )
83
+ converter .target_spec .supported_ops = [
84
+ tf .lite .OpsSet .TFLITE_BUILTINS ,
85
+ ]
86
+ converter .allow_custom_ops = True
87
+ tfl_model = converter .convert ()
88
+ with tf .io .gfile .GFile (tflite_model_path , 'wb' ) as f :
89
+ f .write (tfl_model )
90
+
91
+
92
+ def convert_mlgo_model (mlgo_model_dir : str , tflite_model_dir : str ):
93
+ """Convert a mlgo saved model to mlgo tflite.
94
+
95
+ Args:
96
+ mlgo_model_dir: path to the mlgo saved model dir. It is expected to contain
97
+ the saved model files (i.e. saved_model.pb, the variables dir) and the
98
+ output_spec.json file
99
+
100
+ tflite_model_dir: path to a directory where the tflite model will be placed.
101
+ The model will be named model.tflite. Alongside it will be placed a copy
102
+ of the output_spec.json file.
103
+ """
104
+ tf .io .gfile .makedirs (tflite_model_dir )
105
+ convert_saved_model (mlgo_model_dir ,
106
+ os .path .join (tflite_model_dir , TFLITE_MODEL_NAME ))
107
+
108
+ src_json = os .path .join (mlgo_model_dir , OUTPUT_SIGNATURE )
109
+ dest_json = os .path .join (tflite_model_dir , OUTPUT_SIGNATURE )
110
+ tf .io .gfile .copy (src_json , dest_json )
111
+
112
+
72
113
class PolicySaver (object ):
73
114
"""Object that saves policy and model config file required by inference.
74
115
@@ -157,46 +198,11 @@ def _write_output_signature(self, saver, path):
157
198
def save (self , root_dir : str ):
158
199
"""Writes policy and model_binding.txt to root_dir/policy_name/."""
159
200
for policy_name , (saver , _ ) in self ._policy_saver_dict .items ():
160
- self ._save_policy (saver , os .path .join (root_dir , policy_name ))
161
- self ._write_output_signature (saver , os .path .join (root_dir , policy_name ))
162
-
163
-
164
- def convert_saved_model (sm_dir : str , tflite_model_path : str ):
165
- """Convert a saved model to tflite.
166
-
167
- Args:
168
- sm_dir: path to the saved model to convert
169
-
170
- tflite_model_path: desired output file path. Directory structure will
171
- be created by this function, as needed.
172
- """
173
- tf .io .gfile .makedirs (os .path .dirname (tflite_model_path ))
174
- converter = tf .lite .TFLiteConverter .from_saved_model (sm_dir )
175
- converter .target_spec .supported_ops = [
176
- tf .lite .OpsSet .TFLITE_BUILTINS ,
177
- ]
178
- tfl_model = converter .convert ()
179
- with tf .io .gfile .GFile (tflite_model_path , 'wb' ) as f :
180
- f .write (tfl_model )
181
-
182
-
183
- def convert_mlgo_model (mlgo_model_dir : str , tflite_model_dir : str ):
184
- """Convert a mlgo saved model to mlgo tflite.
185
-
186
- Args:
187
- mlgo_model_dir: path to the mlgo saved model dir. It is expected to contain
188
- the saved model files (i.e. saved_model.pb, the variables dir) and the
189
- output_spec.json file
190
-
191
- tflite_model_dir: path to a directory where the tflite model will be placed.
192
- The model will be named model.tflite. Alongside it will be placed a copy of
193
- the output_spec.json file.
194
- """
195
- tf .io .gfile .makedirs (tflite_model_dir )
196
- convert_saved_model (mlgo_model_dir ,
197
- os .path .join (tflite_model_dir , TFLITE_MODEL_NAME ))
198
-
199
- json_file = 'output_spec.json'
200
- src_json = os .path .join (mlgo_model_dir , json_file )
201
- dest_json = os .path .join (tflite_model_dir , json_file )
202
- tf .io .gfile .copy (src_json , dest_json )
201
+ saved_model_dir = os .path .join (root_dir , policy_name )
202
+ self ._save_policy (saver , saved_model_dir )
203
+ self ._write_output_signature (saver , saved_model_dir )
204
+ # This is not quite the most efficient way to do this - we save the model
205
+ # just to load it again and save it as tflite - but it's the minimum,
206
+ # temporary step so we can validate more thoroughly our use of tflite.
207
+ convert_saved_model (saved_model_dir ,
208
+ os .path .join (saved_model_dir , TFLITE_MODEL_NAME ))
0 commit comments