Skip to content

Commit 93fe844

Browse files
xadupresdpython
andauthored
fix scripts (#1626)
Signed-off-by: xavier dupré <[email protected]> Co-authored-by: xavier dupré <[email protected]>
1 parent 5892425 commit 93fe844

File tree

5 files changed

+42
-4
lines changed

5 files changed

+42
-4
lines changed

tests/tfhub/_tools.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -301,7 +301,7 @@ def benchmark(url, dest, onnx_name, opset, imgs, verbose=True, threshold=1e-3,
301301
"Unable to check discrepencies for res=%r." % res) from e
302302
except AssertionError as e:
303303
output_names = [o.name for o in ort.get_outputs()]
304-
res = ort.run(None, imgs[0])
304+
res = fct_ort(imgs[0])
305305
for i, r in enumerate(res):
306306
print("ORT %d: %s: %r: %r" % (i, output_names[i], r.dtype, r.shape))
307307
raise e

tests/tfhub/tfhub_resnet_v1_101.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ def main(opset=13):
1010
name = "resnet_v1_101"
1111
onnx_name = os.path.join(dest, "%s-%d.onnx" % (name, opset))
1212

13-
imgs = generate_random_images(shape=(1, 224, 224, 3))
13+
imgs = generate_random_images(shape=(1, 224, 224, 3), scale=1.)
1414

1515
benchmark(url, dest, onnx_name, opset, imgs)
1616

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
import os
3+
import numpy
4+
import onnxruntime as ort
5+
import tensorflow as tf
6+
import tensorflow_hub as hub
7+
import tf2onnx
8+
from _tools import generate_random_images, check_discrepencies
9+
10+
imgs = generate_random_images(shape=(1, 224, 224, 3), scale=1.)
11+
12+
model = tf.keras.Sequential([
13+
hub.KerasLayer("https://tfhub.dev/google/imagenet/resnet_v1_101/feature_vector/5",
14+
trainable=False)])
15+
model.build([None, 224, 224, 3])
16+
17+
expected_output = model(imgs[0])
18+
19+
dest = "tf-resnet_v1_101"
20+
if not os.path.exists(dest):
21+
os.makedirs(dest)
22+
dest_name = os.path.join(dest, "resnet_v1_101-13-keras.onnx")
23+
if not os.path.exists(dest_name):
24+
tf2onnx.convert.from_keras(model, opset=13, output_path=dest_name)
25+
26+
sess = ort.InferenceSession(dest_name)
27+
print('inputs', [_.name for _ in sess.get_inputs()])
28+
ort_output = sess.run(None, {"keras_layer_input": imgs[0]})
29+
30+
print("Actual")
31+
print(ort_output)
32+
print("Expected")
33+
print(expected_output)
34+
35+
diff = expected_output.numpy() - ort_output[0]
36+
max_diff = numpy.abs(diff).max()
37+
rel_diff = (numpy.abs(diff) / (expected_output.numpy() + 1e-5)).max()
38+
print(max_diff, rel_diff, [ort_output[0].min(), ort_output[0].max()])

tests/tfhub/tfhub_resnet_v2_101.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ def main(opset=13):
1010
name = "resnet_v2_101"
1111
onnx_name = os.path.join(dest, "%s-%d.onnx" % (name, opset))
1212

13-
imgs = generate_random_images(shape=(1, 224, 224, 3))
13+
imgs = generate_random_images(shape=(1, 224, 224, 3), scale=1.)
1414

1515
benchmark(url, dest, onnx_name, opset, imgs)
1616

tests/tfhub/tfhub_resnet_v2_101_classification.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ def main(opset=13):
1010
name = "resnet_v2_101_classification"
1111
onnx_name = os.path.join(dest, "%s-%d.onnx" % (name, opset))
1212

13-
imgs = generate_random_images(shape=(1, 224, 224, 3))
13+
imgs = generate_random_images(shape=(1, 224, 224, 3), scale=1.)
1414

1515
benchmark(url, dest, onnx_name, opset, imgs)
1616

0 commit comments

Comments
 (0)