@@ -685,7 +685,21 @@ def freeze(self):
685685 self .sess , self .sess .graph_def , output_names )
686686 return graphdef
687687
688- def export (self , output_dir ):
688+ def to_tflite (self , saved_model_dir ):
689+ """Convert to tflite."""
690+ input_name = self .signitures ['image_arrays' ].op .name
691+ input_shapes = {input_name : [None , * self .params ['image_size' ], 3 ]}
692+ converter = tf .lite .TFLiteConverter .from_saved_model (
693+ saved_model_dir ,
694+ input_arrays = [input_name ],
695+ input_shapes = input_shapes ,
696+ output_arrays = [self .signitures ['prediction' ].op .name ])
697+ converter .experimental_new_converter = True
698+ supported_ops = [tf .lite .OpsSet .TFLITE_BUILTINS ]
699+ converter .target_spec .supported_ops = supported_ops
700+ return converter .convert ()
701+
702+ def export (self , output_dir , frozen_pb = True , tflite = True ):
689703 """Export a saved model."""
690704 signitures = self .signitures
691705 signature_def_map = {
@@ -709,10 +723,20 @@ def export(self, output_dir):
709723 logging .info ('Model saved at %s' , output_dir )
710724
711725 # also save freeze pb file.
712- graphdef = self .freeze ()
713- pb_path = os .path .join (output_dir , self .model_name + '_frozen.pb' )
714- tf .io .gfile .GFile (pb_path , 'wb' ).write (graphdef .SerializeToString ())
715- logging .info ('Free graph saved at %s' , pb_path )
726+ if frozen_pb :
727+ graphdef = self .freeze ()
728+ pb_path = os .path .join (output_dir , self .model_name + '_frozen.pb' )
729+ tf .io .gfile .GFile (pb_path , 'wb' ).write (graphdef .SerializeToString ())
730+ logging .info ('Free graph saved at %s' , pb_path )
731+
732+ if tflite :
733+ ver = tf .__version__
734+ if ver < '2.2.0-dev20200501' or ('dev' not in ver and ver < '2.2.0-rc4' ):
735+ raise ValueError ('TFLite requires TF 2.2.0rc4 or laterr version.' )
736+ tflite_model = self .to_tflite (output_dir )
737+ tflite_path = os .path .join (output_dir , self .model_name + '.tflite' )
738+ with tf .io .gfile .GFile (tflite_path , 'wb' ) as f :
739+ f .write (tflite_model )
716740
717741
718742class InferenceDriver (object ):
0 commit comments