Skip to content

Commit 010fad1

Browse files
authored
Merge pull request #1098 from xadupre/esrgan
Fix issue with model esrgan-tf2_1
2 parents f5f8b2d + bcdedbe commit 010fad1

File tree

5 files changed

+86
-3
lines changed

5 files changed

+86
-3
lines changed

examples/benchmark_tfmodel_ort.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
"""
2+
The following code compares the speed of tensorflow against onnxruntime
3+
with a model downloaded from Tensorflow Hub.
4+
"""
5+
import time
6+
import numpy
7+
from tqdm import tqdm
8+
import tensorflow_hub as hub
9+
import onnxruntime as ort
10+
11+
12+
def generate_random_images(shape=(100, 100), n=10):
13+
imgs = []
14+
for i in range(n):
15+
sh = (1,) + shape + (3,)
16+
img = numpy.clip(numpy.abs(numpy.random.randn(*sh)), 0, 1) * 255
17+
img = img.astype(numpy.float32)
18+
imgs.append(img)
19+
return imgs
20+
21+
22+
def measure_time(fct, imgs):
23+
results = []
24+
times = []
25+
for img in tqdm(imgs):
26+
begin = time.perf_counter()
27+
result = fct(img)
28+
end = time.perf_counter()
29+
results.append(result)
30+
times.append(end - begin)
31+
return results, times
32+
33+
34+
imgs = generate_random_images()
35+
36+
# Download model from https://tfhub.dev/captain-pool/esrgan-tf2/1
37+
# python -m tf2onnx.convert --saved-model esrgan --output "esrgan-tf2.onnx" --opset 12
38+
ort = ort.InferenceSession('esrgan-tf2.onnx')
39+
fct_ort = lambda img: ort.run(None, {'input_0:0': img})
40+
results_ort, duration_ort = measure_time(fct_ort, imgs)
41+
print(len(imgs), duration_ort)
42+
43+
model = hub.load("https://tfhub.dev/captain-pool/esrgan-tf2/1")
44+
results_tf, duration_tf = measure_time(model, imgs)
45+
print(len(imgs), duration_tf)
46+
47+
print("ratio ORT / TF", sum(duration_ort) / sum(duration_tf))

tests/run_pretrained_models.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -246,7 +246,13 @@ def run_test(self, name, backend="caffe2", onnx_file=None, opset=None, extra_ops
246246
if self.model_type in ["checkpoint"]:
247247
graph_def, input_names, outputs = tf_loader.from_checkpoint(model_path, input_names, outputs)
248248
elif self.model_type in ["saved_model"]:
249-
graph_def, input_names, outputs = tf_loader.from_saved_model(model_path, input_names, outputs, self.tag)
249+
try:
250+
res = tf_loader.from_saved_model(model_path, input_names, outputs, self.tag)
251+
except OSError:
252+
model_path = dir_name
253+
logger.info("Load model(2) from %r", model_path)
254+
res = tf_loader.from_saved_model(model_path, input_names, outputs, self.tag)
255+
graph_def, input_names, outputs = res[:3]
250256
elif self.model_type in ["keras"]:
251257
graph_def, input_names, outputs = tf_loader.from_keras(model_path, input_names, outputs)
252258
else:

tests/run_pretrained_models.yaml

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,24 @@ benchtf-gru:
115115
##
116116
## standard image nets
117117
##
118+
119+
esrgan-tf2:
120+
# url: https://tfhub.dev/captain-pool/esrgan-tf2/1/esrgan-tf2_1.tar.gz
121+
url: https://github.com/captain-pool/GSOC/releases/download/1.0.0/esrgan.tar.gz
122+
model: ersgan
123+
model_type: saved_model
124+
input_get: get_beach
125+
opset_constraints:
126+
"onnx":
127+
"min": 10
128+
inputs:
129+
"input_0:0": [1, 50, 50, 3]
130+
outputs:
131+
- Identity:0
132+
rtol: 0.02
133+
atol: 0.0005
134+
tf_min_version: 2.1
135+
118136
inception_v3_slim:
119137
url: https://storage.googleapis.com/download.tensorflow.org/models/inception_v3_2016_08_28_frozen.pb.tar.gz
120138
model: inception_v3_2016_08_28_frozen.pb

tf2onnx/optimizer/transpose_optimizer.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -648,7 +648,10 @@ def _reducemean_handler(self, trans, node):
648648
def _slice_handler(self, trans, node):
649649
axes = None
650650
if self._g.opset < 10:
651-
axes = node.get_attr("axes").ints
651+
axes_values = node.get_attr("axes")
652+
if not axes_values:
653+
return False
654+
axes = axes_values.ints
652655
if axes == [0, 1, 2, 3]:
653656
node.set_attr("axes", NCHW_TO_NHWC)
654657
return self._switch_transpose_and_node(node, trans)

tf2onnx/tf_loader.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -168,7 +168,16 @@ def from_graphdef(model_path, input_names, output_names):
168168
with tf_session() as sess:
169169
graph_def = tf_graphdef()
170170
with tf_gfile.GFile(model_path, 'rb') as f:
171-
graph_def.ParseFromString(f.read())
171+
try:
172+
content = f.read()
173+
except Exception as e:
174+
raise OSError(
175+
"Unable to load file '{}'.".format(model_path)) from e
176+
try:
177+
graph_def.ParseFromString(content)
178+
except Exception as e:
179+
raise RuntimeError(
180+
"Unable to parse file '{}'.".format(model_path)) from e
172181
tf.import_graph_def(graph_def, name='')
173182
input_names = inputs_without_resource(sess, input_names)
174183
frozen_graph = freeze_session(sess, input_names=input_names, output_names=output_names)

0 commit comments

Comments
 (0)