6
6
import tensorflow as tf
7
7
from test_utils import convert_keras_for_test as convert_keras
8
8
from mock_keras2onnx .proto import is_tensorflow_older_than
9
+ import tf2onnx
9
10
10
11
if (not mock_keras2onnx .proto .is_tf_keras ) or (not mock_keras2onnx .proto .tfcompat .is_tf2 ):
11
12
pytest .skip ("Tensorflow 2.0 only tests." , allow_module_level = True )
@@ -17,7 +18,7 @@ def __init__(self):
17
18
self .conv2d_1 = tf .keras .layers .Conv2D (filters = 6 ,
18
19
kernel_size = (3 , 3 ), activation = 'relu' ,
19
20
input_shape = (32 , 32 , 1 ))
20
- self .average_pool = tf .keras .layers .AveragePooling2D ()
21
+ self .average_pool = tf .keras .layers .AveragePooling2D (( 3 , 3 ) )
21
22
self .conv2d_2 = tf .keras .layers .Conv2D (filters = 16 ,
22
23
kernel_size = (3 , 3 ), activation = 'relu' )
23
24
self .flatten = tf .keras .layers .Flatten ()
@@ -91,8 +92,9 @@ def test_lenet(runner):
91
92
lenet = LeNet ()
92
93
data = np .random .rand (2 * 416 * 416 * 3 ).astype (np .float32 ).reshape (2 , 416 , 416 , 3 )
93
94
expected = lenet (data )
94
- lenet ._set_inputs (data )
95
- oxml = convert_keras (lenet )
95
+ if hasattr (lenet , "_set_inputs" ):
96
+ lenet ._set_inputs (data )
97
+ oxml = convert_keras (lenet , input_signature = [tf .TensorSpec ([None , None , None , None ], tf .float32 )])
96
98
assert runner ('lenet' , oxml , data , expected )
97
99
98
100
@@ -234,15 +236,28 @@ def call(self, inputs, **kwargs):
234
236
swm = Model ()
235
237
const_in = [tf .Variable ([2 , 4 , 6 , 8 , 10 ], dtype = tf .int32 , name = "input" )]
236
238
expected = swm (const_in )
237
- if hasattr (swm , "_set_input" ):
238
- swm ._set_inputs (const_in )
239
- else :
240
- swm .inputs_spec = const_in
241
- if hasattr (swm , "_set_output" ):
242
- swm ._set_output (expected )
243
- else :
244
- swm .outputs_spec = expected
245
- oxml = convert_keras (swm )
239
+
240
+ """
241
+ for op in concrete_func.graph.get_operations():
242
+ print("--", op.name)
243
+ print(op)
244
+
245
+ print("***", concrete_func.inputs)
246
+ print("***", concrete_func.outputs)
247
+ """
248
+ run_model = tf .function (swm )
249
+ concrete_func = run_model .get_concrete_function (tf .TensorSpec ([None ], tf .int32 ))
250
+ model_proto , external_tensor_storage = tf2onnx .convert ._convert_common (
251
+ concrete_func .graph .as_graph_def (),
252
+ input_names = [i .name for i in concrete_func .inputs ],
253
+ output_names = [i .name for i in concrete_func .outputs ],
254
+ large_model = False ,
255
+ output_path = "where_test.onnx" ,
256
+ )
257
+ assert model_proto
258
+ assert not external_tensor_storage
259
+
260
+ oxml = convert_keras (swm , input_signature = [tf .TensorSpec ([None ], tf .int32 )])
246
261
assert runner ('where_test' , oxml , const_in , expected )
247
262
248
263
0 commit comments