Skip to content

Commit 0b9da56

Browse files
committed
refactor
Signed-off-by: xavier dupré <[email protected]>
1 parent db230dc commit 0b9da56

File tree

6 files changed

+30
-5
lines changed

6 files changed

+30
-5
lines changed

tests/tfhub/tfhub_albert_en_xlarge.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
import os
3+
import numpy
4+
from _tools import generate_random_images, benchmark
5+
6+
7+
def main(opset=13):
8+
url = "https://tfhub.dev/tensorflow/albert_en_xlarge/3?tf-hub-format=compressed"
9+
dest = "tf-albert-en-xlarge"
10+
name = "albert-en-xlarge"
11+
onnx_name = os.path.join(dest, "%s-%d.onnx" % (name, opset))
12+
13+
imgs = generate_random_images(shape=(1, 256, 256, 3), dtype=numpy.int32)
14+
15+
benchmark(url, dest, onnx_name, opset, imgs,
16+
signature='serving_default')
17+
18+
19+
if __name__ == "__main__":
20+
main()

tests/tfhub/tfhub_esrgan.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,4 +15,5 @@ def main(opset=13):
1515
benchmark(url, dest, onnx_name, opset, imgs)
1616

1717

18-
main()
18+
if __name__ == "__main__":
19+
main()

tests/tfhub/tfhub_resnet_v2_152.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,4 +15,5 @@ def main(opset=13):
1515
benchmark(url, dest, onnx_name, opset, imgs)
1616

1717

18-
main()
18+
if __name__ == "__main__":
19+
main()

tests/tfhub/tfhub_spam_detection.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,4 +16,5 @@ def main(opset=13):
1616
benchmark(url, dest, onnx_name, opset, imgs)
1717

1818

19-
main()
19+
if __name__ == "__main__":
20+
main()

tests/tfhub/tfhub_thunder.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,4 +16,5 @@ def main(opset=13):
1616
signature='serving_default')
1717

1818

19-
main()
19+
if __name__ == "__main__":
20+
main()

tests/tfhub/tfhub_yamnet.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,4 +15,5 @@ def main(opset=13):
1515
benchmark_tflite(url, dest, onnx_name, opset, imgs)
1616

1717

18-
main()
18+
if __name__ == "__main__":
19+
main()

0 commit comments

Comments
 (0)