Skip to content

Commit a66c448

Browse files
committed
run_pretrained_model path relative to config dir
1 parent 21b3e5f commit a66c448

File tree

2 files changed

+44
-20
lines changed

2 files changed

+44
-20
lines changed

tests/run_pretrained_models.py

Lines changed: 32 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010

1111
import argparse
1212
import os
13+
import re
1314
import sys
1415
import tarfile
1516
import time
@@ -98,8 +99,8 @@ def __init__(self, url, local, make_input, input_names, output_names,
9899
self.force_input_shape = force_input_shape
99100
self.skip_tensorflow = skip_tensorflow
100101

101-
def download_file(self):
102-
"""Download file from url."""
102+
def download_model(self):
103+
"""Download model from url."""
103104
cache_dir = Test.cache_dir
104105
if not os.path.exists(cache_dir):
105106
os.makedirs(cache_dir)
@@ -193,13 +194,13 @@ def run_test(self, name, backend="caffe2", onnx_file=None, opset=None, extra_ops
193194

194195
# get the model
195196
if self.url:
196-
_, dir_name = self.download_file()
197+
_, dir_name = self.download_model()
198+
logger.info("Downloaded to %s", dir_name)
197199
model_path = os.path.join(dir_name, self.local)
198200
else:
199201
model_path = self.local
200-
dir_name = os.path.dirname(self.local)
201-
logger.info("Downloaded to %s", model_path)
202202

203+
logger.info("Load model from %s", model_path)
203204
input_names = list(self.input_names.keys())
204205
outputs = self.output_names
205206
if self.model_type in ["checkpoint"]:
@@ -324,10 +325,13 @@ def get_args():
324325
return args
325326

326327

327-
def tests_from_yaml(fname):
328+
def tests_from_yaml(path):
328329
"""Create test class from yaml file."""
330+
path = os.path.abspath(path)
331+
base_dir = os.path.dirname(path)
332+
329333
tests = {}
330-
config = yaml.load(open(fname, 'r').read())
334+
config = yaml.load(open(path, 'r').read())
331335
for k, v in config.items():
332336
input_func = v.get("input_get")
333337
input_func = _INPUT_FUNC_MAPPING[input_func]
@@ -337,7 +341,27 @@ def tests_from_yaml(fname):
337341
if v.get(kw) is not None:
338342
kwargs[kw] = v[kw]
339343

340-
test = Test(v.get("url"), v.get("model"), input_func, v.get("inputs"), v.get("outputs"), **kwargs)
344+
# when model is local, non-absolute path is relative to yaml directory
345+
url = v.get("url")
346+
model = v.get("model")
347+
if not url:
348+
if not os.path.isabs(model):
349+
model = os.path.join(base_dir, model)
350+
351+
# non-absolute npy file path for np.load is relative to yaml directory
352+
input_names = v.get("inputs")
353+
for key in list(input_names.keys()):
354+
value = input_names[key]
355+
if isinstance(value, str):
356+
# assume at most 1 match
357+
matches = re.findall(r"np\.load\((r?['\"].*?['\"])", value)
358+
if matches:
359+
npy_path = matches[0].lstrip('r').strip("'").strip('"')
360+
if not os.path.isabs(npy_path):
361+
abs_npy_path = os.path.join(base_dir, npy_path)
362+
input_names[key] = value.replace(matches[0], "r'{}'".format(abs_npy_path))
363+
364+
test = Test(url, model, input_func, input_names, v.get("outputs"), **kwargs)
341365
tests[k] = test
342366
return tests
343367

tests/run_pretrained_models.yaml

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -2,15 +2,15 @@
22
# simple models for basic functional test
33
#
44
regression-graphdef:
5-
model: tests/models/regression/graphdef/frozen.pb
5+
model: models/regression/graphdef/frozen.pb
66
input_get: get_ramp
77
inputs:
88
"X:0": [1]
99
outputs:
1010
- pred:0
1111

1212
regression-checkpoint:
13-
model: tests/models/regression/checkpoint/model.meta
13+
model: models/regression/checkpoint/model.meta
1414
model_type: checkpoint
1515
input_get: get_ramp
1616
inputs:
@@ -19,7 +19,7 @@ regression-checkpoint:
1919
- pred:0
2020

2121
regression-saved-model:
22-
model: tests/models/regression/saved_model
22+
model: models/regression/saved_model
2323
model_type: saved_model
2424
input_get: get_ramp
2525
inputs:
@@ -28,7 +28,7 @@ regression-saved-model:
2828
- pred:0
2929

3030
saved_model_with_redundant_inputs:
31-
model: tests/models/saved_model_with_redundant_inputs
31+
model: models/saved_model_with_redundant_inputs
3232
model_type: saved_model
3333
input_get: get_ramp
3434
inputs:
@@ -38,7 +38,7 @@ saved_model_with_redundant_inputs:
3838
- Add:0
3939

4040
graphdef_with_redundant_inputs:
41-
model: tests/models/regression/graphdef/frozen.pb
41+
model: models/regression/graphdef/frozen.pb
4242
input_get: get_ramp
4343
inputs:
4444
"X:0": [1, 10]
@@ -47,7 +47,7 @@ graphdef_with_redundant_inputs:
4747
- Add:0
4848

4949
checkpoint_with_redundant_inputs:
50-
model: tests/models/regression/checkpoint/model.meta
50+
model: models/regression/checkpoint/model.meta
5151
model_type: checkpoint
5252
input_get: get_ramp
5353
inputs:
@@ -57,15 +57,15 @@ checkpoint_with_redundant_inputs:
5757
- pred:0
5858

5959
benchtf-fc:
60-
model: tests/models/fc-layers/frozen.pb
60+
model: models/fc-layers/frozen.pb
6161
input_get: get_ramp
6262
inputs:
6363
"X:0": [1, 784]
6464
outputs:
6565
- output:0
6666

6767
benchtf-conv:
68-
model: tests/models/conv-layers/frozen.pb
68+
model: models/conv-layers/frozen.pb
6969
input_get: get_ramp
7070
inputs:
7171
"X:0": [1, 784]
@@ -74,15 +74,15 @@ benchtf-conv:
7474

7575
benchtf-convbn:
7676
disabled: true # some if from training isn't removed
77-
model: tests/models/convbn-layers/frozen.pb
77+
model: models/convbn-layers/frozen.pb
7878
input_get: get_ramp
7979
inputs:
8080
"X:0": [1, 784]
8181
outputs:
8282
- output:0
8383

8484
benchtf-ae0:
85-
model: tests/models/ae0/frozen.pb
85+
model: models/ae0/frozen.pb
8686
input_get: get_ramp
8787
inputs:
8888
"X:0": [1, 784]
@@ -91,7 +91,7 @@ benchtf-ae0:
9191

9292
benchtf-lstm:
9393
disabled: true
94-
model: tests/models/lstm/frozen.pb
94+
model: models/lstm/frozen.pb
9595
input_get: get_ramp
9696
inputs:
9797
"X:0": [1, 784]
@@ -100,7 +100,7 @@ benchtf-lstm:
100100

101101
benchtf-gru:
102102
disabled: true
103-
model: tests/models/gru/frozen.pb
103+
model: models/gru/frozen.pb
104104
input_get: get_ramp
105105
inputs:
106106
"X:0": [1, 784]

0 commit comments

Comments
 (0)