Skip to content

Commit 7deccdb

Browse files
committed
add spam-detection
Signed-off-by: xavier dupré <[email protected]>
1 parent e04e29d commit 7deccdb

File tree

5 files changed

+33
-14
lines changed

5 files changed

+33
-14
lines changed

examples/benchmark/_tools.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@ def convert_model(model_name, output_path, opset=13, verbose=True):
8484
'--output', '"%s"' % os.path.abspath(output_path).replace("\\", "/"),
8585
'--opset', "%d" % opset]
8686
if verbose:
87-
print("cmd: %s" % " ".join(cmdl))
87+
print("cmd: python %s" % " ".join(cmdl))
8888
pproc = subprocess.Popen(
8989
cmdl, shell=True, stdin=None, stdout=subprocess.PIPE, stderr=subprocess.PIPE,
9090
executable=sys.executable.replace("pythonw", "python"))
@@ -139,12 +139,14 @@ def benchmark(url, dest, onnx_name, opset, imgs, verbose=True, threshold=1e-3,
139139
for a in ort.get_outputs():
140140
print(" {}: {}, {}".format(a.name, a.type, a.shape))
141141

142+
# onnxruntime
142143
input_name = ort.get_inputs()[0].name
143144
fct_ort = lambda img: ort.run(None, {input_name: img})[0]
144145
results_ort, duration_ort = measure_time(fct_ort, imgs)
145146
if verbose:
146147
print("ORT", len(imgs), duration_ort)
147148

149+
# tensorflow
148150
import tensorflow_hub as hub
149151
from tensorflow import convert_to_tensor
150152
model = hub.load(url.split("?")[0])
@@ -159,6 +161,7 @@ def benchmark(url, dest, onnx_name, opset, imgs, verbose=True, threshold=1e-3,
159161
mean_tf = sum(duration_tf) / len(duration_tf)
160162
print("ratio ORT=%r / TF=%r = %r" % (mean_ort, mean_tf, mean_ort / mean_tf))
161163

164+
# checks discrepencies
162165
res = model(imgs_tf[0])
163166
if isinstance(res, dict):
164167
if len(res) != 1:

examples/benchmark/tfhub_esrgan.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
dest = "tf-esrgan-tf2"
88
name = "esrgan-tf2"
99
opset = 13
10-
onnx_name = os.path.join(dest, "esrgan-tf2-%d.onnx" % opset)
10+
onnx_name = os.path.join(dest, "%s-%d.onnx" % (name, opset))
1111

1212
imgs = generate_random_images()
1313

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
import os
3+
import random
4+
import numpy
5+
from _tools import generate_random_images, measure_time, download_model, convert_model, benchmark
6+
7+
url = "https://tfhub.dev/tensorflow/tutorials/spam-detection/1?tf-hub-format=compressed"
8+
dest = "tf-spam-detection"
9+
name = "spam-detection"
10+
opset = 13
11+
onnx_name = os.path.join(dest, "%s-%d.onnx" % (name, opset))
12+
13+
imgs = generate_random_images((1, 20), dtype=numpy.int32)
14+
15+
benchmark(url, dest, onnx_name, opset, imgs)

examples/benchmark/tfhub_thunder.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
dest = "tf-thunder"
88
name = "thunder"
99
opset = 13
10-
onnx_name = os.path.join(dest, "esrgan-tf2-%d.onnx" % opset)
10+
onnx_name = os.path.join(dest, "%s-%d.onnx" % (name, opset))
1111

1212
imgs = generate_random_images(shape=(1, 256, 256, 3), dtype=numpy.int32)
1313

examples/benchmark_tfmodel_ort.py

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -8,14 +8,13 @@
88
import sys
99
import time
1010
import tarfile
11-
import zipfile
1211
import subprocess
1312
import datetime
1413
import numpy
1514
from tqdm import tqdm
1615
import tensorflow_hub as hub
1716
import onnxruntime as ort
18-
from tf2onnx import utils, convert
17+
from tf2onnx import utils
1918

2019

2120
def generate_random_images(shape=(100, 100), n=10):
@@ -80,14 +79,14 @@ def convert_model(model_name, output_path, opset=13, verbose=True):
8079
"""
8180
if not os.path.exists(output_path):
8281
begin = datetime.datetime.now()
83-
cmd = [sys.executable.replace("pythonw", "python"),
84-
'-m', 'tf2onnx.convert', '--saved-model',
85-
'"%s"' % model_name.replace("\\", "/"),
86-
'--output', '"%s"' % output_path.replace("\\", "/"),
87-
'--opset', "%d" % opset]
82+
cmdl = ['-m', 'tf2onnx.convert', '--saved-model',
83+
'"%s"' % model_name.replace("\\", "/"),
84+
'--output', '"%s"' % output_path.replace("\\", "/"),
85+
'--opset', "%d" % opset]
8886
if verbose:
89-
print("cmd: %s" % " ".join(cmd))
90-
pproc = subprocess.Popen(cmdl, shell=True, stdin=None, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
87+
print("cmd: python %s" % " ".join(cmdl))
88+
pproc = subprocess.Popen(cmdl, shell=True, stdin=None, stdout=subprocess.PIPE, stderr=subprocess.PIPE,
89+
executable=sys.executable.replace("pythonw", "python"))
9190
stdoutdata, stderrdata = pproc.communicate()
9291
if verbose:
9392
print('--OUT--')
@@ -99,7 +98,7 @@ def convert_model(model_name, output_path, opset=13, verbose=True):
9998

10099
# Downloads the model
101100
url = "https://tfhub.dev/captain-pool/esrgan-tf2/1?tf-hub-format=compressed"
102-
dest = "tf-esrgan-tf2"
101+
dest = os.path.abspath("tf-esrgan-tf2")
103102
name = "esrgan-tf2"
104103
opset = 13
105104
onnx_name = os.path.join(dest, "esrgan-tf2-%d.onnx" % opset)
@@ -126,4 +125,6 @@ def convert_model(model_name, output_path, opset=13, verbose=True):
126125
results_tf, duration_tf = measure_time(model, imgs)
127126
print("TF", len(imgs), duration_tf)
128127

129-
print("ratio ORT / TF", sum(duration_ort) / sum(duration_tf))
128+
mean_ort = sum(duration_ort) / len(duration_ort)
129+
mean_tf = sum(duration_tf) / len(duration_tf)
130+
print("ratio ORT=%r / TF=%r = %r" % (mean_ort, mean_tf, mean_ort / mean_tf))

0 commit comments

Comments
 (0)