1
+ # SPDX-License-Identifier: Apache-2.0
2
+
3
+ """
4
+ This example shows how to convert tf functions and keras models using the Python API.
5
+ It also demonstrates converting saved_models from the command line.
6
+ """
7
+
8
+ import tensorflow as tf
9
+ import tf2onnx
10
+ import numpy as np
11
+ import onnxruntime as ort
12
+ import os
13
+
14
+ ##################### tf function #####################
15
+
16
+ @tf .function
17
+ def f (a , b ):
18
+ return a + b
19
+
20
+ input_signature = [tf .TensorSpec ([2 , 3 ], tf .float32 ), tf .TensorSpec ([2 , 3 ], tf .float32 )]
21
+ onnx_model , _ = tf2onnx .convert .from_function (f , input_signature , opset = 13 )
22
+
23
+ a_val = np .ones ([2 , 3 ], np .float32 )
24
+ b_val = np .zeros ([2 , 3 ], np .float32 )
25
+
26
+ print ("Tensorflow result" )
27
+ print (f (a_val , b_val ).numpy ())
28
+
29
+ print ("ORT result" )
30
+ sess = ort .InferenceSession (onnx_model .SerializeToString ())
31
+ res = sess .run (None , {'a' : a_val , 'b' : b_val })
32
+ print (res [0 ])
33
+
34
+
35
+ ##################### Keras Model #####################
36
+
37
+ model = tf .keras .Sequential ()
38
+ model .add (tf .keras .layers .Dense (4 , activation = "relu" ))
39
+
40
+ input_signature = [tf .TensorSpec ([3 , 3 ], tf .float32 , name = 'x' )]
41
+ onnx_model , _ = tf2onnx .convert .from_keras (model , input_signature , opset = 13 )
42
+
43
+ x_val = np .ones ((3 , 3 ), np .float32 )
44
+
45
+ print ("Keras result" )
46
+ print (model (x_val ).numpy ())
47
+
48
+ print ("ORT result" )
49
+ sess = ort .InferenceSession (onnx_model .SerializeToString ())
50
+ res = sess .run (None , {'x' : x_val })
51
+ print (res [0 ])
52
+
53
+
54
+ ##################### Saved Model #####################
55
+
56
+ model .save ("savedmodel" )
57
+ os .system ("python -m tf2onnx.convert --saved-model savedmodel --output model.onnx --opset 13" )
58
+
59
+ print ("ORT result" )
60
+ sess = ort .InferenceSession ("model.onnx" )
61
+ res = sess .run (None , {'dense_input:0' : x_val })
62
+ print (res [0 ])
63
+
64
+ print ("Conversion succeeded" )
0 commit comments