@@ -46,69 +46,77 @@ def create_tflite(action_parameters):
4646 team_uuid = action_parameters ['team_uuid' ]
4747 model_uuid = action_parameters ['model_uuid' ]
4848
49- model_entity = storage .retrieve_model_entity (team_uuid , model_uuid )
50- model_folder = model_entity ['model_folder' ]
51-
52- # The following code is inspired by
53- # https://colab.sandbox.google.com/github/tensorflow/models/blob/master/research/object_detection/colab_tutorials/convert_odt_model_to_TFLite.ipynb
54- # and
55- # https://github.com/tensorflow/models/blob/b3483b3942ab9bddc94fcbc5bd00fc790d1ddfcb/research/object_detection/export_tflite_graph_tf2.py
56-
57- if not blob_storage .tflite_saved_model_exists (model_folder ):
58- # Export TFLite inference graph.
59- pipeline_config_path = blob_storage .get_pipeline_config_path (model_folder )
60- pipeline_config = pipeline_pb2 .TrainEvalPipelineConfig ()
61- with tf .io .gfile .GFile (pipeline_config_path , 'r' ) as f :
62- text_format .Parse (f .read (), pipeline_config )
63- trained_checkpoint_path = model_entity ['trained_checkpoint_path' ]
64- if trained_checkpoint_path == '' :
65- message = 'Error: Trained checkpoint not found for model_uuid=%s.' % model_uuid
66- logging .critical (message )
67- raise exceptions .HttpErrorNotFound (message )
68- trained_checkpoint_dir = trained_checkpoint_path [:trained_checkpoint_path .rindex ('/' )]
69- output_directory = blob_storage .get_tflite_folder_path (model_folder )
70- max_detections = 10 # This matches the default for TFObjectDetector.Parameters.maxNumDetections in the the FTC SDK.
71- export_tflite_graph_lib_tf2 .export_tflite_model (pipeline_config , trained_checkpoint_dir ,
72- output_directory , max_detections , use_regular_nms = False )
73-
74- action .retrigger_if_necessary (action_parameters )
75-
76- if not blob_storage .tflite_quantized_model_exists (model_folder ):
77- # Convert to a quantized tflite model
78- saved_model_path = blob_storage .get_tflite_saved_model_path (model_folder )
79- converter = tf .lite .TFLiteConverter .from_saved_model (saved_model_path )
80- converter .optimizations = [tf .lite .Optimize .DEFAULT ] # DEFAULT means the tflite model will be quantized.
81- tflite_quantized_model = converter .convert ()
82- blob_storage .store_tflite_quantized_model (model_folder , tflite_quantized_model )
83-
84- action .retrigger_if_necessary (action_parameters )
85-
86- if not blob_storage .tflite_label_map_txt_exists (model_folder ):
87- # Create the label map.
88- blob_storage .store_tflite_label_map_txt (model_folder ,
89- '\n ' .join (model_entity ['sorted_label_list' ]))
90-
91- action .retrigger_if_necessary (action_parameters )
92-
93- if not blob_storage .tflite_model_with_metadata_exists (model_folder ):
94- # Add Metadata
95- # Make a temporary directory
96- folder = '/tmp/tflite_creater/%s' % str (uuid .uuid4 ().hex )
97- os .makedirs (folder , exist_ok = True )
98- try :
99- quantized_model_filename = '%s/quantized_model' % folder
100- blob_storage .write_tflite_quantized_model_to_file (model_folder , quantized_model_filename )
101- label_map_txt_filename = '%s/label_map.txt' % folder
102- blob_storage .write_tflite_label_map_txt_to_file (model_folder , label_map_txt_filename )
103- model_with_metadata_filename = '%s/model_with_metadata.tflite' % folder
104-
105- writer = object_detector .MetadataWriter .create_for_inference (
106- writer_utils .load_file (quantized_model_filename ),
107- input_norm_mean = [127.5 ], input_norm_std = [127.5 ],
108- label_file_paths = [label_map_txt_filename ])
109- writer_utils .save_file (writer .populate (), model_with_metadata_filename )
110-
111- blob_storage .store_tflite_model_with_metadata (model_folder , model_with_metadata_filename )
112- finally :
113- # Delete the temporary directory.
114- shutil .rmtree (folder )
49+ try :
50+ model_entity = storage .retrieve_model_entity (team_uuid , model_uuid )
51+ model_folder = model_entity ['model_folder' ]
52+
53+ # The following code is inspired by
54+ # https://colab.sandbox.google.com/github/tensorflow/models/blob/master/research/object_detection/colab_tutorials/convert_odt_model_to_TFLite.ipynb
55+ # and
56+ # https://github.com/tensorflow/models/blob/b3483b3942ab9bddc94fcbc5bd00fc790d1ddfcb/research/object_detection/export_tflite_graph_tf2.py
57+
58+ if not blob_storage .tflite_saved_model_exists (model_folder ):
59+ # Export TFLite inference graph.
60+ pipeline_config_path = blob_storage .get_pipeline_config_path (model_folder )
61+ pipeline_config = pipeline_pb2 .TrainEvalPipelineConfig ()
62+ with tf .io .gfile .GFile (pipeline_config_path , 'r' ) as f :
63+ text_format .Parse (f .read (), pipeline_config )
64+ trained_checkpoint_path = model_entity ['trained_checkpoint_path' ]
65+ if trained_checkpoint_path == '' :
66+ message = 'Error: Trained checkpoint not found for model_uuid=%s.' % model_uuid
67+ logging .critical (message )
68+ raise exceptions .HttpErrorNotFound (message )
69+ trained_checkpoint_dir = trained_checkpoint_path [:trained_checkpoint_path .rindex ('/' )]
70+ output_directory = blob_storage .get_tflite_folder_path (model_folder )
71+ max_detections = 10 # This matches the default for TFObjectDetector.Parameters.maxNumDetections in the the FTC SDK.
72+ export_tflite_graph_lib_tf2 .export_tflite_model (pipeline_config , trained_checkpoint_dir ,
73+ output_directory , max_detections , use_regular_nms = False )
74+
75+ action .retrigger_if_necessary (action_parameters )
76+
77+ if not blob_storage .tflite_quantized_model_exists (model_folder ):
78+ # Convert to a quantized tflite model
79+ saved_model_path = blob_storage .get_tflite_saved_model_path (model_folder )
80+ converter = tf .lite .TFLiteConverter .from_saved_model (saved_model_path )
81+ converter .optimizations = [tf .lite .Optimize .DEFAULT ] # DEFAULT means the tflite model will be quantized.
82+ tflite_quantized_model = converter .convert ()
83+ blob_storage .store_tflite_quantized_model (model_folder , tflite_quantized_model )
84+
85+ action .retrigger_if_necessary (action_parameters )
86+
87+ if not blob_storage .tflite_label_map_txt_exists (model_folder ):
88+ # Create the label map.
89+ blob_storage .store_tflite_label_map_txt (model_folder ,
90+ '\n ' .join (model_entity ['sorted_label_list' ]))
91+
92+ action .retrigger_if_necessary (action_parameters )
93+
94+ if not blob_storage .tflite_model_with_metadata_exists (model_folder ):
95+ # Add Metadata
96+ # Make a temporary directory
97+ folder = '/tmp/tflite_creater/%s' % str (uuid .uuid4 ().hex )
98+ os .makedirs (folder , exist_ok = True )
99+ try :
100+ quantized_model_filename = '%s/quantized_model' % folder
101+ blob_storage .write_tflite_quantized_model_to_file (model_folder , quantized_model_filename )
102+ label_map_txt_filename = '%s/label_map.txt' % folder
103+ blob_storage .write_tflite_label_map_txt_to_file (model_folder , label_map_txt_filename )
104+ model_with_metadata_filename = '%s/model_with_metadata.tflite' % folder
105+
106+ writer = object_detector .MetadataWriter .create_for_inference (
107+ writer_utils .load_file (quantized_model_filename ),
108+ input_norm_mean = [127.5 ], input_norm_std = [127.5 ],
109+ label_file_paths = [label_map_txt_filename ])
110+ writer_utils .save_file (writer .populate (), model_with_metadata_filename )
111+
112+ blob_storage .store_tflite_model_with_metadata (model_folder , model_with_metadata_filename )
113+ finally :
114+ # Delete the temporary directory.
115+ shutil .rmtree (folder )
116+ except :
117+ # Check if the model has been deleted.
118+ team_entity = storage .retrieve_team_entity (team_uuid )
119+ if 'model_uuids_deleted' in team_entity :
120+ if model_uuid in team_entity ['model_uuids_deleted' ]:
121+ return
122+ raise
0 commit comments