|
1 | 1 | # SPDX-License-Identifier: Apache-2.0
|
2 | 2 | import os
|
| 3 | +os.environ["CUDA_VISIBLE_DEVICES"] = "-1" |
| 4 | +import pickle |
3 | 5 | import numpy
|
4 | 6 | from onnxruntime import InferenceSession
|
5 |
| -from _tools import generate_random_images, benchmark |
| 7 | +from _tools import generate_random_images, benchmark, measure_time |
| 8 | +import tensorflow as tf |
| 9 | +import tensorflow_hub as hub |
6 | 10 |
|
7 | 11 |
|
8 | 12 | def main(opset=13):
|
9 | 13 | url = "https://tfhub.dev/google/humpback_whale/1?tf-hub-format=compressed"
|
10 | 14 | dest = "tf-humpback-whale"
|
11 | 15 | name = "humpback-whale"
|
12 | 16 | onnx_name = os.path.join(dest, "%s-%d.onnx" % (name, opset))
|
13 |
| - |
14 |
| - kind = "function" |
15 |
| - if kind == "function": |
16 |
| - import tensorflow as tf |
17 |
| - import tensorflow_hub as hub |
18 |
| - import tf2onnx |
19 |
| - model = hub.load('https://tfhub.dev/google/humpback_whale/1') |
20 |
| - FILENAME = 'gs://bioacoustics-www1/sounds/Cross_02_060203_071428.d20_7.wav' |
21 |
| - waveform, sample_rate = tf.audio.decode_wav(tf.io.read_file(FILENAME)) |
22 |
| - waveform = tf.expand_dims(waveform, 0) # makes a batch of size 1 |
23 |
| - context_step_samples = tf.cast(sample_rate, tf.int64) |
24 |
| - print(waveform.dtype, waveform.shape, sample_rate.dtype, sample_rate.shape, sample_rate) |
25 |
| - |
26 |
| - spec = (tf.TensorSpec((None, ) + waveform.shape[-2:], tf.float32, name="waveform"), |
27 |
| - tf.TensorSpec((1, 1), tf.int64, name="context_step_samples")) |
28 |
| - inputs = {'waveform': waveform.numpy(), |
29 |
| - 'context_step_samples': context_step_samples.numpy()} |
30 |
| - |
31 |
| - tf2onnx.convert.from_function( |
32 |
| - model.signatures['score'], input_signature=spec, opset=13, output_path=onnx_name) |
33 |
| - # AttributeError: '_WrapperFunction' object has no attribute 'get_concrete_function' |
34 |
| - |
35 |
| - sess = InferenceSession(onnx_name) |
36 |
| - got = sess.run(None, inputs) |
37 |
| - print(got) |
38 |
| - |
39 |
| - score_fn = model.signatures['score'] |
40 |
| - scores = score_fn(waveform=waveform, context_step_samples=context_step_samples) |
41 |
| - |
42 |
| - if kind == "keras": |
43 |
| - import tensorflow as tf |
44 |
| - import tensorflow_hub as hub |
45 |
| - import tf2onnx |
46 |
| - model = hub.load('https://tfhub.dev/google/humpback_whale/1').model |
47 |
| - FILENAME = 'gs://bioacoustics-www1/sounds/Cross_02_060203_071428.d20_7.wav' |
48 |
| - waveform, sample_rate = tf.audio.decode_wav(tf.io.read_file(FILENAME)) |
49 |
| - waveform = tf.expand_dims(waveform, 0) # makes a batch of size 1 |
50 |
| - context_step_samples = tf.cast(sample_rate, tf.int64) |
51 |
| - print(waveform.dtype, waveform.shape, sample_rate.dtype, sample_rate.shape, sample_rate) |
52 |
| - |
53 |
| - spec = (tf.TensorSpec((None, ) + waveform.shape[-2:], tf.float32, name="waveform"), |
54 |
| - tf.TensorSpec((1, 1), tf.int64, name="context_step_samples")) |
55 |
| - inputs = {'waveform': waveform.numpy(), |
56 |
| - 'context_step_samples': context_step_samples.numpy()} |
57 |
| - |
58 |
| - tf2onnx.convert.from_keras(model, input_signature=spec, opset=13, output_path=onnx_name) |
59 |
| - # AttributeError: '_UserObject' object has no attribute 'output_names' |
60 |
| - |
61 |
| - sess = InferenceSession(onnx_name) |
62 |
| - got = sess.run(None, inputs) |
63 |
| - print(got) |
64 |
| - |
| 17 | + print("[download data]") |
| 18 | + FILENAME = 'gs://bioacoustics-www1/sounds/Cross_02_060203_071428.d20_7.wav' |
| 19 | + pkl_name = os.path.join(dest, "data.pkl") |
| 20 | + if not os.path.exists(pkl_name): |
| 21 | + with open(pkl_name, "wb") as f: |
| 22 | + waveform, sample_rate = tf.audio.decode_wav(tf.io.read_file(FILENAME)) |
| 23 | + waveform = tf.expand_dims(waveform, 0) # makes a batch of size 1 |
| 24 | + context_step_samples = tf.cast(sample_rate, tf.int64) |
| 25 | + data = dict(waveform=waveform, context_step_samples=context_step_samples) |
| 26 | + pickle.dump(data, f) |
| 27 | + else: |
| 28 | + with open(pkl_name, "rb") as f: |
| 29 | + data = pickle.load(f) |
| 30 | + waveform = data["waveform"] |
| 31 | + context_step_samples = data["context_step_samples"] |
| 32 | + print("[data] done. context_step_samples=", context_step_samples.numpy()) |
| 33 | + |
| 34 | + def benchmark_custom(local_name): |
| 35 | + model = hub.load(local_name) |
65 | 36 | score_fn = model.signatures['score']
|
66 | 37 | scores = score_fn(waveform=waveform, context_step_samples=context_step_samples)
|
| 38 | + imgs_tf = [dict(waveform=waveform, context_step_samples=context_step_samples)] |
| 39 | + results_tf, duration_tf = measure_time( |
| 40 | + lambda inputs: score_fn(**inputs), imgs_tf) |
| 41 | + return scores, results_tf, duration_tf |
67 | 42 |
|
68 |
| - if kind == 'cmd': |
69 |
| - imgs = generate_random_images(shape=(1, 10000, 1), scale=1.) |
70 |
| - inputs = [dict(waveform=img, |
71 |
| - context_step_samples=numpy.array(512, dtype=numpy.int64)) |
72 |
| - for img in imgs] |
73 |
| - benchmark(url, dest, onnx_name, opset, inputs, optimize=False, |
74 |
| - signature='score') |
75 |
| - # onnxruntime.capi.onnxruntime_pybind11_state.RuntimeException: |
76 |
| - # [ONNXRuntimeError] : 6 : RUNTIME_EXCEPTION : Non-zero status code returned while running Reshape node. Name:'StatefulPartitionedCall/Reshape_1' Status Message: C:\xadupre\microsoft_xadupre\onnxruntime\onnxruntime\core\providers\cpu\tensor\reshape_helper.h:42 onnxruntime::ReshapeHelper::ReshapeHelper gsl::narrow_cast<int64_t>(input_shape.Size()) == size was false. The input tensor cannot be reshaped to the requested shape. |
77 |
| - # Input shape:{0,1}, requested shape:{1,1,1} |
| 43 | + imgs = generate_random_images(shape=(1, 750000, 1), scale=1., n=2) |
| 44 | + inputs = [dict(waveform=waveform.numpy(), |
| 45 | + context_step_samples=numpy.array( |
| 46 | + context_step_samples.numpy(), dtype=numpy.int64))] |
| 47 | + benchmark(url, dest, onnx_name, opset, inputs, optimize=False, |
| 48 | + signature='score', custom_tf=benchmark_custom) |
78 | 49 |
|
79 | 50 |
|
80 | 51 | if __name__ == "__main__":
|
81 |
| - main() |
| 52 | + with tf.device('/CPU:0'): |
| 53 | + main() |
0 commit comments