Skip to content

Commit 0ea9a15

Browse files
committed
add an end2end example
Signed-off-by: xavier dupré <[email protected]>
1 parent 07c51be commit 0ea9a15

File tree

1 file changed

+47
-0
lines changed

1 file changed

+47
-0
lines changed

examples/end2end_tfkeras.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
import os
2+
import subprocess
3+
import numpy as np
4+
import tensorflow as tf
5+
from tensorflow import keras
6+
from tensorflow.keras import layers, Input
7+
8+
9+
# Creates the model.
10+
model = keras.Sequential()
11+
#model.add(layers.Embedding(input_dim=10, output_dim=4))
12+
model.add(Input((4, 4)))
13+
model.add(layers.SimpleRNN(8))
14+
model.add(layers.Dense(2))
15+
print(model.summary())
16+
print(model.inputs)
17+
print(model.outputs)
18+
19+
# Testing the model.
20+
input = np.random.randn(2, 4, 4).astype(np.float32)
21+
expected = model.predict(input)
22+
print(expected)
23+
24+
# Training
25+
# ....
26+
27+
# Saves the model.
28+
if not os.path.exists("simple_rnn"):
29+
os.mkdir("simple_rnn")
30+
tf.keras.models.save_model(model, "simple_rnn")
31+
32+
# Run the command line.
33+
proc = subprocess.run('python -m tf2onnx.convert --saved-model simple_rnn '
34+
'--output simple_rnn.onnx --opset 12'.split(),
35+
capture_output=True)
36+
print(proc.returncode)
37+
print(proc.stdout.decode('ascii'))
38+
print(proc.stderr.decode('ascii'))
39+
40+
# Run onnxruntime.
41+
from onnxruntime import InferenceSession
42+
session = InferenceSession("simple_rnn.onnx")
43+
got = session.run(None, {'input_1:0': input})
44+
print(got[0])
45+
46+
# Differences
47+
print(np.abs(got[0] - expected).max())

0 commit comments

Comments
 (0)