Skip to content

Commit 94d8113

Browse files
authored
add more huggingface models for testing (#1469)
* add tf-2.5 to ci Signed-off-by: Guenther Schmuelling <[email protected]> * theta must be >- 0 Signed-off-by: Guenther Schmuelling <[email protected]> * add more huggingface models for testing Signed-off-by: Guenther Schmuelling <[email protected]> * pylint Signed-off-by: Guenther Schmuelling <[email protected]> * pylint Signed-off-by: Guenther Schmuelling <[email protected]> * list python version in setup Signed-off-by: Guenther Schmuelling <[email protected]>
1 parent e5b1b2a commit 94d8113

File tree

4 files changed

+123
-62
lines changed

4 files changed

+123
-62
lines changed

README.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,8 @@ The common issues we run into we try to document here [Troubleshooting Guide](Tr
1818

1919
| Build Type | OS | Python | Tensorflow | Onnx opset | Status |
2020
| --- | --- | --- | --- | --- | --- |
21-
| Unit Test - Basic | Linux, MacOS<sup>\*</sup>, Windows<sup>\*</sup> | 3.6, 3.7, 3.8 | 1.12-1.15, 2.1-2.4 | 7-13 | [![Build Status](https://dev.azure.com/tensorflow-onnx/tensorflow-onnx/_apis/build/status/unit_test?branchName=master)](https://dev.azure.com/tensorflow-onnx/tensorflow-onnx/_build/latest?definitionId=16&branchName=master) |
22-
| Unit Test - Full | Linux, MacOS, Windows | 3.6, 3.7, 3.8 | 1.12-1.15, 2.1-2.4 | 7-13 | [![Build Status](https://dev.azure.com/tensorflow-onnx/tensorflow-onnx/_apis/build/status/unit_test-matrix?branchName=master)](https://dev.azure.com/tensorflow-onnx/tensorflow-onnx/_build/latest?definitionId=18&branchName=master) | |
21+
| Unit Test - Basic | Linux, MacOS<sup>\*</sup>, Windows<sup>\*</sup> | 3.6, 3.7, 3.8 | 1.12-1.15, 2.1-2.5 | 7-13 | [![Build Status](https://dev.azure.com/tensorflow-onnx/tensorflow-onnx/_apis/build/status/unit_test?branchName=master)](https://dev.azure.com/tensorflow-onnx/tensorflow-onnx/_build/latest?definitionId=16&branchName=master) |
22+
| Unit Test - Full | Linux, MacOS, Windows | 3.6, 3.7, 3.8 | 1.12-1.15, 2.1-2.5 | 7-13 | [![Build Status](https://dev.azure.com/tensorflow-onnx/tensorflow-onnx/_apis/build/status/unit_test-matrix?branchName=master)](https://dev.azure.com/tensorflow-onnx/tensorflow-onnx/_build/latest?definitionId=18&branchName=master) | |
2323
<br/>
2424

2525
## Supported Versions

setup.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -81,5 +81,22 @@ def run(self):
8181
8282
author_email='[email protected]',
8383
url='https://github.com/onnx/tensorflow-onnx',
84-
install_requires=['numpy>=1.14.1', 'onnx>=1.4.1', 'requests', 'six', 'flatbuffers']
84+
install_requires=['numpy>=1.14.1', 'onnx>=1.4.1', 'requests', 'six', 'flatbuffers'],
85+
classifiers=[
86+
'Development Status :: 5 - Production/Stable',
87+
'Intended Audience :: Developers',
88+
'Intended Audience :: Education',
89+
'Intended Audience :: Science/Research',
90+
'License :: OSI Approved :: Apache2 License',
91+
'Topic :: Scientific/Engineering',
92+
'Topic :: Scientific/Engineering :: Mathematics',
93+
'Topic :: Scientific/Engineering :: Artificial Intelligence',
94+
'Topic :: Software Development',
95+
'Topic :: Software Development :: Libraries',
96+
'Topic :: Software Development :: Libraries :: Python Modules',
97+
'Programming Language :: Python :: 3',
98+
'Programming Language :: Python :: 3.6',
99+
'Programming Language :: Python :: 3.7',
100+
'Programming Language :: Python :: 3.8',
101+
'Programming Language :: Python :: 3.9']
85102
)

tests/huggingface.py

Lines changed: 64 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,11 @@
11
# SPDX-License-Identifier: Apache-2.0
22

3-
"""Unit tests for huggingface tensorflow transformers."""
3+
"""
4+
Unit tests for huggingface tensorflow transformers.
5+
6+
tested with tf-2.4.1, transformers-4.5.1
7+
8+
"""
49

510
# pylint: disable=missing-docstring,invalid-name,unused-argument
611
# pylint: disable=bad-classmethod-argument,wrong-import-position
@@ -19,7 +24,9 @@
1924
import tensorflow as tf
2025
import tf2onnx
2126

22-
compare_perf = False
27+
compare_perf = True
28+
time_to_run = 10
29+
time_step = 10
2330

2431

2532
class TestTransformers(unittest.TestCase):
@@ -47,26 +54,31 @@ def run_onnxruntime(self, model_path, input_dict, output_names):
4754
m = rt.InferenceSession(model_path, sess_options=opt, providers=providers)
4855
results = m.run(output_names, input_dict)
4956
if compare_perf:
50-
count = 10
57+
n = 0
5158
time_start = time.time()
52-
for _ in range(count):
53-
_ = m.run(output_names, input_dict.keys())
59+
time_stop = time_start + time_to_run
60+
while time.time() < time_stop:
61+
for _ in range(time_step):
62+
_ = m.run(output_names, input_dict)
63+
n += time_step
5464
time_end = time.time()
55-
val = str((time_end - time_start) / count)
56-
print(f'==== avg ort name={self.name}, time={val}')
65+
val = (time_end - time_start) / n
66+
print(f'= avg ort name={self.name}, time={val}, n={n}')
5767
return results
5868

5969
def run_keras(self, model, inputs):
60-
print(f"==== {self.name}")
6170
pred = model(inputs)
6271
if compare_perf:
63-
count = 10
72+
n = 0
6473
time_start = time.time()
65-
for _ in range(count):
66-
_ = model(inputs)
67-
time_end = time.time()
68-
val = str((time_end - time_start) / count)
69-
print(f'==== avg keras name={self.name}, time={val}')
74+
time_stop = time_start + time_to_run
75+
while time.time() < time_stop:
76+
for _ in range(time_step):
77+
_ = model(inputs)
78+
n += time_step
79+
time_stop = time.time()
80+
val = (time_stop - time_start) / n
81+
print(f'= avg keras name={self.name}, time={val}, n={n}')
7082
return pred
7183

7284
def run_test(self, model, input_dict, rtol=1e-2, atol=1e-4, input_signature=None,
@@ -96,8 +108,11 @@ def run_test(self, model, input_dict, rtol=1e-2, atol=1e-4, input_signature=None
96108
if not large:
97109
model_path = model_path + ".onnx"
98110
print("= convert")
111+
time_start = time.time()
99112
_, _ = tf2onnx.convert.from_keras(model, input_signature=input_signature,
100113
opset=13, large_model=large, output_path=model_path)
114+
time_stop = time.time()
115+
print(f"= convertsion took {time_stop - time_start}")
101116

102117
if large:
103118
# need to unpack the zip for run_onnxruntime()
@@ -163,18 +178,45 @@ def test_TFDisillBertModel(self):
163178

164179
## FUNNEL
165180

166-
def _test_TFFunnelSquad(self, size, large=False):
181+
def _test_TFFunnel(self, size, large=False):
167182
from transformers import FunnelTokenizer, TFFunnelForQuestionAnswering
168183
tokenizer = FunnelTokenizer.from_pretrained(size)
169184
model = TFFunnelForQuestionAnswering.from_pretrained(size)
170185
question, text = "Who was Jim Henson?", "Jim Henson was a nice puppet"
171186
input_dict = tokenizer(question, text, return_tensors='tf')
172-
spec, input_dict = self.spec_and_pad(input_dict, max_length=model.config.max_length)
187+
spec, input_dict = self.spec_and_pad(input_dict, 128)
173188
outputs = ["start_logits", "end_logits"]
174189
self.run_test(model, input_dict, input_signature=spec, outputs=outputs, rtol=1e-5)
175190

176-
def test_TFFunnelSquadSmall(self):
177-
self._test_TFFunnelSquad("funnel-transformer/small")
191+
def test_TFFunnelSmall(self):
192+
self._test_TFFunnel("funnel-transformer/small")
193+
194+
def test_TFFunnelSmallBase(self):
195+
self._test_TFFunnel("funnel-transformer/small-base")
196+
197+
def test_TFFunnelMedium(self):
198+
self._test_TFFunnel("funnel-transformer/medium")
199+
200+
def test_TFFunnelMediumBase(self):
201+
self._test_TFFunnel("funnel-transformer/medium-base")
202+
203+
def test_TFFunnelIntermediate(self):
204+
self._test_TFFunnel("funnel-transformer/intermediate")
205+
206+
def test_TFFunnelIntermediateBase(self):
207+
self._test_TFFunnel("funnel-transformer/intermediate-base")
208+
209+
def test_TFFunnelLarge(self):
210+
self._test_TFFunnel("funnel-transformer/large")
211+
212+
def test_TFFunnelLargeBase(self):
213+
self._test_TFFunnel("funnel-transformer/large-base")
214+
215+
def test_TFFunnelXLarge(self):
216+
self._test_TFFunnel("funnel-transformer/xlarge")
217+
218+
def test_TFFunnelXLargeBase(self):
219+
self._test_TFFunnel("funnel-transformer/xlarge-base")
178220

179221
## T5
180222

@@ -352,13 +394,16 @@ def _test_TFBart(self, size, large=False):
352394
tokenizer = BartTokenizer.from_pretrained(size)
353395
model = TFBartModel.from_pretrained(size)
354396
input_dict = tokenizer("Hello, my dog is cute", return_tensors="tf")
355-
spec, input_dict = self.spec_and_pad(input_dict, max_length=model.config.max_length)
397+
spec, input_dict = self.spec_and_pad(input_dict, max_length=128)
356398
outputs = ["last_hidden_state"]
357399
self.run_test(model, input_dict, input_signature=spec, outputs=outputs, large=large)
358400

359401
def test_TFBartBase(self):
360402
self._test_TFBart("facebook/bart-base", large=True)
361403

404+
def test_TFBartLarge(self):
405+
self._test_TFBart("facebook/bart-large", large=True)
406+
362407
def test_TFBartLargeCnn(self):
363408
self._test_TFBart("facebook/bart-large-cnn", large=True)
364409

tests/run_pretrained_models.py

Lines changed: 39 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -53,8 +53,8 @@
5353
logger = logging.getLogger("run_pretrained")
5454

5555
TEMP_DIR = os.path.join(utils.get_temp_directory(), "run_pretrained")
56-
PERFITER = 1000
57-
56+
PERF_STEP = 10
57+
PERF_TIME = 10
5858

5959
def get_img(shape, path, dtype, should_scale=True):
6060
"""Get image as input."""
@@ -292,10 +292,15 @@ def run_tensorflow(self, sess, inputs):
292292
result = sess.run(self.output_names, feed_dict=feed_dict)
293293
if self.perf:
294294
logger.info("Running TF perf")
295+
n = 0
295296
start = time.time()
296-
for _ in range(PERFITER):
297-
_ = sess.run(self.output_names, feed_dict=feed_dict)
298-
self.tf_runtime = time.time() - start
297+
stop = start + PERF_TIME
298+
while time.time() < stop:
299+
for _ in range(PERF_STEP):
300+
_ = sess.run(self.output_names, feed_dict=feed_dict)
301+
n += PERF_STEP
302+
self.tf_runtime = 1000 * (time.time() - start) / n
303+
logger.info("TF perf {:.2f}ms/inference, n={}".format(self.tf_runtime, n))
299304
return result
300305

301306
def to_onnx(self, tf_graph, opset=None, extra_opset=None, shape_override=None, input_names=None,
@@ -312,18 +317,6 @@ def to_onnx(self, tf_graph, opset=None, extra_opset=None, shape_override=None, i
312317
tflite_path=tflite_path, dequantize=self.dequantize,
313318
tensors_to_rename=tensors_to_rename)
314319

315-
def run_caffe2(self, name, model_proto, inputs):
316-
"""Run test again caffe2 backend."""
317-
import caffe2.python.onnx.backend
318-
prepared_backend = caffe2.python.onnx.backend.prepare(model_proto)
319-
results = prepared_backend.run(inputs)
320-
if self.perf:
321-
start = time.time()
322-
for _ in range(PERFITER):
323-
_ = prepared_backend.run(inputs)
324-
self.onnx_runtime = time.time() - start
325-
return results
326-
327320
def run_onnxruntime(self, name, model_proto, inputs, outputs, external_tensor_storage=None):
328321
"""Run test against onnxruntime backend."""
329322
import onnxruntime as rt
@@ -340,10 +333,15 @@ def run_onnxruntime(self, name, model_proto, inputs, outputs, external_tensor_st
340333
m = rt.InferenceSession(model_path)
341334
results = m.run(outputs, inputs)
342335
if self.perf:
336+
n = 0
343337
start = time.time()
344-
for _ in range(PERFITER):
345-
_ = m.run(outputs, inputs)
346-
self.onnx_runtime = time.time() - start
338+
stop = start + PERF_TIME
339+
while time.time() < stop:
340+
for _ in range(PERF_STEP):
341+
_ = m.run(outputs, inputs)
342+
n += PERF_STEP
343+
self.onnx_runtime = 1000 * (time.time() - start) / n
344+
logger.info("ORT perf {:.2f}ms/inference, n={}".format(self.onnx_runtime, n))
347345
return results
348346

349347
@staticmethod
@@ -357,8 +355,7 @@ def create_onnx_file(name, model_proto, inputs, outdir, external_tensor_storage=
357355
utils.save_onnx_zip(model_path, model_proto, external_tensor_storage)
358356
logger.info("Created %s", model_path)
359357

360-
def run_test(self, name, backend="caffe2", onnx_file=None, opset=None, extra_opset=None,
361-
perf=None, fold_const=None):
358+
def run_test(self, name, backend="onnxruntime", onnx_file=None, opset=None, extra_opset=None, perf=None):
362359
"""Run complete test against backend."""
363360
self.perf = perf
364361

@@ -422,10 +419,15 @@ def run_tflite():
422419
tf_results = run_tflite()
423420
if self.perf:
424421
logger.info("Running TFLite perf")
422+
n = 0
425423
start = time.time()
426-
for _ in range(PERFITER):
427-
_ = run_tflite()
428-
self.tf_runtime = time.time() - start
424+
stop = start + PERF_TIME
425+
while time.time() < stop:
426+
for _ in range(PERF_STEP):
427+
_ = run_tflite()
428+
n += PERF_STEP
429+
self.tf_runtime = 1000 * (time.time() - start) / n
430+
logger.info("TFLite perf {:.2f}ms/inference, n={}".format(self.tf_runtime, n))
429431
logger.info("TFLite OK")
430432

431433
if not self.run_tf_frozen:
@@ -444,10 +446,15 @@ def run_tflite():
444446
tf_results = [tf_res.numpy() for tf_res in tf_results]
445447
if self.perf:
446448
logger.info("Running TF perf")
449+
n = 0
447450
start = time.time()
448-
for _ in range(PERFITER):
449-
_ = concrete_func(**inputs)
450-
self.tf_runtime = time.time() - start
451+
stop = start + PERF_TIME
452+
while time.time() < stop:
453+
for _ in range(PERF_STEP):
454+
_ = concrete_func(**inputs)
455+
n += PERF_STEP
456+
self.tf_runtime = 1000 * (time.time() - start) / n
457+
logger.info("TF perf {:.2f}ms/inference, n={}".format(self.tf_runtime, n))
451458
logger.info("TensorFlow OK")
452459

453460
shape_override = {}
@@ -533,9 +540,7 @@ def run_tflite():
533540

534541
try:
535542
onnx_results = None
536-
if backend == "caffe2":
537-
onnx_results = self.run_caffe2(name, model_proto, inputs)
538-
elif backend == "onnxruntime":
543+
if backend == "onnxruntime":
539544
if to_rename is None:
540545
struc_outputs = self.output_names
541546
else:
@@ -614,7 +619,7 @@ def get_args():
614619
parser.add_argument("--tests", help="tests to run")
615620
parser.add_argument("--target", default="", help="target platform")
616621
parser.add_argument("--backend", default="onnxruntime",
617-
choices=["caffe2", "onnxruntime"], help="backend to use")
622+
choices=["onnxruntime"], help="backend to use")
618623
parser.add_argument("--opset", type=int, default=None, help="opset to use")
619624
parser.add_argument("--extra_opset", default=None,
620625
help="extra opset with format like domain:version, e.g. com.microsoft:1")
@@ -625,9 +630,6 @@ def get_args():
625630
parser.add_argument("--list", help="list tests", action="store_true")
626631
parser.add_argument("--onnx-file", help="create onnx file in directory")
627632
parser.add_argument("--perf", help="capture performance numbers")
628-
parser.add_argument("--perfiter", type=int, default=PERFITER, help="number of inferences for perf testing")
629-
parser.add_argument("--fold_const", help="enable tf constant_folding transformation before conversion",
630-
action="store_true")
631633
parser.add_argument("--include-disabled", help="include disabled tests", action="store_true")
632634
args = parser.parse_args()
633635

@@ -699,7 +701,6 @@ def load_tests_from_yaml(path):
699701

700702

701703
def main():
702-
global PERFITER
703704
args = get_args()
704705
logging.basicConfig(level=logging.get_verbosity_level(args.verbose))
705706
if args.debug:
@@ -718,7 +719,6 @@ def main():
718719

719720
failed = 0
720721
count = 0
721-
PERFITER = args.perfiter
722722
for test in test_keys:
723723
logger.info("===================================")
724724

@@ -749,8 +749,7 @@ def main():
749749
try:
750750
logger.info("Running %s", test)
751751
ret = t.run_test(test, backend=args.backend, onnx_file=args.onnx_file,
752-
opset=args.opset, extra_opset=args.extra_opset, perf=args.perf,
753-
fold_const=args.fold_const)
752+
opset=args.opset, extra_opset=args.extra_opset, perf=args.perf)
754753
except Exception:
755754
logger.error("Failed to run %s", test, exc_info=1)
756755
ret = None
@@ -770,7 +769,7 @@ def main():
770769
t = tests[test]
771770
if t.perf:
772771
# Report perf in ms per inference
773-
f.write("{},{},{}\n".format(test, t.tf_runtime * 1000 / PERFITER, t.onnx_runtime * 1000 / PERFITER))
772+
f.write("{},{},{}\n".format(test, t.tf_runtime, t.onnx_runtime))
774773
return failed
775774

776775

0 commit comments

Comments
 (0)