@@ -64,7 +64,7 @@ def tflist_to_onnx(node_list, shape_override):
64
64
# ignore the following attributes
65
65
ignored_attr = ["unknown_rank" , "_class" , "Tshape" , "use_cudnn_on_gpu" , "Index" , "Tpaddings" ,
66
66
"TI" , "Tparams" , "Tindices" , "Tlen" , "Tdim" , "dynamic_size" , "Tmultiples" ,
67
- "output_dtype" , " Tblock_shape" , "Tcrops" , "index_type" , "Taxis" , "U" , "maxval" ,
67
+ "Tblock_shape" , "Tcrops" , "index_type" , "Taxis" , "U" , "maxval" ,
68
68
"Tout" , "Tlabels" , "Tindex" , "element_shape" ]
69
69
# some stats
70
70
op_cnt = collections .Counter ()
@@ -239,10 +239,22 @@ def arg_minmax_op(ctx, node, name, args):
239
239
dim_count = len (input_shape ) if input_shape else 0
240
240
axis = dim_count + axis
241
241
242
+ nodes = [node ]
243
+ # TF ArgMin/ArgMax may return int32 or int64
244
+ # Onnx ArgMin/ArgMax only supports int64 output, add cast if needed
245
+ if node .get_attr_int ("output_type" ) == onnx_pb .TensorProto .INT32 :
246
+ # current node will return int64 after conversion, which differs from previous dtype got from tf
247
+ ctx .set_dtype (node .output [0 ], onnx_pb .TensorProto .INT64 )
248
+ op_name = utils .make_name ("Cast" )
249
+ cast_node = ctx .insert_new_node_on_output ("Cast" , node .output [0 ], name = op_name , to = onnx_pb .TensorProto .INT32 )
250
+ ctx .set_dtype (cast_node .output [0 ], onnx_pb .TensorProto .INT32 )
251
+ ctx .copy_shape (node .output [0 ], cast_node .output [0 ])
252
+ nodes .append (cast_node )
253
+
242
254
node .set_attr ("axis" , axis )
243
255
node .set_attr ("keepdims" , 0 )
244
256
ctx .remove_input (node , node .input [1 ])
245
- return node
257
+ return nodes
246
258
247
259
248
260
def reduce_op (ctx , node , name , args ):
@@ -700,7 +712,7 @@ def biasadd_op7(ctx, node, name, args):
700
712
shape0 = ctx .get_shape (node .input [0 ])
701
713
shape1 = ctx .get_shape (node .input [1 ])
702
714
if node .inputs [1 ].type == 'Const' and len (shape1 ) == 1 :
703
- new_broadcast_shape = [shape1 [0 ], ] + [1 , ] * (len (shape0 ) - 2 )
715
+ new_broadcast_shape = [shape1 [0 ]] + [1 ] * (len (shape0 ) - 2 )
704
716
shape_name = utils .make_name (node .name )
705
717
shape_const_node = ctx .make_const (shape_name , np .array (new_broadcast_shape , dtype = np .int64 ))
706
718
op_name = node .input [1 ]
@@ -1191,7 +1203,6 @@ def minmax_op(ctx, node, name, args):
1191
1203
1192
1204
1193
1205
def pack_op (ctx , node , name , args ):
1194
-
1195
1206
# hack to make up for the missing onnx pack op
1196
1207
axis = node .get_attr ("axis" ).i
1197
1208
if axis < 0 :
@@ -1333,14 +1344,18 @@ def matmul_op(ctx, node, name, args):
1333
1344
shape = ctx .get_shape (node .input [0 ])
1334
1345
if shape :
1335
1346
perm = list (range (0 , len (shape )))
1336
- tmp = perm [- 1 ]; perm [- 1 ] = perm [- 2 ]; perm [- 2 ] = tmp
1347
+ tmp = perm [- 1 ]
1348
+ perm [- 1 ] = perm [- 2 ]
1349
+ perm [- 2 ] = tmp
1337
1350
transpose = ctx .insert_new_node_on_input (node , "Transpose" , node .input [0 ], perm = perm )
1338
1351
nodes .insert (0 , transpose )
1339
1352
if transpose_b != 0 :
1340
1353
shape = ctx .get_shape (node .input [1 ])
1341
1354
if shape :
1342
1355
perm = list (range (0 , len (shape )))
1343
- tmp = perm [- 1 ]; perm [- 1 ] = perm [- 2 ]; perm [- 2 ] = tmp
1356
+ tmp = perm [- 1 ]
1357
+ perm [- 1 ] = perm [- 2 ]
1358
+ perm [- 2 ] = tmp
1344
1359
transpose = ctx .insert_new_node_on_input (node , "Transpose" , node .input [1 ], perm = perm )
1345
1360
nodes .insert (0 , transpose )
1346
1361
0 commit comments