Skip to content

Commit 9743c05

Browse files
committed
Update run_pretrained_models.py
1 parent a521221 commit 9743c05

File tree

1 file changed

+16
-3
lines changed

1 file changed

+16
-3
lines changed

tests/run_pretrained_models.py

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -51,8 +51,7 @@
5151
def get_beach(shape):
5252
"""Get beach image as input."""
5353
resize_to = shape[1:3]
54-
path = os.path.join(os.path.dirname(os.path.abspath(__file__)),
55-
'..', 'tests', "beach.jpg")
54+
path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "beach.jpg")
5655
img = PIL.Image.open(path)
5756
img = img.resize(resize_to, PIL.Image.ANTIALIAS)
5857
img_np = np.array(img).astype(np.float32)
@@ -247,7 +246,21 @@ def run_test(self, name, backend="caffe2", onnx_file=None, opset=None, extra_ops
247246
if self.model_type in ["checkpoint"]:
248247
graph_def, input_names, outputs = tf_loader.from_checkpoint(model_path, input_names, outputs)
249248
elif self.model_type in ["saved_model"]:
250-
graph_def, input_names, outputs = tf_loader.from_saved_model(model_path, input_names, outputs, self.tag)
249+
try:
250+
res = tf_loader.from_saved_model(
251+
model_path, input_names, outputs, self.tag, self.signatures, self.concrete_function, self.large_model)
252+
except OSError:
253+
model_path = dir_name
254+
logger.info("Load model(2) from %r", model_path)
255+
res = tf_loader.from_saved_model(
256+
model_path, input_names, outputs, self.tag, self.signatures, self.concrete_function, self.large_model)
257+
if len(res) == 5:
258+
graph_def, input_names, outputs, concrete_func, imported = res
259+
elif len(res) == 3:
260+
graph_def, input_names, outputs = res
261+
concrete_func, imported = None, None
262+
else:
263+
raise OSError("Unexpected number of results %r." % len(res))
251264
elif self.model_type in ["keras"]:
252265
graph_def, input_names, outputs = tf_loader.from_keras(model_path, input_names, outputs)
253266
else:

0 commit comments

Comments
 (0)