Skip to content

Commit 2be4cf3

Browse files
xadupresdpython
andauthored
fix humpback_whale (#1671)
Signed-off-by: xavier dupré <[email protected]> Co-authored-by: xavier dupré <[email protected]>
1 parent d42dcc5 commit 2be4cf3

File tree

2 files changed

+55
-77
lines changed

2 files changed

+55
-77
lines changed

tests/tfhub/_tools.py

Lines changed: 19 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -213,7 +213,7 @@ def check_discrepencies(out1, out2, threshold=1e-3):
213213

214214
def benchmark(url, dest, onnx_name, opset, imgs, verbose=True, threshold=1e-3,
215215
signature=None, tag=None, output_name=None, ort_name=None,
216-
optimize=True, convert_tflite=None):
216+
optimize=True, convert_tflite=None, custom_tf=None):
217217
"""
218218
Runs a simple benchmark.
219219
Goes through every steps (download, convert).
@@ -290,18 +290,21 @@ def benchmark(url, dest, onnx_name, opset, imgs, verbose=True, threshold=1e-3,
290290
print("ORT", len(imgs), duration_ort)
291291

292292
# tensorflow
293-
import tensorflow_hub as hub
294-
from tensorflow import convert_to_tensor
295-
if isinstance(imgs[0], OrderedDict):
296-
imgs_tf = [
297-
OrderedDict((k, convert_to_tensor(v)) for k, v in img.items())
298-
for img in imgs]
293+
if custom_tf is None:
294+
import tensorflow_hub as hub
295+
from tensorflow import convert_to_tensor
296+
if isinstance(imgs[0], OrderedDict):
297+
imgs_tf = [
298+
OrderedDict((k, convert_to_tensor(v)) for k, v in img.items())
299+
for img in imgs]
300+
else:
301+
imgs_tf = [convert_to_tensor(img) for img in imgs]
302+
model = hub.load(url.split("?")[0])
303+
if signature is not None:
304+
model = model.signatures[signature]
305+
results_tf, duration_tf = measure_time(model, imgs_tf)
299306
else:
300-
imgs_tf = [convert_to_tensor(img) for img in imgs]
301-
model = hub.load(url.split("?")[0])
302-
if signature is not None:
303-
model = model.signatures[signature]
304-
results_tf, duration_tf = measure_time(model, imgs_tf)
307+
output, results_tf, duration_tf = custom_tf(tname)
305308

306309
if verbose:
307310
print("TF", len(imgs), duration_tf)
@@ -310,7 +313,10 @@ def benchmark(url, dest, onnx_name, opset, imgs, verbose=True, threshold=1e-3,
310313
print("ratio ORT=%r / TF=%r = %r" % (mean_ort, mean_tf, mean_ort / mean_tf))
311314

312315
# checks discrepencies
313-
res = model(imgs_tf[0])
316+
if custom_tf is None:
317+
res = model(imgs_tf[0])
318+
else:
319+
res = output
314320
if isinstance(res, dict):
315321
if output_name is None:
316322
if len(res) != 1:

tests/tfhub/tfhub_humpback_whale.py

Lines changed: 36 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -1,81 +1,53 @@
11
# SPDX-License-Identifier: Apache-2.0
22
import os
3+
os.environ["CUDA_VISIBLE_DEVICES"] = "-1"
4+
import pickle
35
import numpy
46
from onnxruntime import InferenceSession
5-
from _tools import generate_random_images, benchmark
7+
from _tools import generate_random_images, benchmark, measure_time
8+
import tensorflow as tf
9+
import tensorflow_hub as hub
610

711

812
def main(opset=13):
913
url = "https://tfhub.dev/google/humpback_whale/1?tf-hub-format=compressed"
1014
dest = "tf-humpback-whale"
1115
name = "humpback-whale"
1216
onnx_name = os.path.join(dest, "%s-%d.onnx" % (name, opset))
13-
14-
kind = "function"
15-
if kind == "function":
16-
import tensorflow as tf
17-
import tensorflow_hub as hub
18-
import tf2onnx
19-
model = hub.load('https://tfhub.dev/google/humpback_whale/1')
20-
FILENAME = 'gs://bioacoustics-www1/sounds/Cross_02_060203_071428.d20_7.wav'
21-
waveform, sample_rate = tf.audio.decode_wav(tf.io.read_file(FILENAME))
22-
waveform = tf.expand_dims(waveform, 0) # makes a batch of size 1
23-
context_step_samples = tf.cast(sample_rate, tf.int64)
24-
print(waveform.dtype, waveform.shape, sample_rate.dtype, sample_rate.shape, sample_rate)
25-
26-
spec = (tf.TensorSpec((None, ) + waveform.shape[-2:], tf.float32, name="waveform"),
27-
tf.TensorSpec((1, 1), tf.int64, name="context_step_samples"))
28-
inputs = {'waveform': waveform.numpy(),
29-
'context_step_samples': context_step_samples.numpy()}
30-
31-
tf2onnx.convert.from_function(
32-
model.signatures['score'], input_signature=spec, opset=13, output_path=onnx_name)
33-
# AttributeError: '_WrapperFunction' object has no attribute 'get_concrete_function'
34-
35-
sess = InferenceSession(onnx_name)
36-
got = sess.run(None, inputs)
37-
print(got)
38-
39-
score_fn = model.signatures['score']
40-
scores = score_fn(waveform=waveform, context_step_samples=context_step_samples)
41-
42-
if kind == "keras":
43-
import tensorflow as tf
44-
import tensorflow_hub as hub
45-
import tf2onnx
46-
model = hub.load('https://tfhub.dev/google/humpback_whale/1').model
47-
FILENAME = 'gs://bioacoustics-www1/sounds/Cross_02_060203_071428.d20_7.wav'
48-
waveform, sample_rate = tf.audio.decode_wav(tf.io.read_file(FILENAME))
49-
waveform = tf.expand_dims(waveform, 0) # makes a batch of size 1
50-
context_step_samples = tf.cast(sample_rate, tf.int64)
51-
print(waveform.dtype, waveform.shape, sample_rate.dtype, sample_rate.shape, sample_rate)
52-
53-
spec = (tf.TensorSpec((None, ) + waveform.shape[-2:], tf.float32, name="waveform"),
54-
tf.TensorSpec((1, 1), tf.int64, name="context_step_samples"))
55-
inputs = {'waveform': waveform.numpy(),
56-
'context_step_samples': context_step_samples.numpy()}
57-
58-
tf2onnx.convert.from_keras(model, input_signature=spec, opset=13, output_path=onnx_name)
59-
# AttributeError: '_UserObject' object has no attribute 'output_names'
60-
61-
sess = InferenceSession(onnx_name)
62-
got = sess.run(None, inputs)
63-
print(got)
64-
17+
print("[download data]")
18+
FILENAME = 'gs://bioacoustics-www1/sounds/Cross_02_060203_071428.d20_7.wav'
19+
pkl_name = os.path.join(dest, "data.pkl")
20+
if not os.path.exists(pkl_name):
21+
with open(pkl_name, "wb") as f:
22+
waveform, sample_rate = tf.audio.decode_wav(tf.io.read_file(FILENAME))
23+
waveform = tf.expand_dims(waveform, 0) # makes a batch of size 1
24+
context_step_samples = tf.cast(sample_rate, tf.int64)
25+
data = dict(waveform=waveform, context_step_samples=context_step_samples)
26+
pickle.dump(data, f)
27+
else:
28+
with open(pkl_name, "rb") as f:
29+
data = pickle.load(f)
30+
waveform = data["waveform"]
31+
context_step_samples = data["context_step_samples"]
32+
print("[data] done. context_step_samples=", context_step_samples.numpy())
33+
34+
def benchmark_custom(local_name):
35+
model = hub.load(local_name)
6536
score_fn = model.signatures['score']
6637
scores = score_fn(waveform=waveform, context_step_samples=context_step_samples)
38+
imgs_tf = [dict(waveform=waveform, context_step_samples=context_step_samples)]
39+
results_tf, duration_tf = measure_time(
40+
lambda inputs: score_fn(**inputs), imgs_tf)
41+
return scores, results_tf, duration_tf
6742

68-
if kind == 'cmd':
69-
imgs = generate_random_images(shape=(1, 10000, 1), scale=1.)
70-
inputs = [dict(waveform=img,
71-
context_step_samples=numpy.array(512, dtype=numpy.int64))
72-
for img in imgs]
73-
benchmark(url, dest, onnx_name, opset, inputs, optimize=False,
74-
signature='score')
75-
# onnxruntime.capi.onnxruntime_pybind11_state.RuntimeException:
76-
# [ONNXRuntimeError] : 6 : RUNTIME_EXCEPTION : Non-zero status code returned while running Reshape node. Name:'StatefulPartitionedCall/Reshape_1' Status Message: C:\xadupre\microsoft_xadupre\onnxruntime\onnxruntime\core\providers\cpu\tensor\reshape_helper.h:42 onnxruntime::ReshapeHelper::ReshapeHelper gsl::narrow_cast<int64_t>(input_shape.Size()) == size was false. The input tensor cannot be reshaped to the requested shape.
77-
# Input shape:{0,1}, requested shape:{1,1,1}
43+
imgs = generate_random_images(shape=(1, 750000, 1), scale=1., n=2)
44+
inputs = [dict(waveform=waveform.numpy(),
45+
context_step_samples=numpy.array(
46+
context_step_samples.numpy(), dtype=numpy.int64))]
47+
benchmark(url, dest, onnx_name, opset, inputs, optimize=False,
48+
signature='score', custom_tf=benchmark_custom)
7849

7950

8051
if __name__ == "__main__":
81-
main()
52+
with tf.device('/CPU:0'):
53+
main()

0 commit comments

Comments
 (0)