29
29
POSSIBLE_TARGETS = [TARGET_RS4 , TARGET_CAFFE2 ]
30
30
DEFAULT_TARGET = [TARGET_RS4 , TARGET_CAFFE2 ]
31
31
32
+
32
33
def tensorflow_to_onnx (graph ):
33
34
"""
34
35
Load tensorflow graph into an onnx graph with minimal rewrites so
@@ -119,10 +120,9 @@ def _convert_shapenode_to_int64(ctx, node, input_number):
119
120
shape_node .set_attr ("value" , onnx_tensor )
120
121
return [node ]
121
122
else :
122
- op_name = utils .make_name (node .name )
123
- cast_op = ctx .insert_new_node_on_input (node , "Cast" , name , name = op_name )
123
+ cast_op = ctx .insert_new_node_on_input (node , "Cast" , name )
124
124
cast_op .set_attr ("to" , onnx_pb .TensorProto .INT64 )
125
- ctx .copy_shape (name , op_name + ":0" )
125
+ ctx .copy_shape (name , cast_op . output [ 0 ] )
126
126
return [cast_op , node ]
127
127
128
128
# pylint: disable=W0613,C0111,W0612
@@ -274,8 +274,29 @@ def reshape_op(ctx, node, name, args):
274
274
275
275
276
276
def reshape_op5 (ctx , node , name , args ):
277
+ need_casting = node .dtype in [onnx_pb .TensorProto .INT32 ,
278
+ onnx_pb .TensorProto .INT16 ,
279
+ onnx_pb .TensorProto .INT64 ]
277
280
# onnx wants reshape.input[1] to have the value be int64 which is not the case for tensorflow.
278
- return _convert_shapenode_to_int64 (ctx , node , 1 )
281
+ nodes = _convert_shapenode_to_int64 (ctx , node , 1 )
282
+ if not need_casting :
283
+ # onnx reshape can handle the type - done
284
+ return nodes
285
+
286
+ # onnx < opset 8 does not know reshape for other types than float*, wrap the reshape in casts
287
+ input_cast = ctx .insert_new_node_on_input (node , "Cast" , node .input [0 ])
288
+ input_cast .set_attr ("to" , onnx_pb .TensorProto .FLOAT )
289
+ ctx .copy_shape (name , input_cast .output [0 ])
290
+
291
+ # if the next node is already a cast we don't need to insert another one
292
+ next_nodes = ctx .find_output_consumers (node .output [0 ])
293
+ if len (next_nodes ) != 1 or next_nodes [0 ].type != "Cast" :
294
+ op_name = utils .make_name (node .name )
295
+ output_cast = ctx .insert_new_node_on_output ("Cast" , node .output [0 ], name = op_name )
296
+ output_cast .set_attr ("to" , node .dtype )
297
+ ctx .copy_shape (name , output_cast .output [0 ])
298
+ nodes .append (output_cast )
299
+ return [input_cast ] + nodes
279
300
280
301
281
302
NCHW_TO_NHWC = [0 , 2 , 3 , 1 ]
@@ -317,8 +338,7 @@ def calc_shape(a, b):
317
338
else :
318
339
# if input comes from a op, insert transpose op
319
340
input_name = node .input [0 ]
320
- op_name = utils .make_name (node .name )
321
- transpose = ctx .insert_new_node_on_input (node , "Transpose" , input_name , name = op_name )
341
+ transpose = ctx .insert_new_node_on_input (node , "Transpose" , input_name )
322
342
transpose .set_attr ("perm" , NHWC_TO_NCHW )
323
343
transpose .inserted_nchw = True
324
344
ctx .set_shape (transpose .output [0 ], calc_shape (ctx .get_shape (input_name ), NHWC_TO_NCHW ))
@@ -336,9 +356,8 @@ def calc_shape(a, b):
336
356
parent .data_format = "NCHW"
337
357
else :
338
358
# kernel comes from op, insert transpose op
339
- op_name = utils .make_name (node .name )
340
359
input_name = node .input [1 ]
341
- transpose = ctx .insert_new_node_on_input (node , "Transpose" , input_name , name = op_name )
360
+ transpose = ctx .insert_new_node_on_input (node , "Transpose" , input_name )
342
361
transpose .set_attr ("perm" , HWCN_TO_NCHW )
343
362
transpose .inserted_nchw = True
344
363
ctx .copy_shape (input_name , transpose .output [0 ])
@@ -349,18 +368,16 @@ def calc_shape(a, b):
349
368
if new_kernel_shape :
350
369
if ctx .opset < 5 :
351
370
# old reshape takes new shape as attribute
352
- op_name = utils .make_name (node .name )
353
371
input_name = node .input [1 ]
354
- reshape = ctx .insert_new_node_on_input (node , "Reshape" , input_name , name = op_name )
372
+ reshape = ctx .insert_new_node_on_input (node , "Reshape" , input_name )
355
373
reshape .set_attr ("shape" , new_kernel_shape )
356
374
ctx .set_shape (reshape .output [0 ], new_kernel_shape )
357
375
else :
358
376
# new reshape takes new shape as input[1]
359
- op_name = utils .make_name (node .name )
360
377
shape_name = utils .make_name (node .name )
361
378
shape_node = ctx .make_const (shape_name , "Const" , np .array (new_kernel_shape , dtype = np .int64 ))
362
379
input_name = node .input [1 ]
363
- reshape = ctx .insert_new_node_on_input (node , "Reshape" , input_name , name = op_name )
380
+ reshape = ctx .insert_new_node_on_input (node , "Reshape" , input_name )
364
381
reshape .input .append (shape_name )
365
382
ctx .set_shape (reshape .output [0 ], new_kernel_shape )
366
383
nodes .append (reshape )
@@ -820,7 +837,7 @@ def minmax_op(ctx, node, name, args):
820
837
input_node = node .inputs [i ]
821
838
dtype = ctx .dtypes [node .input [i ]]
822
839
zero_name = utils .make_name (input_node .name )
823
- zero_node = ctx .make_const (zero_name , "Const" , np .zeros (shapeo , dtype = utils .ONNX_TO_NUMPY_DTYPE [dtype ]))
840
+ ctx .make_const (zero_name , "Const" , np .zeros (shapeo , dtype = utils .ONNX_TO_NUMPY_DTYPE [dtype ]))
824
841
op_name = utils .make_name (input_node .name )
825
842
output_name = op_name + ":0"
826
843
add_node = Node (helper .make_node ("Add" , [input_node .output [0 ], zero_name ],
@@ -853,6 +870,7 @@ def pack_op(ctx, node, name, args):
853
870
ctx .replace_all_inputs (ctx .get_nodes (), node .output [0 ], output_name )
854
871
return [concat ] + nodes
855
872
873
+
856
874
def unpack_op (ctx , node , name , args ):
857
875
# hack to make up for the missing onnx unpack op
858
876
axis = node .get_attr ("axis" ).i
@@ -870,6 +888,44 @@ def unpack_op(ctx, node, name, args):
870
888
return nodes
871
889
872
890
891
+ def onehot_op (ctx , node , name , args ):
892
+ # until there is no onehot op in onnx, a workaround using gather from eye
893
+ indices_name = node .input [0 ]
894
+ indices_shape = ctx .get_shape (indices_name )
895
+ if len (indices_shape ) != 1 :
896
+ # TODO: this works for rank=1 but tensorflow supports more than this.
897
+ # Same principle should work but we need to implemtn our own eye.
898
+ raise ValueError ("onehot op: only rank1 is supported" )
899
+ axis = node .get_attr ("axis" )
900
+ # axis becomes axis for gather
901
+ node .set_attr ("axis" , 0 )
902
+ depth = node .inputs [1 ].get_tensor_value ()[0 ]
903
+ on = node .inputs [2 ].get_tensor_value ()[0 ]
904
+ off = node .inputs [3 ].get_tensor_value ()[0 ]
905
+ dtype = node .inputs [2 ].get_tensor_type ()
906
+ eye = np .eye (depth , dtype = dtype )
907
+ if on != 0 :
908
+ eye [eye == 1 ] = on
909
+ eye [eye == 0 ] = off
910
+ else :
911
+ eye [eye == 0 ] = off
912
+ eye [eye == 1 ] = on
913
+ const_name = utils .make_name (node .name )
914
+ ctx .make_const (const_name , "Const" , eye )
915
+ # setup gather inputs
916
+ del node .input [:]
917
+ node .input .append (const_name )
918
+ node .input .append (indices_name )
919
+ node .type = "Gather"
920
+ if axis .i == 0 :
921
+ # TODO: revisit for rank > 1
922
+ name = utils .make_name (node .name )
923
+ transpose_op = ctx .insert_new_node_on_output ("Transpose" , node .output [0 ], name )
924
+ ctx .copy_shape (node .output [0 ], transpose_op .output [0 ])
925
+ return [node , transpose_op ]
926
+ return node
927
+
928
+
873
929
# pylint: enable=W0613,C0111,W0612
874
930
875
931
# map tensorflow ops to onnx ops. The format below is
@@ -962,6 +1018,7 @@ def unpack_op(ctx, node, name, args):
962
1018
_OPSET_5 = {
963
1019
"Reshape" : (reshape_op5 , []),
964
1020
"ExpandDims" : (expanddims_op7 , []),
1021
+ "OneHot" : (onehot_op , []),
965
1022
}
966
1023
967
1024
_OPSET_6 = {
@@ -1183,7 +1240,7 @@ def tensorflow_onnx_mapping(g, continue_on_error, custom_op_handlers):
1183
1240
def tf_optimize (sess , inputs , outputs , graph_def ):
1184
1241
"""Optimize tensorflow graph for inference."""
1185
1242
transforms = [
1186
- # "fold_constants(ignore_errors=true)",
1243
+ "fold_constants(ignore_errors=true)" ,
1187
1244
"fold_batch_norms" ,
1188
1245
"fold_old_batch_norms" ,
1189
1246
]
0 commit comments