Skip to content

Commit 0c69c13

Browse files
committed
Export TFLite.
Fix #341 Fix #158 Fix #138 Fix #4
1 parent 3ab4157 commit 0c69c13

File tree

1 file changed

+29
-5
lines changed

1 file changed

+29
-5
lines changed

efficientdet/inference.py

Lines changed: 29 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -685,7 +685,21 @@ def freeze(self):
685685
self.sess, self.sess.graph_def, output_names)
686686
return graphdef
687687

688-
def export(self, output_dir):
688+
def to_tflite(self, saved_model_dir):
689+
"""Convert to tflite."""
690+
input_name = self.signitures['image_arrays'].op.name
691+
input_shapes = {input_name: [None, *self.params['image_size'], 3]}
692+
converter = tf.lite.TFLiteConverter.from_saved_model(
693+
saved_model_dir,
694+
input_arrays=[input_name],
695+
input_shapes=input_shapes,
696+
output_arrays=[self.signitures['prediction'].op.name])
697+
converter.experimental_new_converter = True
698+
supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS]
699+
converter.target_spec.supported_ops = supported_ops
700+
return converter.convert()
701+
702+
def export(self, output_dir, frozen_pb=True, tflite=True):
689703
"""Export a saved model."""
690704
signitures = self.signitures
691705
signature_def_map = {
@@ -709,10 +723,20 @@ def export(self, output_dir):
709723
logging.info('Model saved at %s', output_dir)
710724

711725
# also save freeze pb file.
712-
graphdef = self.freeze()
713-
pb_path = os.path.join(output_dir, self.model_name + '_frozen.pb')
714-
tf.io.gfile.GFile(pb_path, 'wb').write(graphdef.SerializeToString())
715-
logging.info('Free graph saved at %s', pb_path)
726+
if frozen_pb:
727+
graphdef = self.freeze()
728+
pb_path = os.path.join(output_dir, self.model_name + '_frozen.pb')
729+
tf.io.gfile.GFile(pb_path, 'wb').write(graphdef.SerializeToString())
730+
logging.info('Free graph saved at %s', pb_path)
731+
732+
if tflite:
733+
ver = tf.__version__
734+
if ver < '2.2.0-dev20200501' or ('dev' not in ver and ver < '2.2.0-rc4'):
735+
raise ValueError('TFLite requires TF 2.2.0rc4 or laterr version.')
736+
tflite_model = self.to_tflite(output_dir)
737+
tflite_path = os.path.join(output_dir, self.model_name + '.tflite')
738+
with tf.io.gfile.GFile(tflite_path, 'wb') as f:
739+
f.write(tflite_model)
716740

717741

718742
class InferenceDriver(object):

0 commit comments

Comments
 (0)