Skip to content

Commit de67f20

Browse files
authored
Merge pull request #1580 from xadupre/bench
Adds scripts to benchmark tfhub models and check discrepencies
2 parents 9590a8d + 45dcb15 commit de67f20

File tree

6 files changed

+321
-8
lines changed

6 files changed

+321
-8
lines changed

examples/benchmark_tfmodel_ort.py

Lines changed: 89 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,17 @@
44
The following code compares the speed of tensorflow against onnxruntime
55
with a model downloaded from Tensorflow Hub.
66
"""
7+
import os
8+
import sys
79
import time
10+
import tarfile
11+
import subprocess
12+
import datetime
813
import numpy
914
from tqdm import tqdm
1015
import tensorflow_hub as hub
1116
import onnxruntime as ort
17+
from tf2onnx import utils
1218

1319

1420
def generate_random_images(shape=(100, 100), n=10):
@@ -21,29 +27,104 @@ def generate_random_images(shape=(100, 100), n=10):
2127
return imgs
2228

2329

24-
def measure_time(fct, imgs):
30+
def measure_time(fct, imgs, n=50, timeout=15):
31+
"""
32+
Runs *n* times the same function taking one parameter
33+
from *imgs*. It stops if the total time overcomes *timeout*.
34+
It also runs once the function before measuring.
35+
"""
36+
# Let's run it once first.
37+
fct(imgs[0])
38+
# The time is measured for n iterations except if the total time
39+
# overcomes timeout.
2540
results = []
2641
times = []
27-
for img in tqdm(imgs):
42+
for i in tqdm(range(0, n)):
43+
img = imgs[i % len(imgs)]
2844
begin = time.perf_counter()
2945
result = fct(img)
3046
end = time.perf_counter()
3147
results.append(result)
3248
times.append(end - begin)
49+
if sum(times) > timeout:
50+
break
3351
return results, times
3452

3553

54+
def download_model(url, dest, verbose=True):
55+
"""
56+
Downloads a model from tfhub and unzips it.
57+
The function assumes the format is `.tar.gz`.
58+
"""
59+
if not os.path.exists(dest):
60+
os.makedirs(dest)
61+
fpath = os.path.join(dest, "model.tar.gz")
62+
if not os.path.exists(fpath):
63+
if verbose:
64+
print("Download %r." % fpath)
65+
utils.get_url(url, fpath)
66+
tname = os.path.join(dest, "model_path")
67+
if not os.path.exists(tname):
68+
if verbose:
69+
print("Untar %r." % tname)
70+
tar = tarfile.open(fpath)
71+
tar.extractall(tname)
72+
tar.close()
73+
return fpath, tname
74+
75+
76+
def convert_model(model_name, output_path, opset=13, verbose=True):
77+
"""
78+
Converts the downloaded model into ONNX.
79+
"""
80+
if not os.path.exists(output_path):
81+
begin = datetime.datetime.now()
82+
cmdl = ['-m', 'tf2onnx.convert', '--saved-model',
83+
'"%s"' % model_name.replace("\\", "/"),
84+
'--output', '"%s"' % output_path.replace("\\", "/"),
85+
'--opset', "%d" % opset]
86+
if verbose:
87+
print("cmd: python %s" % " ".join(cmdl))
88+
pproc = subprocess.Popen(cmdl, shell=True, stdin=None, stdout=subprocess.PIPE, stderr=subprocess.PIPE,
89+
executable=sys.executable.replace("pythonw", "python"))
90+
stdoutdata, stderrdata = pproc.communicate()
91+
if verbose:
92+
print('--OUT--')
93+
print(stdoutdata)
94+
print('--ERR--')
95+
print(stderrdata)
96+
print("Duration %r." % (datetime.datetime.now() - begin))
97+
98+
99+
# Downloads the model
100+
url = "https://tfhub.dev/captain-pool/esrgan-tf2/1?tf-hub-format=compressed"
101+
dest = os.path.abspath("tf-esrgan-tf2")
102+
name = "esrgan-tf2"
103+
opset = 13
104+
onnx_name = os.path.join(dest, "esrgan-tf2-%d.onnx" % opset)
105+
106+
fpath, tname = download_model(url, dest)
107+
print("Created %r, %r." % (fpath, tname))
108+
109+
# Converts the model.
110+
print("Convert model in %r." % dest)
111+
convert_model(tname, onnx_name, opset)
112+
print("Created %r." % onnx_name)
113+
114+
# Generates random images.
115+
print("Generates images.")
36116
imgs = generate_random_images()
37117

38-
# Download model from https://tfhub.dev/captain-pool/esrgan-tf2/1
39-
# python -m tf2onnx.convert --saved-model esrgan --output "esrgan-tf2.onnx" --opset 12
40-
ort = ort.InferenceSession('esrgan-tf2.onnx')
118+
# Benchmarks both models.
119+
ort = ort.InferenceSession(onnx_name)
41120
fct_ort = lambda img: ort.run(None, {'input_0': img})
42121
results_ort, duration_ort = measure_time(fct_ort, imgs)
43-
print(len(imgs), duration_ort)
122+
print("ORT", len(imgs), duration_ort)
44123

45124
model = hub.load("https://tfhub.dev/captain-pool/esrgan-tf2/1")
46125
results_tf, duration_tf = measure_time(model, imgs)
47-
print(len(imgs), duration_tf)
126+
print("TF", len(imgs), duration_tf)
48127

49-
print("ratio ORT / TF", sum(duration_ort) / sum(duration_tf))
128+
mean_ort = sum(duration_ort) / len(duration_ort)
129+
mean_tf = sum(duration_tf) / len(duration_tf)
130+
print("ratio ORT=%r / TF=%r = %r" % (mean_ort, mean_tf, mean_ort / mean_tf))

tests/tfhub/_tools.py

Lines changed: 174 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,174 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
3+
"""
4+
The following code compares the speed of tensorflow against onnxruntime
5+
with a model downloaded from Tensorflow Hub.
6+
"""
7+
import os
8+
import sys
9+
import time
10+
import tarfile
11+
import zipfile
12+
import subprocess
13+
import datetime
14+
import numpy
15+
from tqdm import tqdm
16+
import onnxruntime
17+
18+
19+
def generate_random_images(shape=(1, 100, 100, 3), n=10, dtype=numpy.float32):
20+
imgs = []
21+
for i in range(n):
22+
sh = shape
23+
img = numpy.clip(numpy.abs(numpy.random.randn(*sh)), 0, 1) * 255
24+
img = img.astype(dtype)
25+
imgs.append(img)
26+
return imgs
27+
28+
29+
def measure_time(fct, imgs, n=50, timeout=15):
30+
"""
31+
Runs *n* times the same function taking one parameter
32+
from *imgs*. It stops if the total time overcomes *timeout*.
33+
It also runs once the function before measuring.
34+
"""
35+
# Let's run it once first.
36+
fct(imgs[0])
37+
# The time is measured for n iterations except if the total time
38+
# overcomes timeout.
39+
results = []
40+
times = []
41+
for i in tqdm(range(0, n)):
42+
img = imgs[i % len(imgs)]
43+
begin = time.perf_counter()
44+
result = fct(img)
45+
end = time.perf_counter()
46+
results.append(result)
47+
times.append(end - begin)
48+
if sum(times) > timeout:
49+
break
50+
return results, times
51+
52+
53+
def download_model(url, dest, verbose=True):
54+
"""
55+
Downloads a model from tfhub and unzips it.
56+
The function assumes the format is `.tar.gz`.
57+
"""
58+
if not os.path.exists(dest):
59+
os.makedirs(dest)
60+
fpath = os.path.join(dest, "model.tar.gz")
61+
if not os.path.exists(fpath):
62+
from tf2onnx import utils
63+
if verbose:
64+
print("Download %r." % fpath)
65+
utils.get_url(url, fpath)
66+
tname = os.path.join(dest, "model_path")
67+
if not os.path.exists(tname):
68+
if verbose:
69+
print("Untar %r." % tname)
70+
tar = tarfile.open(fpath)
71+
tar.extractall(tname)
72+
tar.close()
73+
return fpath, tname
74+
75+
76+
def convert_model(model_name, output_path, opset=13, verbose=True):
77+
"""
78+
Converts the downloaded model into ONNX.
79+
"""
80+
if not os.path.exists(output_path):
81+
begin = datetime.datetime.now()
82+
cmdl = ['-m', 'tf2onnx.convert', '--saved-model',
83+
'"%s"' % os.path.abspath(model_name).replace("\\", "/"),
84+
'--output', '"%s"' % os.path.abspath(output_path).replace("\\", "/"),
85+
'--opset', "%d" % opset]
86+
if verbose:
87+
print("cmd: python %s" % " ".join(cmdl))
88+
pproc = subprocess.Popen(
89+
cmdl, shell=True, stdin=None, stdout=subprocess.PIPE, stderr=subprocess.PIPE,
90+
executable=sys.executable.replace("pythonw", "python"))
91+
stdoutdata, stderrdata = pproc.communicate()
92+
if verbose:
93+
print('--OUT--')
94+
print(stdoutdata.decode('ascii'))
95+
print('--ERR--')
96+
print(stderrdata.decode('ascii'))
97+
print("Duration %r." % (datetime.datetime.now() - begin))
98+
99+
100+
def check_discrepencies(out1, out2, threshold=1e-3):
101+
"""
102+
Compares two tensors. Raises an exception if it fails.
103+
"""
104+
if out1.dtype != out2.dtype:
105+
raise AssertionError("Type mismatch %r != %r." % (out1.dtype, out2.dtype))
106+
if out1.shape != out2.shape:
107+
raise AssertionError("Shape mismatch %r != %r." % (out1.shape, out2.shape))
108+
diff = numpy.abs(out1.ravel() - out2.ravel()).max()
109+
if diff > threshold:
110+
raise AssertionError("Discrependcies %r > %r." % (diff, threshold))
111+
112+
113+
def benchmark(url, dest, onnx_name, opset, imgs, verbose=True, threshold=1e-3,
114+
signature=None):
115+
"""
116+
Runs a simple benchmark.
117+
Goes through every steps (download, convert).
118+
Skips them if already done.
119+
"""
120+
fpath, tname = download_model(url, dest)
121+
if verbose:
122+
print("Created %r, %r." % (fpath, tname))
123+
124+
# Converts the model.
125+
if verbose:
126+
print("Convert model in %r." % dest)
127+
convert_model(tname, onnx_name, opset)
128+
if verbose:
129+
print("Created %r." % onnx_name)
130+
131+
# Benchmarks both models.
132+
ort = onnxruntime.InferenceSession(onnx_name)
133+
134+
if verbose:
135+
print("ONNX inputs:")
136+
for a in ort.get_inputs():
137+
print(" {}: {}, {}".format(a.name, a.type, a.shape))
138+
print("ONNX outputs:")
139+
for a in ort.get_outputs():
140+
print(" {}: {}, {}".format(a.name, a.type, a.shape))
141+
142+
# onnxruntime
143+
input_name = ort.get_inputs()[0].name
144+
fct_ort = lambda img: ort.run(None, {input_name: img})[0]
145+
results_ort, duration_ort = measure_time(fct_ort, imgs)
146+
if verbose:
147+
print("ORT", len(imgs), duration_ort)
148+
149+
# tensorflow
150+
import tensorflow_hub as hub
151+
from tensorflow import convert_to_tensor
152+
model = hub.load(url.split("?")[0])
153+
if signature is not None:
154+
model = model.signatures['serving_default']
155+
imgs_tf = [convert_to_tensor(img) for img in imgs]
156+
results_tf, duration_tf = measure_time(model, imgs_tf)
157+
158+
if verbose:
159+
print("TF", len(imgs), duration_tf)
160+
mean_ort = sum(duration_ort) / len(duration_ort)
161+
mean_tf = sum(duration_tf) / len(duration_tf)
162+
print("ratio ORT=%r / TF=%r = %r" % (mean_ort, mean_tf, mean_ort / mean_tf))
163+
164+
# checks discrepencies
165+
res = model(imgs_tf[0])
166+
if isinstance(res, dict):
167+
if len(res) != 1:
168+
raise NotImplementedError("TF output contains more than one output: %r." % res)
169+
output_name = ort.get_outputs()[0].name
170+
if output_name not in res:
171+
raise AssertionError("Unable to find output %r in %r." % (output_name, list(sorted(res))))
172+
res = res[output_name]
173+
check_discrepencies(fct_ort(imgs[0]), res.numpy(), threshold)
174+
return duration_ort, duration_tf

tests/tfhub/tfhub_esrgan.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
import os
3+
import numpy
4+
from _tools import generate_random_images, measure_time, download_model, convert_model, benchmark
5+
6+
url = "https://tfhub.dev/captain-pool/esrgan-tf2/1?tf-hub-format=compressed"
7+
dest = "tf-esrgan-tf2"
8+
name = "esrgan-tf2"
9+
opset = 13
10+
onnx_name = os.path.join(dest, "%s-%d.onnx" % (name, opset))
11+
12+
imgs = generate_random_images()
13+
14+
benchmark(url, dest, onnx_name, opset, imgs)

tests/tfhub/tfhub_resnet_v2_152.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
import os
3+
import numpy
4+
from _tools import generate_random_images, measure_time, download_model, convert_model, benchmark
5+
6+
url = "https://tfhub.dev/google/imagenet/resnet_v2_152/classification/5?tf-hub-format=compressed"
7+
dest = "tf-resnet_v2_152"
8+
name = "resnet_v2_152"
9+
opset = 13
10+
onnx_name = os.path.join(dest, "%s-%d.onnx" % (name, opset))
11+
12+
imgs = generate_random_images(shape=(1, 224, 224, 3))
13+
14+
benchmark(url, dest, onnx_name, opset, imgs)

tests/tfhub/tfhub_spam_detection.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
import os
3+
import random
4+
import numpy
5+
from _tools import generate_random_images, measure_time, download_model, convert_model, benchmark
6+
7+
url = "https://tfhub.dev/tensorflow/tutorials/spam-detection/1?tf-hub-format=compressed"
8+
dest = "tf-spam-detection"
9+
name = "spam-detection"
10+
opset = 13
11+
onnx_name = os.path.join(dest, "%s-%d.onnx" % (name, opset))
12+
13+
imgs = generate_random_images((1, 20), dtype=numpy.int32)
14+
15+
benchmark(url, dest, onnx_name, opset, imgs)

tests/tfhub/tfhub_thunder.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
import os
3+
import numpy
4+
from _tools import generate_random_images, measure_time, download_model, convert_model, benchmark
5+
6+
url = "https://tfhub.dev/google/movenet/singlepose/thunder/3?tf-hub-format=compressed"
7+
dest = "tf-thunder"
8+
name = "thunder"
9+
opset = 13
10+
onnx_name = os.path.join(dest, "%s-%d.onnx" % (name, opset))
11+
12+
imgs = generate_random_images(shape=(1, 256, 256, 3), dtype=numpy.int32)
13+
14+
benchmark(url, dest, onnx_name, opset, imgs,
15+
signature='serving_default')

0 commit comments

Comments
 (0)