Skip to content

Commit 33e43b4

Browse files
xadupresdpython
andauthored
[WIP] Fix tfhub scripts (enformer, ...) (#1637)
* Scripts and errors Signed-off-by: xavier dupré <[email protected]> * yamnet Signed-off-by: xavier dupré <[email protected]> Co-authored-by: xavier dupré <[email protected]>
1 parent f221917 commit 33e43b4

File tree

7 files changed

+217
-27
lines changed

7 files changed

+217
-27
lines changed

tests/tfhub/_tools.py

Lines changed: 37 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -128,27 +128,29 @@ def download_tflite(url, dest, verbose=True):
128128
return fpath
129129

130130

131-
def convert_model(model_name, output_path, opset=13, tag=None, verbose=True):
131+
def convert_model(model_name, output_path, opset=13, tag=None, signature=None, verbose=True):
132132
"""
133133
Converts the downloaded model into ONNX.
134134
"""
135135
ext = os.path.splitext(output_path)[-1]
136136
large_model = ext == ".zip"
137137
if not os.path.exists(output_path):
138138
begin = datetime.datetime.now()
139-
cmdl = ['-m', 'tf2onnx.convert', '--saved-model',
140-
'"%s"' % os.path.abspath(model_name).replace("\\", "/"),
141-
'--output', '"%s"' % os.path.abspath(output_path).replace("\\", "/"),
139+
cmdl = ['python', '-m', 'tf2onnx.convert', '--saved-model',
140+
'%s' % os.path.abspath(model_name).replace("\\", "/"),
141+
'--output', '%s' % os.path.abspath(output_path).replace("\\", "/"),
142142
'--opset', "%d" % opset]
143+
if signature is not None:
144+
cmdl.append('--signature_def=%s' % signature)
143145
if tag is not None:
144-
cmdl.append('--tag="%s"' % tag)
146+
cmdl.append('--tag=%s' % tag)
145147
if large_model:
146148
cmdl.append('--large_model')
147149
if verbose:
148-
print("cmd: python %s" % " ".join(cmdl))
150+
print("cmd: %s" % " ".join(cmdl))
149151
pproc = subprocess.Popen(
150-
cmdl, shell=True, stdin=None, stdout=subprocess.PIPE, stderr=subprocess.PIPE,
151-
executable=sys.executable.replace("pythonw", "python"))
152+
cmdl, shell=False, stdin=None, stdout=subprocess.PIPE, stderr=subprocess.PIPE,
153+
executable=None)
152154
stdoutdata, stderrdata = pproc.communicate()
153155
if verbose:
154156
print('--OUT--')
@@ -164,15 +166,15 @@ def convert_tflite(model_name, output_path, opset=13, verbose=True):
164166
"""
165167
if not os.path.exists(output_path):
166168
begin = datetime.datetime.now()
167-
cmdl = ['-m', 'tf2onnx.convert', '--tflite',
169+
cmdl = ['python', '-m', 'tf2onnx.convert', '--tflite',
168170
'"%s"' % os.path.abspath(model_name).replace("\\", "/"),
169171
'--output', '"%s"' % os.path.abspath(output_path).replace("\\", "/"),
170172
'--opset', "%d" % opset]
171173
if verbose:
172-
print("cmd: python %s" % " ".join(cmdl))
174+
print("cmd: %s" % " ".join(cmdl))
173175
pproc = subprocess.Popen(
174-
cmdl, shell=True, stdin=None, stdout=subprocess.PIPE, stderr=subprocess.PIPE,
175-
executable=sys.executable.replace("pythonw", "python"))
176+
cmdl, shell=False, stdin=None, stdout=subprocess.PIPE, stderr=subprocess.PIPE,
177+
executable=None)
176178
stdoutdata, stderrdata = pproc.communicate()
177179
if verbose:
178180
print('--OUT--')
@@ -186,6 +188,20 @@ def check_discrepencies(out1, out2, threshold=1e-3):
186188
"""
187189
Compares two tensors. Raises an exception if it fails.
188190
"""
191+
if isinstance(out1, list):
192+
if len(out1) > 1:
193+
if len(out1) != len(out2):
194+
raise AssertionError(
195+
"Mismatched number of outputs, %d for ONNX, %d for TF." % (
196+
len(out1), len(out2)))
197+
for i, (a, b) in enumerate(zip(out1, out2)):
198+
try:
199+
check_discrepencies(out1[i], out2[i].numpy(), threshold=1e-3)
200+
except AssertionError as e:
201+
raise AssertionError("Discrepency with output %d." % i) from e
202+
return
203+
else:
204+
out1 = out1[0]
189205
if out1.dtype != out2.dtype:
190206
raise AssertionError("Type mismatch %r != %r." % (out1.dtype, out2.dtype))
191207
if out1.shape != out2.shape:
@@ -210,7 +226,7 @@ def benchmark(url, dest, onnx_name, opset, imgs, verbose=True, threshold=1e-3,
210226
# Converts the model.
211227
if verbose:
212228
print("Convert model in %r." % dest)
213-
convert_model(tname, onnx_name, opset, tag=tag)
229+
convert_model(tname, onnx_name, opset, tag=tag, signature=signature)
214230
if verbose:
215231
print("Created %r." % onnx_name)
216232

@@ -254,9 +270,11 @@ def benchmark(url, dest, onnx_name, opset, imgs, verbose=True, threshold=1e-3,
254270
index = 0
255271
if isinstance(imgs[0], dict):
256272
fct_ort = lambda img: ort.run(None, img)[index]
273+
fct_orts = lambda img: ort.run(None, img)
257274
else:
258275
input_name = ort.get_inputs()[0].name
259276
fct_ort = lambda img: ort.run(None, {input_name: img})[index]
277+
fct_orts = lambda img: ort.run(None, {input_name: img})
260278
results_ort, duration_ort = measure_time(fct_ort, imgs)
261279
if verbose:
262280
print("ORT", len(imgs), duration_ort)
@@ -294,8 +312,12 @@ def benchmark(url, dest, onnx_name, opset, imgs, verbose=True, threshold=1e-3,
294312
if output_name not in res:
295313
raise AssertionError("Unable to find output %r in %r." % (output_name, list(sorted(res))))
296314
res = res[output_name]
315+
res_ort = fct_orts(imgs[0])
297316
try:
298-
check_discrepencies(fct_ort(imgs[0]), res.numpy(), threshold)
317+
if len(res_ort) > 1:
318+
check_discrepencies(res_ort, res, threshold)
319+
else:
320+
check_discrepencies(res_ort, res.numpy(), threshold)
299321
except AttributeError as e:
300322
raise AssertionError(
301323
"Unable to check discrepencies for res=%r." % res) from e
@@ -373,4 +395,4 @@ def benchmark_tflite(url, dest, onnx_name, opset, imgs, verbose=True, threshold=
373395
raise AssertionError("Unable to find output %r in %r." % (output_name, list(sorted(res))))
374396
res = res[output_name]
375397
check_discrepencies(fct_ort(imgs[0]), res.numpy(), threshold)
376-
return duration_ort, duration_tf
398+
return duration_ort, duration_tf

tests/tfhub/tfhub_enformer.py

Lines changed: 45 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,60 @@
11
# SPDX-License-Identifier: Apache-2.0
22
import os
33
import numpy
4-
from _tools import generate_random_images, benchmark
4+
from numpy.testing import assert_almost_equal
5+
from onnxruntime import InferenceSession
6+
from _tools import generate_random_images, benchmark, measure_time
7+
from tensorflow import convert_to_tensor
8+
import tensorflow as tf
9+
import tensorflow_hub as hub
10+
import tf2onnx
511

612

713
def main(opset=13):
14+
print('[begin]')
815
url = "https://tfhub.dev/deepmind/enformer/1?tf-hub-format=compressed"
916
dest = "tf-enformer"
1017
name = "enformer"
11-
onnx_name = os.path.join(dest, "%s-%d.zip" % (name, opset))
18+
onnx_name = os.path.join(dest, "%s-%d.onnx" % (name, opset))
1219

13-
imgs = generate_random_images(shape=(1, 224, 224, 3))
20+
model = None
21+
if not os.path.exists(onnx_name):
22+
if model is None:
23+
model = hub.load("https://tfhub.dev/deepmind/enformer/1").model
1424

15-
benchmark(url, dest, onnx_name, opset, imgs)
25+
tf2onnx.convert.from_function(
26+
model.predict_on_batch,
27+
[tf.TensorSpec([None, 393216, 4], tf.float32)],
28+
opset=13, output_path=onnx_name)
1629

30+
# benchmark(url, dest, onnx_name, opset, imgs)
31+
print("[generate dummy images]")
32+
imgs = generate_random_images(shape=(1, 393216, 4), scale=0.)
33+
34+
ort = InferenceSession(onnx_name)
35+
fct_ort = lambda img: ort.run(None, {'args_0': img})[0]
36+
37+
if model is None:
38+
model = hub.load("https://tfhub.dev/deepmind/enformer/1").model
39+
40+
fct_tf = lambda img: model.predict_on_batch(img)
41+
42+
print('[benchmark tf]')
43+
imgs_tf = [convert_to_tensor(img) for img in imgs]
44+
results_tf, duration_tf = measure_time(fct_tf, imgs)
45+
print("TF", len(imgs), duration_tf)
46+
47+
print('[benchmark ort]')
48+
results_ort, duration_ort = measure_time(fct_ort, imgs)
49+
print("ORT", len(imgs), duration_ort)
50+
51+
mean_ort = sum(duration_ort) / len(duration_ort)
52+
mean_tf = sum(duration_tf) / len(duration_tf)
53+
print("ratio ORT=%r / TF=%r = %r" % (mean_ort, mean_tf, mean_ort / mean_tf))
54+
55+
# discrepencies
56+
assert_almost_equal(results_tf[0]['human'], results_ort[0], decimal=4)
57+
print('[end]')
1758

1859
if __name__ == "__main__":
1960
main()

tests/tfhub/tfhub_humpback_whale.py

Lines changed: 64 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
# SPDX-License-Identifier: Apache-2.0
22
import os
33
import numpy
4+
from onnxruntime import InferenceSession
45
from _tools import generate_random_images, benchmark
56

67

@@ -10,12 +11,70 @@ def main(opset=13):
1011
name = "humpback-whale"
1112
onnx_name = os.path.join(dest, "%s-%d.onnx" % (name, opset))
1213

13-
imgs = generate_random_images(shape=(1, 1024, 1))
14-
inputs = [dict(waveform=img,
15-
context_step_samples=numpy.array(512, dtype=numpy.int64))
16-
for img in imgs]
14+
kind = "function"
15+
if kind == "function":
16+
import tensorflow as tf
17+
import tensorflow_hub as hub
18+
import tf2onnx
19+
model = hub.load('https://tfhub.dev/google/humpback_whale/1')
20+
FILENAME = 'gs://bioacoustics-www1/sounds/Cross_02_060203_071428.d20_7.wav'
21+
waveform, sample_rate = tf.audio.decode_wav(tf.io.read_file(FILENAME))
22+
waveform = tf.expand_dims(waveform, 0) # makes a batch of size 1
23+
context_step_samples = tf.cast(sample_rate, tf.int64)
24+
print(waveform.dtype, waveform.shape, sample_rate.dtype, sample_rate.shape, sample_rate)
25+
26+
spec = (tf.TensorSpec((None, ) + waveform.shape[-2:], tf.float32, name="waveform"),
27+
tf.TensorSpec((1, 1), tf.int64, name="context_step_samples"))
28+
inputs = {'waveform': waveform.numpy(),
29+
'context_step_samples': context_step_samples.numpy()}
30+
31+
tf2onnx.convert.from_function(
32+
model.signatures['score'], input_signature=spec, opset=13, output_path=onnx_name)
33+
# AttributeError: '_WrapperFunction' object has no attribute 'get_concrete_function'
1734

18-
benchmark(url, dest, onnx_name, opset, inputs, optimize=False)
35+
sess = InferenceSession(onnx_name)
36+
got = sess.run(None, inputs)
37+
print(got)
38+
39+
score_fn = model.signatures['score']
40+
scores = score_fn(waveform=waveform, context_step_samples=context_step_samples)
41+
42+
if kind == "keras":
43+
import tensorflow as tf
44+
import tensorflow_hub as hub
45+
import tf2onnx
46+
model = hub.load('https://tfhub.dev/google/humpback_whale/1').model
47+
FILENAME = 'gs://bioacoustics-www1/sounds/Cross_02_060203_071428.d20_7.wav'
48+
waveform, sample_rate = tf.audio.decode_wav(tf.io.read_file(FILENAME))
49+
waveform = tf.expand_dims(waveform, 0) # makes a batch of size 1
50+
context_step_samples = tf.cast(sample_rate, tf.int64)
51+
print(waveform.dtype, waveform.shape, sample_rate.dtype, sample_rate.shape, sample_rate)
52+
53+
spec = (tf.TensorSpec((None, ) + waveform.shape[-2:], tf.float32, name="waveform"),
54+
tf.TensorSpec((1, 1), tf.int64, name="context_step_samples"))
55+
inputs = {'waveform': waveform.numpy(),
56+
'context_step_samples': context_step_samples.numpy()}
57+
58+
tf2onnx.convert.from_keras(model, input_signature=spec, opset=13, output_path=onnx_name)
59+
# AttributeError: '_UserObject' object has no attribute 'output_names'
60+
61+
sess = InferenceSession(onnx_name)
62+
got = sess.run(None, inputs)
63+
print(got)
64+
65+
score_fn = model.signatures['score']
66+
scores = score_fn(waveform=waveform, context_step_samples=context_step_samples)
67+
68+
if kind == 'cmd':
69+
imgs = generate_random_images(shape=(1, 10000, 1), scale=1.)
70+
inputs = [dict(waveform=img,
71+
context_step_samples=numpy.array(512, dtype=numpy.int64))
72+
for img in imgs]
73+
benchmark(url, dest, onnx_name, opset, inputs, optimize=False,
74+
signature='score')
75+
# onnxruntime.capi.onnxruntime_pybind11_state.RuntimeException:
76+
# [ONNXRuntimeError] : 6 : RUNTIME_EXCEPTION : Non-zero status code returned while running Reshape node. Name:'StatefulPartitionedCall/Reshape_1' Status Message: C:\xadupre\microsoft_xadupre\onnxruntime\onnxruntime\core\providers\cpu\tensor\reshape_helper.h:42 onnxruntime::ReshapeHelper::ReshapeHelper gsl::narrow_cast<int64_t>(input_shape.Size()) == size was false. The input tensor cannot be reshaped to the requested shape.
77+
# Input shape:{0,1}, requested shape:{1,1,1}
1978

2079

2180
if __name__ == "__main__":

tests/tfhub/tfhub_mobile_food_segmenter_V1.py

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22
import os
33
import numpy
44
from _tools import generate_random_images, benchmark
5+
import tf2onnx
6+
import onnxruntime as ort
57

68

79
def main(opset=13):
@@ -12,7 +14,29 @@ def main(opset=13):
1214

1315
imgs = generate_random_images(shape=(1, 513, 513, 3), scale=1.)
1416

15-
benchmark(url, dest, onnx_name, opset, imgs, tag='')
17+
if True:
18+
benchmark(url, dest, onnx_name, opset, imgs, tag='')
19+
# The conversion works but tensorflow fails with
20+
# TypeError: 'AutoTrackable' object is not callable
21+
22+
if True:
23+
import tensorflow.compat.v2 as tf
24+
import tensorflow_hub as hub
25+
26+
m = hub.KerasLayer('https://tfhub.dev/google/seefood/segmenter/mobile_food_segmenter_V1/1')
27+
inputs = {
28+
"X": tf.keras.Input(shape=[1, 513, 513, 3], dtype="float32", name="X"),
29+
}
30+
outputs = m(inputs)["default"]
31+
# TypeError: pruned(images) missing required arguments: images
32+
print(outputs)
33+
model = tf.keras.Model(inputs, outputs)
34+
35+
if not os.path.exists(dest):
36+
os.makedirs(dest)
37+
38+
# This model is a large model.
39+
tf2onnx.convert.from_keras(model, opset=13, output_path=onnx_name)
1640

1741

1842
if __name__ == "__main__":

tests/tfhub/tfhub_yamnet_coral.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
import os
3+
import numpy
4+
from _tools import generate_random_images, benchmark_tflite
5+
6+
7+
def main(opset=13):
8+
url = "https://tfhub.dev/google/coral-model/yamnet/classification/coral/1?coral-format=tflite"
9+
dest = "tf-yamnet-coral"
10+
name = "yamnet"
11+
onnx_name = os.path.join(dest, "%s-%d.onnx" % (name, opset))
12+
13+
imgs = generate_random_images(shape=(1, 256, 256, 3), dtype=numpy.int32)
14+
15+
benchmark_tflite(url, dest, onnx_name, opset, imgs)
16+
# WARNING - Error loading model into tflite interpreter: Encountered unresolved custom op: edgetpu-custom-op.Node number 14 (edgetpu-custom-op) failed to prepare.
17+
# WARNING - Could not parse attributes for custom op 'TFL_edgetpu-custom-op': 'utf-8' codec can't decode byte 0xc8 in position 0: invalid continuation byte
18+
# WARNING - For now, onnxruntime only support float32 type for Gemm rewriter
19+
# ERROR - Tensorflow op [tower0/network/layer32/final_output1_prequant: TFL_edgetpu-custom-op] is not supported
20+
# ERROR - Unsupported ops: Counter({'TFL_edgetpu-custom-op': 1})
21+
22+
if __name__ == "__main__":
23+
main()

tests/tfhub/tfhub_yamnet_tf.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
import os
3+
import numpy
4+
from _tools import generate_random_images, benchmark
5+
6+
7+
def main(opset=13):
8+
url = "https://tfhub.dev/google/yamnet/1?tf-hub-format=compressed"
9+
dest = "tf-yamnet-tf"
10+
name = "yamnet"
11+
onnx_name = os.path.join(dest, "%s-%d.onnx" % (name, opset))
12+
13+
imgs = generate_random_images(shape=(16000, ), dtype=numpy.float32, scale=0.)
14+
15+
benchmark(url, dest, onnx_name, opset, imgs)
16+
17+
18+
if __name__ == "__main__":
19+
main()

tests/tfhub/tfhub_yamnet.py renamed to tests/tfhub/tfhub_yamnet_tflite.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,14 +5,16 @@
55

66

77
def main(opset=13):
8-
url = "https://tfhub.dev/google/coral-model/yamnet/classification/coral/1?coral-format=tflite"
9-
dest = "tf-yamnet"
8+
url = "https://tfhub.dev/google/lite-model/yamnet/classification/tflite/1?lite-format=tflite"
9+
dest = "tf-yamnet-tflite"
1010
name = "yamnet"
1111
onnx_name = os.path.join(dest, "%s-%d.onnx" % (name, opset))
1212

1313
imgs = generate_random_images(shape=(1, 256, 256, 3), dtype=numpy.int32)
1414

1515
benchmark_tflite(url, dest, onnx_name, opset, imgs)
16+
# WARNING - For now, onnxruntime only support float32 type for Gemm rewriter
17+
# onnxruntime: Could not find an implementation for the node pre_tower/split_prequant:Split(13)
1618

1719

1820
if __name__ == "__main__":

0 commit comments

Comments
 (0)