Skip to content

Commit 23ac300

Browse files
authored
Merge pull request #1587 from xadupre/bench2
Add yamnet to the list of tfhub tested models
2 parents d18d3f7 + dff63d1 commit 23ac300

File tree

7 files changed

+195
-33
lines changed

7 files changed

+195
-33
lines changed

tests/tfhub/_tools.py

Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,22 @@ def download_model(url, dest, verbose=True):
7373
return fpath, tname
7474

7575

76+
def download_tflite(url, dest, verbose=True):
77+
"""
78+
Downloads a model from tfhub.
79+
The function assumes the format is `.tflite`.
80+
"""
81+
if not os.path.exists(dest):
82+
os.makedirs(dest)
83+
fpath = os.path.join(dest, "model.tflite")
84+
if not os.path.exists(fpath):
85+
from tf2onnx import utils
86+
if verbose:
87+
print("Download %r." % fpath)
88+
utils.get_url(url, fpath)
89+
return fpath
90+
91+
7692
def convert_model(model_name, output_path, opset=13, verbose=True):
7793
"""
7894
Converts the downloaded model into ONNX.
@@ -97,6 +113,30 @@ def convert_model(model_name, output_path, opset=13, verbose=True):
97113
print("Duration %r." % (datetime.datetime.now() - begin))
98114

99115

116+
def convert_tflite(model_name, output_path, opset=13, verbose=True):
117+
"""
118+
Converts the downloaded model into ONNX.
119+
"""
120+
if not os.path.exists(output_path):
121+
begin = datetime.datetime.now()
122+
cmdl = ['-m', 'tf2onnx.convert', '--tflite',
123+
'"%s"' % os.path.abspath(model_name).replace("\\", "/"),
124+
'--output', '"%s"' % os.path.abspath(output_path).replace("\\", "/"),
125+
'--opset', "%d" % opset]
126+
if verbose:
127+
print("cmd: python %s" % " ".join(cmdl))
128+
pproc = subprocess.Popen(
129+
cmdl, shell=True, stdin=None, stdout=subprocess.PIPE, stderr=subprocess.PIPE,
130+
executable=sys.executable.replace("pythonw", "python"))
131+
stdoutdata, stderrdata = pproc.communicate()
132+
if verbose:
133+
print('--OUT--')
134+
print(stdoutdata.decode('ascii'))
135+
print('--ERR--')
136+
print(stderrdata.decode('ascii'))
137+
print("Duration %r." % (datetime.datetime.now() - begin))
138+
139+
100140
def check_discrepencies(out1, out2, threshold=1e-3):
101141
"""
102142
Compares two tensors. Raises an exception if it fails.
@@ -172,3 +212,66 @@ def benchmark(url, dest, onnx_name, opset, imgs, verbose=True, threshold=1e-3,
172212
res = res[output_name]
173213
check_discrepencies(fct_ort(imgs[0]), res.numpy(), threshold)
174214
return duration_ort, duration_tf
215+
216+
217+
def benchmark_tflite(url, dest, onnx_name, opset, imgs, verbose=True, threshold=1e-3):
218+
"""
219+
Runs a simple benchmark with a tflite model.
220+
Goes through every steps (download, convert).
221+
Skips them if already done.
222+
"""
223+
tname = download_tflite(url, dest)
224+
if verbose:
225+
print("Created %r." % tname)
226+
227+
# Converts the model.
228+
if verbose:
229+
print("Convert model in %r." % dest)
230+
convert_tflite(tname, onnx_name, opset)
231+
if verbose:
232+
print("Created %r." % onnx_name)
233+
234+
# Benchmarks both models.
235+
ort = onnxruntime.InferenceSession(onnx_name)
236+
237+
if verbose:
238+
print("ONNX inputs:")
239+
for a in ort.get_inputs():
240+
print(" {}: {}, {}".format(a.name, a.type, a.shape))
241+
print("ONNX outputs:")
242+
for a in ort.get_outputs():
243+
print(" {}: {}, {}".format(a.name, a.type, a.shape))
244+
245+
# onnxruntime
246+
input_name = ort.get_inputs()[0].name
247+
fct_ort = lambda img: ort.run(None, {input_name: img})[0]
248+
results_ort, duration_ort = measure_time(fct_ort, imgs)
249+
if verbose:
250+
print("ORT", len(imgs), duration_ort)
251+
252+
# tensorflow
253+
import tensorflow_hub as hub
254+
from tensorflow import convert_to_tensor
255+
model = hub.load(url.split("?")[0])
256+
if signature is not None:
257+
model = model.signatures['serving_default']
258+
imgs_tf = [convert_to_tensor(img) for img in imgs]
259+
results_tf, duration_tf = measure_time(model, imgs_tf)
260+
261+
if verbose:
262+
print("TF", len(imgs), duration_tf)
263+
mean_ort = sum(duration_ort) / len(duration_ort)
264+
mean_tf = sum(duration_tf) / len(duration_tf)
265+
print("ratio ORT=%r / TF=%r = %r" % (mean_ort, mean_tf, mean_ort / mean_tf))
266+
267+
# checks discrepencies
268+
res = model(imgs_tf[0])
269+
if isinstance(res, dict):
270+
if len(res) != 1:
271+
raise NotImplementedError("TF output contains more than one output: %r." % res)
272+
output_name = ort.get_outputs()[0].name
273+
if output_name not in res:
274+
raise AssertionError("Unable to find output %r in %r." % (output_name, list(sorted(res))))
275+
res = res[output_name]
276+
check_discrepencies(fct_ort(imgs[0]), res.numpy(), threshold)
277+
return duration_ort, duration_tf

tests/tfhub/tfhub_albert_en_xlarge.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
import os
3+
import numpy
4+
from _tools import generate_random_images, benchmark
5+
6+
7+
def main(opset=13):
8+
url = "https://tfhub.dev/tensorflow/albert_en_xlarge/3?tf-hub-format=compressed"
9+
dest = "tf-albert-en-xlarge"
10+
name = "albert-en-xlarge"
11+
onnx_name = os.path.join(dest, "%s-%d.onnx" % (name, opset))
12+
13+
imgs = generate_random_images(shape=(1, 256, 256, 3), dtype=numpy.int32)
14+
15+
benchmark(url, dest, onnx_name, opset, imgs,
16+
signature='serving_default')
17+
18+
19+
if __name__ == "__main__":
20+
main()

tests/tfhub/tfhub_esrgan.py

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,19 @@
11
# SPDX-License-Identifier: Apache-2.0
22
import os
33
import numpy
4-
from _tools import generate_random_images, measure_time, download_model, convert_model, benchmark
4+
from _tools import generate_random_images, benchmark
55

6-
url = "https://tfhub.dev/captain-pool/esrgan-tf2/1?tf-hub-format=compressed"
7-
dest = "tf-esrgan-tf2"
8-
name = "esrgan-tf2"
9-
opset = 13
10-
onnx_name = os.path.join(dest, "%s-%d.onnx" % (name, opset))
116

12-
imgs = generate_random_images()
7+
def main(opset=13):
8+
url = "https://tfhub.dev/captain-pool/esrgan-tf2/1?tf-hub-format=compressed"
9+
dest = "tf-esrgan-tf2"
10+
name = "esrgan-tf2"
11+
onnx_name = os.path.join(dest, "%s-%d.onnx" % (name, opset))
1312

14-
benchmark(url, dest, onnx_name, opset, imgs)
13+
imgs = generate_random_images()
14+
15+
benchmark(url, dest, onnx_name, opset, imgs)
16+
17+
18+
if __name__ == "__main__":
19+
main()

tests/tfhub/tfhub_resnet_v2_152.py

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,19 @@
11
# SPDX-License-Identifier: Apache-2.0
22
import os
33
import numpy
4-
from _tools import generate_random_images, measure_time, download_model, convert_model, benchmark
4+
from _tools import generate_random_images, benchmark
55

6-
url = "https://tfhub.dev/google/imagenet/resnet_v2_152/classification/5?tf-hub-format=compressed"
7-
dest = "tf-resnet_v2_152"
8-
name = "resnet_v2_152"
9-
opset = 13
10-
onnx_name = os.path.join(dest, "%s-%d.onnx" % (name, opset))
116

12-
imgs = generate_random_images(shape=(1, 224, 224, 3))
7+
def main(opset=13):
8+
url = "https://tfhub.dev/google/imagenet/resnet_v2_152/classification/5?tf-hub-format=compressed"
9+
dest = "tf-resnet_v2_152"
10+
name = "resnet_v2_152"
11+
onnx_name = os.path.join(dest, "%s-%d.onnx" % (name, opset))
1312

14-
benchmark(url, dest, onnx_name, opset, imgs)
13+
imgs = generate_random_images(shape=(1, 224, 224, 3))
14+
15+
benchmark(url, dest, onnx_name, opset, imgs)
16+
17+
18+
if __name__ == "__main__":
19+
main()

tests/tfhub/tfhub_spam_detection.py

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2,14 +2,19 @@
22
import os
33
import random
44
import numpy
5-
from _tools import generate_random_images, measure_time, download_model, convert_model, benchmark
5+
from _tools import generate_random_images, benchmark
66

7-
url = "https://tfhub.dev/tensorflow/tutorials/spam-detection/1?tf-hub-format=compressed"
8-
dest = "tf-spam-detection"
9-
name = "spam-detection"
10-
opset = 13
11-
onnx_name = os.path.join(dest, "%s-%d.onnx" % (name, opset))
127

13-
imgs = generate_random_images((1, 20), dtype=numpy.int32)
8+
def main(opset=13):
9+
url = "https://tfhub.dev/tensorflow/tutorials/spam-detection/1?tf-hub-format=compressed"
10+
dest = "tf-spam-detection"
11+
name = "spam-detection"
12+
onnx_name = os.path.join(dest, "%s-%d.onnx" % (name, opset))
1413

15-
benchmark(url, dest, onnx_name, opset, imgs)
14+
imgs = generate_random_images((1, 20), dtype=numpy.int32)
15+
16+
benchmark(url, dest, onnx_name, opset, imgs)
17+
18+
19+
if __name__ == "__main__":
20+
main()

tests/tfhub/tfhub_thunder.py

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,20 @@
11
# SPDX-License-Identifier: Apache-2.0
22
import os
33
import numpy
4-
from _tools import generate_random_images, measure_time, download_model, convert_model, benchmark
4+
from _tools import generate_random_images, benchmark
55

6-
url = "https://tfhub.dev/google/movenet/singlepose/thunder/3?tf-hub-format=compressed"
7-
dest = "tf-thunder"
8-
name = "thunder"
9-
opset = 13
10-
onnx_name = os.path.join(dest, "%s-%d.onnx" % (name, opset))
116

12-
imgs = generate_random_images(shape=(1, 256, 256, 3), dtype=numpy.int32)
7+
def main(opset=13):
8+
url = "https://tfhub.dev/google/movenet/singlepose/thunder/3?tf-hub-format=compressed"
9+
dest = "tf-thunder"
10+
name = "thunder"
11+
onnx_name = os.path.join(dest, "%s-%d.onnx" % (name, opset))
1312

14-
benchmark(url, dest, onnx_name, opset, imgs,
15-
signature='serving_default')
13+
imgs = generate_random_images(shape=(1, 256, 256, 3), dtype=numpy.int32)
14+
15+
benchmark(url, dest, onnx_name, opset, imgs,
16+
signature='serving_default')
17+
18+
19+
if __name__ == "__main__":
20+
main()

tests/tfhub/tfhub_yamnet.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
import os
3+
import numpy
4+
from _tools import generate_random_images, benchmark_tflite
5+
6+
7+
def main(opset=13):
8+
url = "https://tfhub.dev/google/coral-model/yamnet/classification/coral/1?coral-format=tflite"
9+
dest = "tf-yamnet"
10+
name = "yamnet"
11+
onnx_name = os.path.join(dest, "%s-%d.onnx" % (name, opset))
12+
13+
imgs = generate_random_images(shape=(1, 256, 256, 3), dtype=numpy.int32)
14+
15+
benchmark_tflite(url, dest, onnx_name, opset, imgs)
16+
17+
18+
if __name__ == "__main__":
19+
main()

0 commit comments

Comments
 (0)