4
4
import argparse
5
5
import os
6
6
import tarfile
7
+ import time
7
8
import tempfile
8
9
import urllib
9
10
import urllib .request
18
19
from tf2onnx .tfonnx import process_tf_graph
19
20
20
21
TMPPATH = tempfile .mkdtemp ()
22
+ PERFITER = 1000
21
23
22
24
23
25
def get_beach (inputs ):
@@ -90,6 +92,9 @@ def __init__(self, url, local, make_input, input_names, output_names,
90
92
self .rtol = rtol
91
93
self .atol = atol
92
94
self .check_only_shape = check_only_shape
95
+ self .perf = None
96
+ self .tf_runtime = 0
97
+ self .onnx_runtime = 0
93
98
94
99
def download_file (self ):
95
100
"""Download file from url."""
@@ -131,6 +136,11 @@ def run_tensorflow(self, sess, inputs):
131
136
k = sess .graph .get_tensor_by_name (k )
132
137
feed_dict [k ] = v
133
138
result = sess .run (self .output_names , feed_dict = feed_dict )
139
+ if self .perf :
140
+ start = time .time ()
141
+ for _ in range (PERFITER ):
142
+ _ = sess .run (self .output_names , feed_dict = feed_dict )
143
+ self .tf_runtime = time .time () - start
134
144
return result
135
145
136
146
@staticmethod
@@ -144,6 +154,11 @@ def run_caffe2(self, name, onnx_graph, inputs):
144
154
model_proto = onnx_graph .make_model ("test" , inputs .keys (), self .output_names )
145
155
prepared_backend = caffe2 .python .onnx .backend .prepare (model_proto )
146
156
results = prepared_backend .run (inputs )
157
+ if self .perf :
158
+ start = time .time ()
159
+ for _ in range (PERFITER ):
160
+ _ = prepared_backend .run (inputs )
161
+ self .onnx_runtime = time .time () - start
147
162
return results
148
163
149
164
def run_onnxmsrt (self , name , onnx_graph , inputs ):
@@ -156,6 +171,11 @@ def run_onnxmsrt(self, name, onnx_graph, inputs):
156
171
f .write (model_proto .SerializeToString ())
157
172
m = lotus .ModelExecutor (model_path )
158
173
results = m .run (self .output_names , inputs )
174
+ if self .perf :
175
+ start = time .time ()
176
+ for _ in range (PERFITER ):
177
+ _ = m .run (self .output_names , inputs )
178
+ self .onnx_runtime = time .time () - start
159
179
return results
160
180
161
181
def run_onnxmsrtnext (self , name , onnx_graph , inputs ):
@@ -167,6 +187,11 @@ def run_onnxmsrtnext(self, name, onnx_graph, inputs):
167
187
f .write (model_proto .SerializeToString ())
168
188
m = lotus .InferenceSession (model_path )
169
189
results = m .run (self .output_names , inputs )
190
+ if self .perf :
191
+ start = time .time ()
192
+ for _ in range (PERFITER ):
193
+ _ = m .run (self .output_names , inputs )
194
+ self .onnx_runtime = time .time () - start
170
195
return results
171
196
172
197
def run_onnxcntk (self , name , onnx_graph , inputs ):
@@ -182,6 +207,11 @@ def run_onnxcntk(self, name, onnx_graph, inputs):
182
207
for arg in z .arguments :
183
208
input_args [arg ] = inputs [arg .name ]
184
209
results = z .eval (input_args )
210
+ if self .perf :
211
+ start = time .time ()
212
+ for _ in range (PERFITER ):
213
+ _ = z .eval (input_args )
214
+ self .onnx_runtime = time .time () - start
185
215
return results
186
216
187
217
def create_onnx_file (self , name , onnx_graph , inputs , outdir ):
@@ -192,9 +222,10 @@ def create_onnx_file(self, name, onnx_graph, inputs, outdir):
192
222
f .write (model_proto .SerializeToString ())
193
223
print ("\t created" , model_path )
194
224
195
- def run_test (self , name , backend = "caffe2" , debug = False , onnx_file = None , opset = None ):
225
+ def run_test (self , name , backend = "caffe2" , debug = False , onnx_file = None , opset = None , perf = None ):
196
226
"""Run complete test against backend."""
197
227
print (name )
228
+ self .perf = perf
198
229
if self .url :
199
230
_ , dir_name = self .download_file ()
200
231
model_path = os .path .join (dir_name , self .local )
@@ -270,6 +301,7 @@ def get_args():
270
301
parser .add_argument ("--debug" , help = "debug vlog" , action = "store_true" )
271
302
parser .add_argument ("--list" , help = "list tests" , action = "store_true" )
272
303
parser .add_argument ("--onnx-file" , help = "create onnx file in directory" )
304
+ parser .add_argument ("--perf" , help = "capture performance numbers" )
273
305
parser .add_argument ("--include-disabled" , help = "include disabled tests" , action = "store_true" )
274
306
args = parser .parse_args ()
275
307
return args
@@ -312,7 +344,8 @@ def main():
312
344
continue
313
345
count += 1
314
346
try :
315
- ret = t .run_test (test , backend = args .backend , debug = args .debug , onnx_file = args .onnx_file , opset = args .opset )
347
+ ret = t .run_test (test , backend = args .backend , debug = args .debug , onnx_file = args .onnx_file ,
348
+ opset = args .opset , perf = args .perf )
316
349
except Exception as ex :
317
350
ret = None
318
351
print (ex )
@@ -321,6 +354,13 @@ def main():
321
354
322
355
print ("=== RESULT: {} failed of {}, backend={}" .format (failed , count , args .backend ))
323
356
357
+ if args .perf :
358
+ with open (args .perf , "w" ) as f :
359
+ f .write ("test,tensorflow,onnx\n " )
360
+ for test in test_keys :
361
+ t = tests [test ]
362
+ if t .perf :
363
+ f .write ("{},{},{}\n " .format (test , t .tf_runtime , t .onnx_runtime ))
324
364
325
365
if __name__ == "__main__" :
326
366
main ()
0 commit comments