10
10
11
11
import argparse
12
12
import os
13
+ import re
13
14
import sys
14
15
import tarfile
15
16
import time
@@ -98,8 +99,8 @@ def __init__(self, url, local, make_input, input_names, output_names,
98
99
self .force_input_shape = force_input_shape
99
100
self .skip_tensorflow = skip_tensorflow
100
101
101
- def download_file (self ):
102
- """Download file from url."""
102
+ def download_model (self ):
103
+ """Download model from url."""
103
104
cache_dir = Test .cache_dir
104
105
if not os .path .exists (cache_dir ):
105
106
os .makedirs (cache_dir )
@@ -193,13 +194,13 @@ def run_test(self, name, backend="caffe2", onnx_file=None, opset=None, extra_ops
193
194
194
195
# get the model
195
196
if self .url :
196
- _ , dir_name = self .download_file ()
197
+ _ , dir_name = self .download_model ()
198
+ logger .info ("Downloaded to %s" , dir_name )
197
199
model_path = os .path .join (dir_name , self .local )
198
200
else :
199
201
model_path = self .local
200
- dir_name = os .path .dirname (self .local )
201
- logger .info ("Downloaded to %s" , model_path )
202
202
203
+ logger .info ("Load model from %s" , model_path )
203
204
input_names = list (self .input_names .keys ())
204
205
outputs = self .output_names
205
206
if self .model_type in ["checkpoint" ]:
@@ -324,10 +325,13 @@ def get_args():
324
325
return args
325
326
326
327
327
- def tests_from_yaml (fname ):
328
+ def tests_from_yaml (path ):
328
329
"""Create test class from yaml file."""
330
+ path = os .path .abspath (path )
331
+ base_dir = os .path .dirname (path )
332
+
329
333
tests = {}
330
- config = yaml .load (open (fname , 'r' ).read ())
334
+ config = yaml .load (open (path , 'r' ).read ())
331
335
for k , v in config .items ():
332
336
input_func = v .get ("input_get" )
333
337
input_func = _INPUT_FUNC_MAPPING [input_func ]
@@ -337,7 +341,27 @@ def tests_from_yaml(fname):
337
341
if v .get (kw ) is not None :
338
342
kwargs [kw ] = v [kw ]
339
343
340
- test = Test (v .get ("url" ), v .get ("model" ), input_func , v .get ("inputs" ), v .get ("outputs" ), ** kwargs )
344
+ # when model is local, non-absolute path is relative to yaml directory
345
+ url = v .get ("url" )
346
+ model = v .get ("model" )
347
+ if not url :
348
+ if not os .path .isabs (model ):
349
+ model = os .path .join (base_dir , model )
350
+
351
+ # non-absolute npy file path for np.load is relative to yaml directory
352
+ input_names = v .get ("inputs" )
353
+ for key in list (input_names .keys ()):
354
+ value = input_names [key ]
355
+ if isinstance (value , str ):
356
+ # assume at most 1 match
357
+ matches = re .findall (r"np\.load\((r?['\"].*?['\"])" , value )
358
+ if matches :
359
+ npy_path = matches [0 ].lstrip ('r' ).strip ("'" ).strip ('"' )
360
+ if not os .path .isabs (npy_path ):
361
+ abs_npy_path = os .path .join (base_dir , npy_path )
362
+ input_names [key ] = value .replace (matches [0 ], "r'{}'" .format (abs_npy_path ))
363
+
364
+ test = Test (url , model , input_func , input_names , v .get ("outputs" ), ** kwargs )
341
365
tests [k ] = test
342
366
return tests
343
367
0 commit comments