Skip to content

Commit 1b381dd

Browse files
committed
add scripts to benchmark tfhub
Signed-off-by: xavier dupré <[email protected]>
1 parent cd64e4e commit 1b381dd

File tree

4 files changed

+287
-7
lines changed

4 files changed

+287
-7
lines changed

examples/benchmark/_tools.py

Lines changed: 171 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,171 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
3+
"""
4+
The following code compares the speed of tensorflow against onnxruntime
5+
with a model downloaded from Tensorflow Hub.
6+
"""
7+
import os
8+
import sys
9+
import time
10+
import tarfile
11+
import zipfile
12+
import subprocess
13+
import datetime
14+
import numpy
15+
from tqdm import tqdm
16+
import onnxruntime
17+
18+
19+
def generate_random_images(shape=(1, 100, 100, 3), n=10, dtype=numpy.float32):
20+
imgs = []
21+
for i in range(n):
22+
sh = shape
23+
img = numpy.clip(numpy.abs(numpy.random.randn(*sh)), 0, 1) * 255
24+
img = img.astype(dtype)
25+
imgs.append(img)
26+
return imgs
27+
28+
29+
def measure_time(fct, imgs, n=50, timeout=15):
30+
"""
31+
Runs *n* times the same function taking one parameter
32+
from *imgs*. It stops if the total time overcomes *timeout*.
33+
It also runs once the function before measuring.
34+
"""
35+
# Let's run it once first.
36+
fct(imgs[0])
37+
# The time is measured for n iterations except if the total time
38+
# overcomes timeout.
39+
results = []
40+
times = []
41+
for i in tqdm(range(0, n)):
42+
img = imgs[i % len(imgs)]
43+
begin = time.perf_counter()
44+
result = fct(img)
45+
end = time.perf_counter()
46+
results.append(result)
47+
times.append(end - begin)
48+
if sum(times) > timeout:
49+
break
50+
return results, times
51+
52+
53+
def download_model(url, dest, verbose=True):
54+
"""
55+
Downloads a model from tfhub and unzips it.
56+
The function assumes the format is `.tar.gz`.
57+
"""
58+
if not os.path.exists(dest):
59+
os.makedirs(dest)
60+
fpath = os.path.join(dest, "model.tar.gz")
61+
if not os.path.exists(fpath):
62+
from tf2onnx import utils
63+
if verbose:
64+
print("Download %r." % fpath)
65+
utils.get_url(url, fpath)
66+
tname = os.path.join(dest, "model_path")
67+
if not os.path.exists(tname):
68+
if verbose:
69+
print("Untar %r." % tname)
70+
tar = tarfile.open(fpath)
71+
tar.extractall(tname)
72+
tar.close()
73+
return fpath, tname
74+
75+
76+
def convert_model(model_name, output_path, opset=13, verbose=True):
77+
"""
78+
Converts the downloaded model into ONNX.
79+
"""
80+
if not os.path.exists(output_path):
81+
begin = datetime.datetime.now()
82+
cmdl = ['-m', 'tf2onnx.convert', '--saved-model',
83+
'"%s"' % os.path.abspath(model_name).replace("\\", "/"),
84+
'--output', '"%s"' % os.path.abspath(output_path).replace("\\", "/"),
85+
'--opset', "%d" % opset]
86+
if verbose:
87+
print("cmd: %s" % " ".join(cmdl))
88+
pproc = subprocess.Popen(
89+
cmdl, shell=True, stdin=None, stdout=subprocess.PIPE, stderr=subprocess.PIPE,
90+
executable=sys.executable.replace("pythonw", "python"))
91+
stdoutdata, stderrdata = pproc.communicate()
92+
if verbose:
93+
print('--OUT--')
94+
print(stdoutdata.decode('ascii'))
95+
print('--ERR--')
96+
print(stderrdata.decode('ascii'))
97+
print("Duration %r." % (datetime.datetime.now() - begin))
98+
99+
100+
def check_discrepencies(out1, out2, threshold=1e-3):
101+
"""
102+
Compares two tensors. Raises an exception if it fails.
103+
"""
104+
if out1.dtype != out2.dtype:
105+
raise AssertionError("Type mismatch %r != %r." % (out1.dtype, out2.dtype))
106+
if out1.shape != out2.shape:
107+
raise AssertionError("Shape mismatch %r != %r." % (out1.shape, out2.shape))
108+
diff = numpy.abs(out1.ravel() - out2.ravel()).max()
109+
if diff > threshold:
110+
raise AssertionError("Discrependcies %r > %r." % (diff, threshold))
111+
112+
113+
def benchmark(url, dest, onnx_name, opset, imgs, verbose=True, threshold=1e-3,
114+
signature=None):
115+
"""
116+
Runs a simple benchmark.
117+
Goes through every steps (download, convert).
118+
Skips them if already done.
119+
"""
120+
fpath, tname = download_model(url, dest)
121+
if verbose:
122+
print("Created %r, %r." % (fpath, tname))
123+
124+
# Converts the model.
125+
if verbose:
126+
print("Convert model in %r." % dest)
127+
convert_model(tname, onnx_name, opset)
128+
if verbose:
129+
print("Created %r." % onnx_name)
130+
131+
# Benchmarks both models.
132+
ort = onnxruntime.InferenceSession(onnx_name)
133+
134+
if verbose:
135+
print("ONNX inputs:")
136+
for a in ort.get_inputs():
137+
print(" {}: {}, {}".format(a.name, a.type, a.shape))
138+
print("ONNX outputs:")
139+
for a in ort.get_outputs():
140+
print(" {}: {}, {}".format(a.name, a.type, a.shape))
141+
142+
input_name = ort.get_inputs()[0].name
143+
fct_ort = lambda img: ort.run(None, {input_name: img})[0]
144+
results_ort, duration_ort = measure_time(fct_ort, imgs)
145+
if verbose:
146+
print("ORT", len(imgs), duration_ort)
147+
148+
import tensorflow_hub as hub
149+
from tensorflow import convert_to_tensor
150+
model = hub.load(url.split("?")[0])
151+
if signature is not None:
152+
model = model.signatures['serving_default']
153+
imgs_tf = [convert_to_tensor(img) for img in imgs]
154+
results_tf, duration_tf = measure_time(model, imgs_tf)
155+
156+
if verbose:
157+
print("TF", len(imgs), duration_tf)
158+
mean_ort = sum(duration_ort) / len(duration_ort)
159+
mean_tf = sum(duration_tf) / len(duration_tf)
160+
print("ratio ORT=%r / TF=%r = %r" % (mean_ort, mean_tf, mean_ort / mean_tf))
161+
162+
res = model(imgs_tf[0])
163+
if isinstance(res, dict):
164+
if len(res) != 1:
165+
raise NotImplementedError("TF output contains more than one output: %r." % res)
166+
output_name = ort.get_outputs()[0].name
167+
if output_name not in res:
168+
raise AssertionError("Unable to find output %r in %r." % (output_name, list(sorted(res))))
169+
res = res[output_name]
170+
check_discrepencies(fct_ort(imgs[0]), res.numpy(), threshold)
171+
return duration_ort, duration_tf

