|
12 | 12 | # See the License for the specific language governing permissions and |
13 | 13 | # limitations under the License. |
14 | 14 | # ============================================================================== |
15 | | -"""APIs to convert lowered MLIR from PyTorch to TensorFlow and TFLite artifacts.""" |
| 15 | +"""APIs to convert lowered MLIR from PyTorch to TensorFlow artifacts.""" |
16 | 16 |
|
17 | 17 | import re |
18 | | -import tempfile |
19 | 18 |
|
20 | 19 | import tensorflow as tf |
21 | 20 | import torch |
@@ -155,46 +154,3 @@ def mlir_to_tf_function(lowered: export.MlirLowered): |
155 | 154 | _wrap_as_tf_func(lowered, tf_state_dict), |
156 | 155 | input_signature=_make_input_signatures(lowered), |
157 | 156 | ) |
158 | | - |
159 | | - |
160 | | -def mlir_to_flatbuffer(lowered: export.MlirLowered): |
161 | | - """Convert the MLIR lowered to a TFLite flatbuffer binary.""" |
162 | | - tf_state_dict = _build_tf_state_dict(lowered) |
163 | | - signature_names = [tf.saved_model.DEFAULT_SERVING_SIGNATURE_DEF_KEY] |
164 | | - tf_signatures = [_make_input_signatures(lowered)] |
165 | | - tf_functions = [_wrap_as_tf_func(lowered, tf_state_dict)] |
166 | | - |
167 | | - tf_module = tf.Module() |
168 | | - tf_module.f = [] |
169 | | - |
170 | | - for tf_sig, func in zip(tf_signatures, tf_functions): |
171 | | - tf_module.f.append( |
172 | | - tf.function( |
173 | | - func, |
174 | | - input_signature=tf_sig, |
175 | | - ) |
176 | | - ) |
177 | | - |
178 | | - tf_module._variables = list(tf_state_dict.values()) |
179 | | - |
180 | | - tf_concrete_funcs = [ |
181 | | - func.get_concrete_function(*tf_sig) |
182 | | - for func, tf_sig in zip(tf_module.f, tf_signatures) |
183 | | - ] |
184 | | - |
185 | | - # We need to temporarily save since TFLite's from_concrete_functions does not |
186 | | - # allow providing names for each of the concrete functions. |
187 | | - with tempfile.TemporaryDirectory() as temp_dir_path: |
188 | | - tf.saved_model.save( |
189 | | - tf_module, |
190 | | - temp_dir_path, |
191 | | - signatures={ |
192 | | - sig_name: tf_concrete_funcs[idx] |
193 | | - for idx, sig_name in enumerate(signature_names) |
194 | | - }, |
195 | | - ) |
196 | | - |
197 | | - converter = tf.lite.TFLiteConverter.from_saved_model(temp_dir_path) |
198 | | - tflite_model = converter.convert() |
199 | | - |
200 | | - return tflite_model |
0 commit comments