11
11
import zipfile
12
12
import subprocess
13
13
import datetime
14
+ from collections import OrderedDict
14
15
import numpy
15
16
from tqdm import tqdm
16
17
import onnxruntime
17
18
18
19
19
- def generate_random_images (shape = (1 , 100 , 100 , 3 ), n = 10 , dtype = numpy .float32 ):
20
+ def generate_random_images (shape = (1 , 100 , 100 , 3 ), n = 10 , dtype = numpy .float32 , scale = 255 ):
20
21
imgs = []
21
22
for i in range (n ):
22
23
sh = shape
23
- img = numpy .clip (numpy .abs (numpy .random .randn (* sh )), 0 , 1 ) * 255
24
+ img = numpy .clip (numpy .abs (numpy .random .randn (* sh )), 0 , 1 ) * scale
24
25
img = img .astype (dtype )
25
26
imgs .append (img )
26
27
return imgs
27
28
28
29
30
+ def generate_text_inputs ():
31
+ """
32
+ preprocessor = hub.load("http://tfhub.dev/tensorflow/albert_en_preprocess/3")
33
+ encoder = hub.load("https://tfhub.dev/tensorflow/albert_en_xlarge/3")
34
+ sentences = tf.constant(["Hi I'm some text"])
35
+ embedded_inputs = {k: v.numpy() for k, v in preprocessor(sentences).items()}
36
+ """
37
+ one = OrderedDict ([
38
+ ('input_word_ids' , numpy .array ([[
39
+ 2 , 4148 , 31 , 22 , 79 , 109 , 1854 , 3 , 0 , 0 , 0 ,
40
+ 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 ,
41
+ 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 ,
42
+ 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 ,
43
+ 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 ,
44
+ 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 ,
45
+ 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 ,
46
+ 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 ,
47
+ 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 ,
48
+ 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 ,
49
+ 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 ,
50
+ 0 , 0 , 0 , 0 , 0 , 0 ,0 ]]).reshape ((1 , - 1 ))),
51
+ ('input_type_ids' , numpy .array ([[
52
+ 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 ,
53
+ 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 ,
54
+ 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 ,
55
+ 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 ,
56
+ 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 ,
57
+ 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 ]]).reshape ((1 , - 1 ))),
58
+ ('input_mask' , numpy .array ([[
59
+ 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 ,
60
+ 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 ,
61
+ 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 ,
62
+ 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 ,
63
+ 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 ,
64
+ 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 ]]).reshape ((1 , - 1 )))])
65
+ return [one for i in range (10 )]
66
+
67
+
29
68
def measure_time (fct , imgs , n = 50 , timeout = 15 ):
30
69
"""
31
70
Runs *n* times the same function taking one parameter
@@ -89,16 +128,22 @@ def download_tflite(url, dest, verbose=True):
89
128
return fpath
90
129
91
130
92
- def convert_model (model_name , output_path , opset = 13 , verbose = True ):
131
+ def convert_model (model_name , output_path , opset = 13 , tag = None , verbose = True ):
93
132
"""
94
133
Converts the downloaded model into ONNX.
95
134
"""
135
+ ext = os .path .splitext (output_path )[- 1 ]
136
+ large_model = ext == ".zip"
96
137
if not os .path .exists (output_path ):
97
138
begin = datetime .datetime .now ()
98
139
cmdl = ['-m' , 'tf2onnx.convert' , '--saved-model' ,
99
140
'"%s"' % os .path .abspath (model_name ).replace ("\\ " , "/" ),
100
141
'--output' , '"%s"' % os .path .abspath (output_path ).replace ("\\ " , "/" ),
101
142
'--opset' , "%d" % opset ]
143
+ if tag is not None :
144
+ cmdl .append ('--tag="%s"' % tag )
145
+ if large_model :
146
+ cmdl .append ('--large_model' )
102
147
if verbose :
103
148
print ("cmd: python %s" % " " .join (cmdl ))
104
149
pproc = subprocess .Popen (
@@ -151,7 +196,7 @@ def check_discrepencies(out1, out2, threshold=1e-3):
151
196
152
197
153
198
def benchmark (url , dest , onnx_name , opset , imgs , verbose = True , threshold = 1e-3 ,
154
- signature = None ):
199
+ signature = None , tag = None , output_name = None , ort_name = None ):
155
200
"""
156
201
Runs a simple benchmark.
157
202
Goes through every steps (download, convert).
@@ -164,10 +209,21 @@ def benchmark(url, dest, onnx_name, opset, imgs, verbose=True, threshold=1e-3,
164
209
# Converts the model.
165
210
if verbose :
166
211
print ("Convert model in %r." % dest )
167
- convert_model (tname , onnx_name , opset )
212
+ convert_model (tname , onnx_name , opset , tag = tag )
168
213
if verbose :
169
214
print ("Created %r." % onnx_name )
170
215
216
+ # unzip large_model
217
+ ext = os .path .splitext (onnx_name )[- 1 ]
218
+ if ext == ".zip" :
219
+ onnx_name_unzipped = os .path .join (dest , "large_model" , "__MODEL_PROTO.onnx" )
220
+ if not os .path .exists (onnx_name_unzipped ):
221
+ if verbose :
222
+ print ("Unzip model in %r." % os .path .join (dest , "large_model" ))
223
+ with zipfile .ZipFile (onnx_name , 'r' ) as z :
224
+ z .extractall (os .path .join (dest , "large_model" ))
225
+ onnx_name = onnx_name_unzipped
226
+
171
227
# Benchmarks both models.
172
228
ort = onnxruntime .InferenceSession (onnx_name )
173
229
@@ -180,19 +236,37 @@ def benchmark(url, dest, onnx_name, opset, imgs, verbose=True, threshold=1e-3,
180
236
print (" {}: {}, {}" .format (a .name , a .type , a .shape ))
181
237
182
238
# onnxruntime
183
- input_name = ort .get_inputs ()[0 ].name
184
- fct_ort = lambda img : ort .run (None , {input_name : img })[0 ]
239
+ if output_name is None or ort_name is None :
240
+ index = 0
241
+ else :
242
+ output_names = [o .name for o in ort .get_outputs ()]
243
+ if output_name in output_names :
244
+ index = output_names .index (output_name )
245
+ elif ort_name in output_names :
246
+ index = output_names .index (ort_name )
247
+ else :
248
+ index = 0
249
+ if isinstance (imgs [0 ], dict ):
250
+ fct_ort = lambda img : ort .run (None , img )[index ]
251
+ else :
252
+ input_name = ort .get_inputs ()[0 ].name
253
+ fct_ort = lambda img : ort .run (None , {input_name : img })[index ]
185
254
results_ort , duration_ort = measure_time (fct_ort , imgs )
186
255
if verbose :
187
256
print ("ORT" , len (imgs ), duration_ort )
188
257
189
258
# tensorflow
190
259
import tensorflow_hub as hub
191
260
from tensorflow import convert_to_tensor
261
+ if isinstance (imgs [0 ], OrderedDict ):
262
+ imgs_tf = [
263
+ OrderedDict ((k , convert_to_tensor (v )) for k , v in img .items ())
264
+ for img in imgs ]
265
+ else :
266
+ imgs_tf = [convert_to_tensor (img ) for img in imgs ]
192
267
model = hub .load (url .split ("?" )[0 ])
193
268
if signature is not None :
194
- model = model .signatures ['serving_default' ]
195
- imgs_tf = [convert_to_tensor (img ) for img in imgs ]
269
+ model = model .signatures [signature ]
196
270
results_tf , duration_tf = measure_time (model , imgs_tf )
197
271
198
272
if verbose :
@@ -204,13 +278,27 @@ def benchmark(url, dest, onnx_name, opset, imgs, verbose=True, threshold=1e-3,
204
278
# checks discrepencies
205
279
res = model (imgs_tf [0 ])
206
280
if isinstance (res , dict ):
207
- if len (res ) != 1 :
208
- raise NotImplementedError ("TF output contains more than one output: %r." % res )
209
- output_name = ort .get_outputs ()[0 ].name
281
+ if output_name is None :
282
+ if len (res ) != 1 :
283
+ raise NotImplementedError (
284
+ "TF output contains more than one output=%r and output names=%r." % (
285
+ list (res ), [o .name for o in ort .get_outputs ()]))
286
+ else :
287
+ output_name = ort .get_outputs ()[0 ].name
210
288
if output_name not in res :
211
289
raise AssertionError ("Unable to find output %r in %r." % (output_name , list (sorted (res ))))
212
290
res = res [output_name ]
213
- check_discrepencies (fct_ort (imgs [0 ]), res .numpy (), threshold )
291
+ try :
292
+ check_discrepencies (fct_ort (imgs [0 ]), res .numpy (), threshold )
293
+ except AttributeError as e :
294
+ raise AssertionError (
295
+ "Unable to check discrepencies for res=%r." % res ) from e
296
+ except AssertionError as e :
297
+ output_names = [o .name for o in ort .get_outputs ()]
298
+ res = ort .run (None , imgs [0 ])
299
+ for i , r in enumerate (res ):
300
+ print ("ORT %d: %s: %r: %r" % (i , output_names [i ], r .dtype , r .shape ))
301
+ raise e
214
302
return duration_ort , duration_tf
215
303
216
304
@@ -252,10 +340,15 @@ def benchmark_tflite(url, dest, onnx_name, opset, imgs, verbose=True, threshold=
252
340
# tensorflow
253
341
import tensorflow_hub as hub
254
342
from tensorflow import convert_to_tensor
343
+ if isinstance (imgs [0 ], OrderedDict ):
344
+ imgs_tf = [
345
+ OrderedDict ((k , convert_to_tensor (v )) for k , v in img .items ())
346
+ for img in imgs ]
347
+ else :
348
+ imgs_tf = [convert_to_tensor (img ) for img in imgs ]
255
349
model = hub .load (url .split ("?" )[0 ])
256
350
if signature is not None :
257
351
model = model .signatures ['serving_default' ]
258
- imgs_tf = [convert_to_tensor (img ) for img in imgs ]
259
352
results_tf , duration_tf = measure_time (model , imgs_tf )
260
353
261
354
if verbose :
0 commit comments