File tree Expand file tree Collapse file tree 1 file changed +47
-0
lines changed Expand file tree Collapse file tree 1 file changed +47
-0
lines changed Original file line number Diff line number Diff line change
1
+ import os
2
+ import subprocess
3
+ import numpy as np
4
+ import tensorflow as tf
5
+ from tensorflow import keras
6
+ from tensorflow .keras import layers , Input
7
+
8
+
9
+ # Creates the model.
10
+ model = keras .Sequential ()
11
+ #model.add(layers.Embedding(input_dim=10, output_dim=4))
12
+ model .add (Input ((4 , 4 )))
13
+ model .add (layers .SimpleRNN (8 ))
14
+ model .add (layers .Dense (2 ))
15
+ print (model .summary ())
16
+ print (model .inputs )
17
+ print (model .outputs )
18
+
19
+ # Testing the model.
20
+ input = np .random .randn (2 , 4 , 4 ).astype (np .float32 )
21
+ expected = model .predict (input )
22
+ print (expected )
23
+
24
+ # Training
25
+ # ....
26
+
27
+ # Saves the model.
28
+ if not os .path .exists ("simple_rnn" ):
29
+ os .mkdir ("simple_rnn" )
30
+ tf .keras .models .save_model (model , "simple_rnn" )
31
+
32
+ # Run the command line.
33
+ proc = subprocess .run ('python -m tf2onnx.convert --saved-model simple_rnn '
34
+ '--output simple_rnn.onnx --opset 12' .split (),
35
+ capture_output = True )
36
+ print (proc .returncode )
37
+ print (proc .stdout .decode ('ascii' ))
38
+ print (proc .stderr .decode ('ascii' ))
39
+
40
+ # Run onnxruntime.
41
+ from onnxruntime import InferenceSession
42
+ session = InferenceSession ("simple_rnn.onnx" )
43
+ got = session .run (None , {'input_1:0' : input })
44
+ print (got [0 ])
45
+
46
+ # Differences
47
+ print (np .abs (got [0 ] - expected ).max ())
You can’t perform that action at this time.
0 commit comments