@@ -73,6 +73,22 @@ def download_model(url, dest, verbose=True):
73
73
return fpath , tname
74
74
75
75
76
+ def download_tflite (url , dest , verbose = True ):
77
+ """
78
+ Downloads a model from tfhub.
79
+ The function assumes the format is `.tflite`.
80
+ """
81
+ if not os .path .exists (dest ):
82
+ os .makedirs (dest )
83
+ fpath = os .path .join (dest , "model.tflite" )
84
+ if not os .path .exists (fpath ):
85
+ from tf2onnx import utils
86
+ if verbose :
87
+ print ("Download %r." % fpath )
88
+ utils .get_url (url , fpath )
89
+ return fpath
90
+
91
+
76
92
def convert_model (model_name , output_path , opset = 13 , verbose = True ):
77
93
"""
78
94
Converts the downloaded model into ONNX.
@@ -97,6 +113,30 @@ def convert_model(model_name, output_path, opset=13, verbose=True):
97
113
print ("Duration %r." % (datetime .datetime .now () - begin ))
98
114
99
115
116
+ def convert_tflite (model_name , output_path , opset = 13 , verbose = True ):
117
+ """
118
+ Converts the downloaded model into ONNX.
119
+ """
120
+ if not os .path .exists (output_path ):
121
+ begin = datetime .datetime .now ()
122
+ cmdl = ['-m' , 'tf2onnx.convert' , '--tflite' ,
123
+ '"%s"' % os .path .abspath (model_name ).replace ("\\ " , "/" ),
124
+ '--output' , '"%s"' % os .path .abspath (output_path ).replace ("\\ " , "/" ),
125
+ '--opset' , "%d" % opset ]
126
+ if verbose :
127
+ print ("cmd: python %s" % " " .join (cmdl ))
128
+ pproc = subprocess .Popen (
129
+ cmdl , shell = True , stdin = None , stdout = subprocess .PIPE , stderr = subprocess .PIPE ,
130
+ executable = sys .executable .replace ("pythonw" , "python" ))
131
+ stdoutdata , stderrdata = pproc .communicate ()
132
+ if verbose :
133
+ print ('--OUT--' )
134
+ print (stdoutdata .decode ('ascii' ))
135
+ print ('--ERR--' )
136
+ print (stderrdata .decode ('ascii' ))
137
+ print ("Duration %r." % (datetime .datetime .now () - begin ))
138
+
139
+
100
140
def check_discrepencies (out1 , out2 , threshold = 1e-3 ):
101
141
"""
102
142
Compares two tensors. Raises an exception if it fails.
@@ -172,3 +212,66 @@ def benchmark(url, dest, onnx_name, opset, imgs, verbose=True, threshold=1e-3,
172
212
res = res [output_name ]
173
213
check_discrepencies (fct_ort (imgs [0 ]), res .numpy (), threshold )
174
214
return duration_ort , duration_tf
215
+
216
+
217
+ def benchmark_tflite (url , dest , onnx_name , opset , imgs , verbose = True , threshold = 1e-3 ):
218
+ """
219
+ Runs a simple benchmark with a tflite model.
220
+ Goes through every steps (download, convert).
221
+ Skips them if already done.
222
+ """
223
+ tname = download_tflite (url , dest )
224
+ if verbose :
225
+ print ("Created %r." % tname )
226
+
227
+ # Converts the model.
228
+ if verbose :
229
+ print ("Convert model in %r." % dest )
230
+ convert_tflite (tname , onnx_name , opset )
231
+ if verbose :
232
+ print ("Created %r." % onnx_name )
233
+
234
+ # Benchmarks both models.
235
+ ort = onnxruntime .InferenceSession (onnx_name )
236
+
237
+ if verbose :
238
+ print ("ONNX inputs:" )
239
+ for a in ort .get_inputs ():
240
+ print (" {}: {}, {}" .format (a .name , a .type , a .shape ))
241
+ print ("ONNX outputs:" )
242
+ for a in ort .get_outputs ():
243
+ print (" {}: {}, {}" .format (a .name , a .type , a .shape ))
244
+
245
+ # onnxruntime
246
+ input_name = ort .get_inputs ()[0 ].name
247
+ fct_ort = lambda img : ort .run (None , {input_name : img })[0 ]
248
+ results_ort , duration_ort = measure_time (fct_ort , imgs )
249
+ if verbose :
250
+ print ("ORT" , len (imgs ), duration_ort )
251
+
252
+ # tensorflow
253
+ import tensorflow_hub as hub
254
+ from tensorflow import convert_to_tensor
255
+ model = hub .load (url .split ("?" )[0 ])
256
+ if signature is not None :
257
+ model = model .signatures ['serving_default' ]
258
+ imgs_tf = [convert_to_tensor (img ) for img in imgs ]
259
+ results_tf , duration_tf = measure_time (model , imgs_tf )
260
+
261
+ if verbose :
262
+ print ("TF" , len (imgs ), duration_tf )
263
+ mean_ort = sum (duration_ort ) / len (duration_ort )
264
+ mean_tf = sum (duration_tf ) / len (duration_tf )
265
+ print ("ratio ORT=%r / TF=%r = %r" % (mean_ort , mean_tf , mean_ort / mean_tf ))
266
+
267
+ # checks discrepencies
268
+ res = model (imgs_tf [0 ])
269
+ if isinstance (res , dict ):
270
+ if len (res ) != 1 :
271
+ raise NotImplementedError ("TF output contains more than one output: %r." % res )
272
+ output_name = ort .get_outputs ()[0 ].name
273
+ if output_name not in res :
274
+ raise AssertionError ("Unable to find output %r in %r." % (output_name , list (sorted (res ))))
275
+ res = res [output_name ]
276
+ check_discrepencies (fct_ort (imgs [0 ]), res .numpy (), threshold )
277
+ return duration_ort , duration_tf
0 commit comments