examples/benchmark/tfhub_esrgan.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
import os
3+
import numpy
4+
from _tools import generate_random_images, measure_time, download_model, convert_model, benchmark
5+
6+
url = "https://tfhub.dev/captain-pool/esrgan-tf2/1?tf-hub-format=compressed"
7+
dest = "tf-esrgan-tf2"
8+
name = "esrgan-tf2"
9+
opset = 13
10+
onnx_name = os.path.join(dest, "esrgan-tf2-%d.onnx" % opset)
11+
12+
imgs = generate_random_images()
13+
14+
benchmark(url, dest, onnx_name, opset, imgs)

examples/benchmark/tfhub_thunder.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
import os
3+
import numpy
4+
from _tools import generate_random_images, measure_time, download_model, convert_model, benchmark
5+
6+
url = "https://tfhub.dev/google/movenet/singlepose/thunder/3?tf-hub-format=compressed"
7+
dest = "tf-thunder"
8+
name = "thunder"
9+
opset = 13
10+
onnx_name = os.path.join(dest, "esrgan-tf2-%d.onnx" % opset)
11+
12+
imgs = generate_random_images(shape=(1, 256, 256, 3), dtype=numpy.int32)
13+
14+
benchmark(url, dest, onnx_name, opset, imgs,
15+
signature='serving_default')

examples/benchmark_tfmodel_ort.py

