@@ -213,7 +213,7 @@ def check_discrepencies(out1, out2, threshold=1e-3):
213
213
214
214
def benchmark (url , dest , onnx_name , opset , imgs , verbose = True , threshold = 1e-3 ,
215
215
signature = None , tag = None , output_name = None , ort_name = None ,
216
- optimize = True ):
216
+ optimize = True , convert_tflite = None ):
217
217
"""
218
218
Runs a simple benchmark.
219
219
Goes through every steps (download, convert).
@@ -241,6 +241,16 @@ def benchmark(url, dest, onnx_name, opset, imgs, verbose=True, threshold=1e-3,
241
241
z .extractall (os .path .join (dest , "large_model" ))
242
242
onnx_name = onnx_name_unzipped
243
243
244
+ # tflite
245
+ if convert_tflite and not os .path .exists (convert_tflite ):
246
+ import tensorflow as tf
247
+ converter = tf .lite .TFLiteConverter .from_saved_model (tname )
248
+ print ('TFL-i:' , converter .inference_input_type )
249
+ print ('TFL-o:' , converter .inference_output_type )
250
+ tflite_model = converter .convert ()
251
+ with open (convert_tflite , 'wb' ) as f :
252
+ f .write (tflite_model )
253
+
244
254
# Benchmarks both models.
245
255
if optimize :
246
256
ort = onnxruntime .InferenceSession (onnx_name )
@@ -330,15 +340,19 @@ def benchmark(url, dest, onnx_name, opset, imgs, verbose=True, threshold=1e-3,
330
340
return duration_ort , duration_tf
331
341
332
342
333
- def benchmark_tflite (url , dest , onnx_name , opset , imgs , verbose = True , threshold = 1e-3 ):
343
+ def benchmark_tflite (url , dest , onnx_name , opset , imgs , verbose = True , threshold = 1e-3 ,
344
+ names = None ):
334
345
"""
335
346
Runs a simple benchmark with a tflite model.
336
347
Goes through every steps (download, convert).
337
348
Skips them if already done.
338
349
"""
339
- tname = download_tflite (url , dest )
340
- if verbose :
341
- print ("Created %r." % tname )
350
+ if url .startswith ('http' ):
351
+ tname = download_tflite (url , dest )
352
+ if verbose :
353
+ print ("Created %r." % tname )
354
+ else :
355
+ tname = url
342
356
343
357
# Converts the model.
344
358
if verbose :
@@ -349,7 +363,7 @@ def benchmark_tflite(url, dest, onnx_name, opset, imgs, verbose=True, threshold=
349
363
350
364
# Benchmarks both models.
351
365
ort = onnxruntime .InferenceSession (onnx_name )
352
-
366
+
353
367
if verbose :
354
368
print ("ONNX inputs:" )
355
369
for a in ort .get_inputs ():
@@ -365,34 +379,80 @@ def benchmark_tflite(url, dest, onnx_name, opset, imgs, verbose=True, threshold=
365
379
if verbose :
366
380
print ("ORT" , len (imgs ), duration_ort )
367
381
368
- # tensorflow
369
- import tensorflow_hub as hub
370
- from tensorflow import convert_to_tensor
371
- if isinstance (imgs [0 ], OrderedDict ):
372
- imgs_tf = [
373
- OrderedDict ((k , convert_to_tensor (v )) for k , v in img .items ())
374
- for img in imgs ]
375
- else :
376
- imgs_tf = [convert_to_tensor (img ) for img in imgs ]
377
- model = hub .load (url .split ("?" )[0 ])
378
- if signature is not None :
379
- model = model .signatures ['serving_default' ]
380
- results_tf , duration_tf = measure_time (model , imgs_tf )
382
+ # tflite
383
+ import tensorflow as tf
384
+ interpreter = tf .lite .Interpreter (tname )
385
+ #help(interpreter)
386
+ input_details = interpreter .get_input_details ()
387
+ index_in = input_details [0 ]['index' ]
388
+ output_details = interpreter .get_output_details ()
389
+ index_out = output_details [0 ]['index' ]
390
+ interpreter .allocate_tensors ()
391
+
392
+ def call_tflite (inp ):
393
+ interpreter .set_tensor (index_in , inp )
394
+ interpreter .invoke ()
395
+ scores = interpreter .get_tensor (index_out )
396
+ return scores
397
+
398
+ # check intermediate results
399
+ if names is not None :
400
+ from skl2onnx .helpers .onnx_helper import select_model_inputs_outputs
401
+ import onnx
402
+
403
+ with open (onnx_name , "rb" ) as f :
404
+ model_onnx = onnx .load (f )
405
+
406
+ call_tflite (imgs [0 ])
407
+ inputs = {input_name : imgs [0 ]}
408
+ details = interpreter .get_tensor_details ()
409
+ names_index = {}
410
+ for tt in details :
411
+ names_index [tt ['name' ]] = (tt ['index' ], tt ['quantization' ], tt ['quantization_parameters' ])
412
+
413
+ num_results = []
414
+ for name_tfl , name_ort in names :
415
+ index = names_index [name_tfl ]
416
+
417
+ tfl_value = interpreter .get_tensor (index [0 ])
418
+
419
+ new_name = onnx_name + ".%s.onnx" % name_ort .replace (":" , "_" ).replace (";" , "_" ).replace ("/" , "_" )
420
+ if not os .path .exists (new_name ):
421
+ print ('[create onnx model for %r, %r.' % (name_tfl , name_ort ))
422
+ new_model = select_model_inputs_outputs (model_onnx , outputs = [name_ort ])
423
+ with open (new_name , "wb" ) as f :
424
+ f .write (new_model .SerializeToString ())
425
+
426
+ ort_inter = onnxruntime .InferenceSession (new_name )
427
+ result = ort_inter .run (None , inputs )[0 ]
428
+
429
+ diff = numpy .abs (tfl_value .ravel ().astype (numpy .float64 ) -
430
+ result .ravel ().astype (numpy .float64 )).max ()
431
+ num_results .append ("diff=%f names=(%r,%r) " % (diff , name_tfl , name_ort ))
432
+ print ("*** diff=%f names=(%r,%r) " % (diff , name_tfl , name_ort ))
433
+ print (" TFL:" , tfl_value .dtype , tfl_value .shape , tfl_value .min (), tfl_value .max ())
434
+ print (" ORT:" , result .dtype , result .shape , result .min (), result .max ())
435
+
436
+ print ("\n " .join (num_results ))
437
+
438
+ results_tfl , duration_tfl = measure_time (call_tflite , imgs )
381
439
382
440
if verbose :
383
- print ("TF " , len (imgs ), duration_tf )
441
+ print ("TFL " , len (imgs ), duration_tfl )
384
442
mean_ort = sum (duration_ort ) / len (duration_ort )
385
- mean_tf = sum (duration_tf ) / len (duration_tf )
386
- print ("ratio ORT=%r / TF=%r = %r" % (mean_ort , mean_tf , mean_ort / mean_tf ))
387
-
443
+ mean_tfl = sum (duration_tfl ) / len (duration_tfl )
444
+ print ("ratio ORT=%r / TF=%r = %r" % (mean_ort , mean_tfl , mean_ort / mean_tfl ))
445
+
388
446
# checks discrepencies
389
- res = model (imgs_tf [0 ])
447
+ res = call_tflite (imgs [0 ])
448
+ res_ort = fct_ort (imgs [0 ])
390
449
if isinstance (res , dict ):
391
450
if len (res ) != 1 :
392
451
raise NotImplementedError ("TF output contains more than one output: %r." % res )
393
452
output_name = ort .get_outputs ()[0 ].name
394
453
if output_name not in res :
395
454
raise AssertionError ("Unable to find output %r in %r." % (output_name , list (sorted (res ))))
396
455
res = res [output_name ]
397
- check_discrepencies (fct_ort (imgs [0 ]), res .numpy (), threshold )
456
+
457
+ check_discrepencies (res_ort , res , threshold )
398
458
return duration_ort , duration_tf
0 commit comments