Skip to content

Commit bc0e33c

Browse files
authored
Merge pull request #51 from onnx/gs/onnx-1.2
add a simple perf capture to run_pretrained_models.py
2 parents bc7f40e + 52acebb commit bc0e33c

File tree

1 file changed

+42
-2
lines changed

1 file changed

+42
-2
lines changed

tests/run_pretrained_models.py

Lines changed: 42 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import argparse
55
import os
66
import tarfile
7+
import time
78
import tempfile
89
import urllib
910
import urllib.request
@@ -18,6 +19,7 @@
1819
from tf2onnx.tfonnx import process_tf_graph
1920

2021
TMPPATH = tempfile.mkdtemp()
22+
PERFITER = 1000
2123

2224

2325
def get_beach(inputs):
@@ -90,6 +92,9 @@ def __init__(self, url, local, make_input, input_names, output_names,
9092
self.rtol = rtol
9193
self.atol = atol
9294
self.check_only_shape = check_only_shape
95+
self.perf = None
96+
self.tf_runtime = 0
97+
self.onnx_runtime = 0
9398

9499
def download_file(self):
95100
"""Download file from url."""
@@ -131,6 +136,11 @@ def run_tensorflow(self, sess, inputs):
131136
k = sess.graph.get_tensor_by_name(k)
132137
feed_dict[k] = v
133138
result = sess.run(self.output_names, feed_dict=feed_dict)
139+
if self.perf:
140+
start = time.time()
141+
for _ in range(PERFITER):
142+
_ = sess.run(self.output_names, feed_dict=feed_dict)
143+
self.tf_runtime = time.time() - start
134144
return result
135145

136146
@staticmethod
@@ -144,6 +154,11 @@ def run_caffe2(self, name, onnx_graph, inputs):
144154
model_proto = onnx_graph.make_model("test", inputs.keys(), self.output_names)
145155
prepared_backend = caffe2.python.onnx.backend.prepare(model_proto)
146156
results = prepared_backend.run(inputs)
157+
if self.perf:
158+
start = time.time()
159+
for _ in range(PERFITER):
160+
_ = prepared_backend.run(inputs)
161+
self.onnx_runtime = time.time() - start
147162
return results
148163

149164
def run_onnxmsrt(self, name, onnx_graph, inputs):
@@ -156,6 +171,11 @@ def run_onnxmsrt(self, name, onnx_graph, inputs):
156171
f.write(model_proto.SerializeToString())
157172
m = lotus.ModelExecutor(model_path)
158173
results = m.run(self.output_names, inputs)
174+
if self.perf:
175+
start = time.time()
176+
for _ in range(PERFITER):
177+
_ = m.run(self.output_names, inputs)
178+
self.onnx_runtime = time.time() - start
159179
return results
160180

161181
def run_onnxmsrtnext(self, name, onnx_graph, inputs):
@@ -167,6 +187,11 @@ def run_onnxmsrtnext(self, name, onnx_graph, inputs):
167187
f.write(model_proto.SerializeToString())
168188
m = lotus.InferenceSession(model_path)
169189
results = m.run(self.output_names, inputs)
190+
if self.perf:
191+
start = time.time()
192+
for _ in range(PERFITER):
193+
_ = m.run(self.output_names, inputs)
194+
self.onnx_runtime = time.time() - start
170195
return results
171196

172197
def run_onnxcntk(self, name, onnx_graph, inputs):
@@ -182,6 +207,11 @@ def run_onnxcntk(self, name, onnx_graph, inputs):
182207
for arg in z.arguments:
183208
input_args[arg] = inputs[arg.name]
184209
results = z.eval(input_args)
210+
if self.perf:
211+
start = time.time()
212+
for _ in range(PERFITER):
213+
_ = z.eval(input_args)
214+
self.onnx_runtime = time.time() - start
185215
return results
186216

187217
def create_onnx_file(self, name, onnx_graph, inputs, outdir):
@@ -192,9 +222,10 @@ def create_onnx_file(self, name, onnx_graph, inputs, outdir):
192222
f.write(model_proto.SerializeToString())
193223
print("\tcreated", model_path)
194224

195-
def run_test(self, name, backend="caffe2", debug=False, onnx_file=None, opset=None):
225+
def run_test(self, name, backend="caffe2", debug=False, onnx_file=None, opset=None, perf=None):
196226
"""Run complete test against backend."""
197227
print(name)
228+
self.perf = perf
198229
if self.url:
199230
_, dir_name = self.download_file()
200231
model_path = os.path.join(dir_name, self.local)
@@ -270,6 +301,7 @@ def get_args():
270301
parser.add_argument("--debug", help="debug vlog", action="store_true")
271302
parser.add_argument("--list", help="list tests", action="store_true")
272303
parser.add_argument("--onnx-file", help="create onnx file in directory")
304+
parser.add_argument("--perf", help="capture performance numbers")
273305
parser.add_argument("--include-disabled", help="include disabled tests", action="store_true")
274306
args = parser.parse_args()
275307
return args
@@ -312,7 +344,8 @@ def main():
312344
continue
313345
count += 1
314346
try:
315-
ret = t.run_test(test, backend=args.backend, debug=args.debug, onnx_file=args.onnx_file, opset=args.opset)
347+
ret = t.run_test(test, backend=args.backend, debug=args.debug, onnx_file=args.onnx_file,
348+
opset=args.opset, perf=args.perf)
316349
except Exception as ex:
317350
ret = None
318351
print(ex)
@@ -321,6 +354,13 @@ def main():
321354

322355
print("=== RESULT: {} failed of {}, backend={}".format(failed, count, args.backend))
323356

357+
if args.perf:
358+
with open(args.perf, "w") as f:
359+
f.write("test,tensorflow,onnx\n")
360+
for test in test_keys:
361+
t = tests[test]
362+
if t.perf:
363+
f.write("{},{},{}\n".format(test, t.tf_runtime, t.onnx_runtime))
324364

325365
if __name__ == "__main__":
326366
main()

0 commit comments

Comments
 (0)