Lines changed: 87 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,18 @@
44
The following code compares the speed of tensorflow against onnxruntime
55
with a model downloaded from Tensorflow Hub.
66
"""
7+
import os
8+
import sys
79
import time
10+
import tarfile
11+
import zipfile
12+
import subprocess
13+
import datetime
814
import numpy
915
from tqdm import tqdm
1016
import tensorflow_hub as hub
1117
import onnxruntime as ort
18+
from tf2onnx import utils, convert
1219

1320

1421
def generate_random_images(shape=(100, 100), n=10):
@@ -21,29 +28,102 @@ def generate_random_images(shape=(100, 100), n=10):
2128
return imgs
2229

2330

24-
def measure_time(fct, imgs):
31+
def measure_time(fct, imgs, n=50, timeout=15):
32+
"""
33+
Runs *n* times the same function taking one parameter
34+
from *imgs*. It stops if the total time overcomes *timeout*.
35+
It also runs once the function before measuring.
36+
"""
37+
# Let's run it once first.
38+
fct(imgs[0])
39+
# The time is measured for n iterations except if the total time
40+
# overcomes timeout.
2541
results = []
2642
times = []
27-
for img in tqdm(imgs):
43+
for i in tqdm(range(0, n)):
44+
img = imgs[i % len(imgs)]
2845
begin = time.perf_counter()
2946
result = fct(img)
3047
end = time.perf_counter()
3148
results.append(result)
3249
times.append(end - begin)
50+
if sum(times) > timeout:
51+
break
3352
return results, times
3453

3554

55+
def download_model(url, dest, verbose=True):
56+
"""
57+
Downloads a model from tfhub and unzips it.
58+
The function assumes the format is `.tar.gz`.
59+
"""
60+
if not os.path.exists(dest):
61+
os.makedirs(dest)
62+
fpath = os.path.join(dest, "model.tar.gz")
63+
if not os.path.exists(fpath):
64+
if verbose:
65+
print("Download %r." % fpath)
66+
utils.get_url(url, fpath)
67+
tname = os.path.join(dest, "model_path")
68+
if not os.path.exists(tname):
69+
if verbose:
70+
print("Untar %r." % tname)
71+
tar = tarfile.open(fpath)
72+
tar.extractall(tname)
73+
tar.close()
74+
return fpath, tname
75+
76+
77+
def convert_model(model_name, output_path, opset=13, verbose=True):
78+
"""
79+
Converts the downloaded model into ONNX.
80+
"""
81+
if not os.path.exists(output_path):
82+
begin = datetime.datetime.now()
83+
cmd = [sys.executable.replace("pythonw", "python"),
84+
'-m', 'tf2onnx.convert', '--saved-model',
85+
'"%s"' % model_name.replace("\\", "/"),
86+
'--output', '"%s"' % output_path.replace("\\", "/"),
87+
'--opset', "%d" % opset]
88+
if verbose:
89+
print("cmd: %s" % " ".join(cmd))
90+
pproc = subprocess.Popen(cmdl, shell=True, stdin=None, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
91+
stdoutdata, stderrdata = pproc.communicate()
92+
if verbose:
93+
print('--OUT--')
94+
print(stdoutdata)
95+
print('--ERR--')
96+
print(stderrdata)
97+
print("Duration %r." % (datetime.datetime.now() - begin))
98+
99+
100+
# Downloads the model
101+
url = "https://tfhub.dev/captain-pool/esrgan-tf2/1?tf-hub-format=compressed"
102+
dest = "tf-esrgan-tf2"
103+
name = "esrgan-tf2"
104+
opset = 13
105+
onnx_name = os.path.join(dest, "esrgan-tf2-%d.onnx" % opset)
106+
107+
fpath, tname = download_model(url, dest)
108+
print("Created %r, %r." % (fpath, tname))
109+
110+
# Converts the model.
111+
print("Convert model in %r." % dest)
112+
convert_model(tname, onnx_name, opset)
113+
print("Created %r." % onnx_name)
114+
115+
# Generates random images.
116+
print("Generates images.")
36117
imgs = generate_random_images()
37118

38-
# Download model from https://tfhub.dev/captain-pool/esrgan-tf2/1
39-
# python -m tf2onnx.convert --saved-model esrgan --output "esrgan-tf2.onnx" --opset 12
40-
ort = ort.InferenceSession('esrgan-tf2.onnx')
119+
# Benchmarks both models.
120+
ort = ort.InferenceSession(onnx_name)
41121
fct_ort = lambda img: ort.run(None, {'input_0': img})
42122
results_ort, duration_ort = measure_time(fct_ort, imgs)
43-
print(len(imgs), duration_ort)
123+
print("ORT", len(imgs), duration_ort)
44124

45125
model = hub.load("https://tfhub.dev/captain-pool/esrgan-tf2/1")
46126
results_tf, duration_tf = measure_time(model, imgs)
47-
print(len(imgs), duration_tf)
127+
print("TF", len(imgs), duration_tf)
48128

49129
print("ratio ORT / TF", sum(duration_ort) / sum(duration_tf))

0 commit comments

Comments
 (0)