Skip to content

Commit d6c3ebf

Browse files
xadupresdpython
andauthored
update script to test tfhub models (#1644)
Signed-off-by: xavier dupré <[email protected]> Co-authored-by: xavier dupré <[email protected]>
1 parent da76eea commit d6c3ebf

File tree

3 files changed

+98
-32
lines changed

3 files changed

+98
-32
lines changed

tests/tfhub/_tools.py

Lines changed: 85 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -213,7 +213,7 @@ def check_discrepencies(out1, out2, threshold=1e-3):
213213

214214
def benchmark(url, dest, onnx_name, opset, imgs, verbose=True, threshold=1e-3,
215215
signature=None, tag=None, output_name=None, ort_name=None,
216-
optimize=True):
216+
optimize=True, convert_tflite=None):
217217
"""
218218
Runs a simple benchmark.
219219
Goes through every steps (download, convert).
@@ -241,6 +241,16 @@ def benchmark(url, dest, onnx_name, opset, imgs, verbose=True, threshold=1e-3,
241241
z.extractall(os.path.join(dest, "large_model"))
242242
onnx_name = onnx_name_unzipped
243243

244+
# tflite
245+
if convert_tflite and not os.path.exists(convert_tflite):
246+
import tensorflow as tf
247+
converter = tf.lite.TFLiteConverter.from_saved_model(tname)
248+
print('TFL-i:', converter.inference_input_type)
249+
print('TFL-o:', converter.inference_output_type)
250+
tflite_model = converter.convert()
251+
with open(convert_tflite, 'wb') as f:
252+
f.write(tflite_model)
253+
244254
# Benchmarks both models.
245255
if optimize:
246256
ort = onnxruntime.InferenceSession(onnx_name)
@@ -330,15 +340,19 @@ def benchmark(url, dest, onnx_name, opset, imgs, verbose=True, threshold=1e-3,
330340
return duration_ort, duration_tf
331341

332342

333-
def benchmark_tflite(url, dest, onnx_name, opset, imgs, verbose=True, threshold=1e-3):
343+
def benchmark_tflite(url, dest, onnx_name, opset, imgs, verbose=True, threshold=1e-3,
344+
names=None):
334345
"""
335346
Runs a simple benchmark with a tflite model.
336347
Goes through every steps (download, convert).
337348
Skips them if already done.
338349
"""
339-
tname = download_tflite(url, dest)
340-
if verbose:
341-
print("Created %r." % tname)
350+
if url.startswith('http'):
351+
tname = download_tflite(url, dest)
352+
if verbose:
353+
print("Created %r." % tname)
354+
else:
355+
tname = url
342356

343357
# Converts the model.
344358
if verbose:
@@ -349,7 +363,7 @@ def benchmark_tflite(url, dest, onnx_name, opset, imgs, verbose=True, threshold=
349363

350364
# Benchmarks both models.
351365
ort = onnxruntime.InferenceSession(onnx_name)
352-
366+
353367
if verbose:
354368
print("ONNX inputs:")
355369
for a in ort.get_inputs():
@@ -365,34 +379,80 @@ def benchmark_tflite(url, dest, onnx_name, opset, imgs, verbose=True, threshold=
365379
if verbose:
366380
print("ORT", len(imgs), duration_ort)
367381

368-
# tensorflow
369-
import tensorflow_hub as hub
370-
from tensorflow import convert_to_tensor
371-
if isinstance(imgs[0], OrderedDict):
372-
imgs_tf = [
373-
OrderedDict((k, convert_to_tensor(v)) for k, v in img.items())
374-
for img in imgs]
375-
else:
376-
imgs_tf = [convert_to_tensor(img) for img in imgs]
377-
model = hub.load(url.split("?")[0])
378-
if signature is not None:
379-
model = model.signatures['serving_default']
380-
results_tf, duration_tf = measure_time(model, imgs_tf)
382+
# tflite
383+
import tensorflow as tf
384+
interpreter = tf.lite.Interpreter(tname)
385+
#help(interpreter)
386+
input_details = interpreter.get_input_details()
387+
index_in = input_details[0]['index']
388+
output_details = interpreter.get_output_details()
389+
index_out = output_details[0]['index']
390+
interpreter.allocate_tensors()
391+
392+
def call_tflite(inp):
393+
interpreter.set_tensor(index_in, inp)
394+
interpreter.invoke()
395+
scores = interpreter.get_tensor(index_out)
396+
return scores
397+
398+
# check intermediate results
399+
if names is not None:
400+
from skl2onnx.helpers.onnx_helper import select_model_inputs_outputs
401+
import onnx
402+
403+
with open(onnx_name, "rb") as f:
404+
model_onnx = onnx.load(f)
405+
406+
call_tflite(imgs[0])
407+
inputs = {input_name: imgs[0]}
408+
details = interpreter.get_tensor_details()
409+
names_index = {}
410+
for tt in details:
411+
names_index[tt['name']] = (tt['index'], tt['quantization'], tt['quantization_parameters'])
412+
413+
num_results = []
414+
for name_tfl, name_ort in names:
415+
index = names_index[name_tfl]
416+
417+
tfl_value = interpreter.get_tensor(index[0])
418+
419+
new_name = onnx_name + ".%s.onnx" % name_ort.replace(":", "_").replace(";", "_").replace("/", "_")
420+
if not os.path.exists(new_name):
421+
print('[create onnx model for %r, %r.' % (name_tfl, name_ort))
422+
new_model = select_model_inputs_outputs(model_onnx, outputs=[name_ort])
423+
with open(new_name, "wb") as f:
424+
f.write(new_model.SerializeToString())
425+
426+
ort_inter = onnxruntime.InferenceSession(new_name)
427+
result = ort_inter.run(None, inputs)[0]
428+
429+
diff = numpy.abs(tfl_value.ravel().astype(numpy.float64) -
430+
result.ravel().astype(numpy.float64)).max()
431+
num_results.append("diff=%f names=(%r,%r) " % (diff, name_tfl, name_ort))
432+
print("*** diff=%f names=(%r,%r) " % (diff, name_tfl, name_ort))
433+
print(" TFL:", tfl_value.dtype, tfl_value.shape, tfl_value.min(), tfl_value.max())
434+
print(" ORT:", result.dtype, result.shape, result.min(), result.max())
435+
436+
print("\n".join(num_results))
437+
438+
results_tfl, duration_tfl = measure_time(call_tflite, imgs)
381439

382440
if verbose:
383-
print("TF", len(imgs), duration_tf)
441+
print("TFL", len(imgs), duration_tfl)
384442
mean_ort = sum(duration_ort) / len(duration_ort)
385-
mean_tf = sum(duration_tf) / len(duration_tf)
386-
print("ratio ORT=%r / TF=%r = %r" % (mean_ort, mean_tf, mean_ort / mean_tf))
387-
443+
mean_tfl = sum(duration_tfl) / len(duration_tfl)
444+
print("ratio ORT=%r / TF=%r = %r" % (mean_ort, mean_tfl, mean_ort / mean_tfl))
445+
388446
# checks discrepencies
389-
res = model(imgs_tf[0])
447+
res = call_tflite(imgs[0])
448+
res_ort = fct_ort(imgs[0])
390449
if isinstance(res, dict):
391450
if len(res) != 1:
392451
raise NotImplementedError("TF output contains more than one output: %r." % res)
393452
output_name = ort.get_outputs()[0].name
394453
if output_name not in res:
395454
raise AssertionError("Unable to find output %r in %r." % (output_name, list(sorted(res))))
396455
res = res[output_name]
397-
check_discrepencies(fct_ort(imgs[0]), res.numpy(), threshold)
456+
457+
check_discrepencies(res_ort, res, threshold)
398458
return duration_ort, duration_tf

tests/tfhub/tfhub_yamnet_tf.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,22 @@
11
# SPDX-License-Identifier: Apache-2.0
22
import os
33
import numpy
4-
from _tools import generate_random_images, benchmark
4+
from _tools import generate_random_images, benchmark, benchmark_tflite
55

66

77
def main(opset=13):
88
url = "https://tfhub.dev/google/yamnet/1?tf-hub-format=compressed"
99
dest = "tf-yamnet-tf"
1010
name = "yamnet"
1111
onnx_name = os.path.join(dest, "%s-%d.onnx" % (name, opset))
12+
tfl = os.path.join(dest, 'model.tflite')
1213

1314
imgs = generate_random_images(shape=(16000, ), dtype=numpy.float32, scale=0.)
1415

15-
benchmark(url, dest, onnx_name, opset, imgs)
16+
# benchmark(url, dest, onnx_name, opset, imgs, convert_tflite=tfl)
17+
18+
onnx_name = os.path.join(dest, "%s-tfl-%d.onnx" % (name, opset))
19+
benchmark_tflite(tfl, dest, onnx_name, opset, imgs)
1620

1721

1822
if __name__ == "__main__":

tests/tfhub/tfhub_yamnet_tflite.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,11 +10,13 @@ def main(opset=13):
1010
name = "yamnet"
1111
onnx_name = os.path.join(dest, "%s-%d.onnx" % (name, opset))
1212

13-
imgs = generate_random_images(shape=(1, 256, 256, 3), dtype=numpy.int32)
14-
15-
benchmark_tflite(url, dest, onnx_name, opset, imgs)
16-
# WARNING - For now, onnxruntime only support float32 type for Gemm rewriter
17-
# onnxruntime: Could not find an implementation for the node pre_tower/split_prequant:Split(13)
13+
imgs = generate_random_images(shape=(15600, ), dtype=numpy.float32, scale=0.)
14+
15+
benchmark_tflite(url, dest, onnx_name, opset, imgs, names=[
16+
('stft/rfft3', 'FFT_stft/rfft4_reshape__190:0'),
17+
('magnitude_spectrogram', 'ComplexAbsmagnitude_spectrogram__206:0'),
18+
('log_mel_spectrogram', 'log_mel_spectrogram'),
19+
])
1820

1921

2022
if __name__ == "__main__":

0 commit comments

Comments
 (0)