@@ -105,8 +105,25 @@ def tensorflow_to_onnx(graph):
105
105
return onnx_nodes , op_cnt , attr_cnt , output_shapes , dtypes
106
106
107
107
108
- # pylint: disable=W0613,C0111,W0612
108
+ def _convert_shapenode_to_int64 (ctx , node , input_number ):
109
+ shape_node = node .inputs [1 ]
110
+ name = node .input [1 ]
111
+ if shape_node .is_const ():
112
+ # if it is a const, change the const to be int64
113
+ shape = shape_node .get_tensor_value ()
114
+ shape = np .array (list (shape ), dtype = np .int64 )
115
+ onnx_tensor = numpy_helper .from_array (shape , name )
116
+ ctx ._initializers [name ] = onnx_tensor
117
+ shape_node .set_attr ("value" , onnx_tensor )
118
+ return [node ]
119
+ else :
120
+ op_name = utils .make_name (node .name )
121
+ cast_op = ctx .insert_new_node_on_input (node , "Cast" , name , name = op_name )
122
+ cast_op .set_attr ("to" , onnx_pb .TensorProto .INT64 )
123
+ ctx .copy_shape (name , op_name + ":0" )
124
+ return [cast_op , node ]
109
125
126
+ # pylint: disable=W0613,C0111,W0612
110
127
111
128
def no_op (ctx , node , name , args ):
112
129
"""Skip node."""
@@ -255,23 +272,8 @@ def reshape_op(ctx, node, name, args):
255
272
256
273
257
274
def reshape_op5 (ctx , node , name , args ):
258
- shape_node = node .inputs [1 ]
259
275
# onnx wants reshape.input[1] to have the value be int64 which is not the case for tensorflow.
260
- name = node .input [1 ]
261
- if shape_node .is_const ():
262
- # if it is a const, change the const to be int64
263
- shape = shape_node .get_tensor_value ()
264
- shape = np .array (list (shape ), dtype = np .int64 )
265
- onnx_tensor = numpy_helper .from_array (shape , name )
266
- ctx ._initializers [name ] = onnx_tensor
267
- shape_node .set_attr ("value" , onnx_tensor )
268
- return node
269
- else :
270
- op_name = utils .make_name (node .name )
271
- cast_op = ctx .insert_new_node_on_input (node , "Cast" , name , name = op_name )
272
- cast_op .set_attr ("to" , onnx_pb .TensorProto .INT64 )
273
- ctx .copy_shape (name , op_name + ":0" )
274
- return [cast_op , node ]
276
+ return _convert_shapenode_to_int64 (ctx , node , 1 )
275
277
276
278
277
279
NCHW_TO_NHWC = [0 , 2 , 3 , 1 ]
@@ -690,7 +692,7 @@ def expanddims_op(ctx, node, name, args):
690
692
def expanddims_op7 (ctx , node , name , args ):
691
693
shape = ctx .get_shape (node .output [0 ])
692
694
shape_name = utils .make_name (node .name )
693
- shape_node = ctx .make_const (shape_name , "Const" , np .array (shape ))
695
+ shape_node = ctx .make_const (shape_name , "Const" , np .array (shape , dtype = np . int64 ))
694
696
node .type = "Reshape"
695
697
node .input [1 ] = shape_name
696
698
return node
@@ -786,6 +788,11 @@ def topk_op(ctx, node, name, args):
786
788
return node
787
789
788
790
791
+ def tile_op7 (ctx , node , name , args ):
792
+ # onnx wants shape input to be int64
793
+ return _convert_shapenode_to_int64 (ctx , node , 1 )
794
+
795
+
789
796
# pylint: enable=W0613,C0111,W0612
790
797
791
798
# map tensorflow ops to onnx ops. The format below is
@@ -881,7 +888,7 @@ def topk_op(ctx, node, name, args):
881
888
}
882
889
883
890
_OPSET_7 = {
884
- "Tile" : (direct_op , []),
891
+ "Tile" : (tile_op7 , []),
885
892
"ResizeNearestNeighbor" : (upsample_op , []),
886
893
"BiasAdd" : (biasadd_op7 , []),
887
894
"BiasAddV1" : (biasadd_op7 , []),
0 commit comments