@@ -1457,13 +1457,15 @@ def fill_op(ctx, node, name, args):
1457
1457
# In onnx the value is an attribute so we need to fetch the value as const which
1458
1458
# sooner or later will be a problem for tensorflow-onnx.
1459
1459
# ConstantOfShape in onnxruntime only support int64, so insert cast op
1460
- cast_node = ctx .insert_new_node_on_input (node , "Cast" , node .input [0 ], to = onnx_pb .TensorProto .INT64 )
1460
+ input_dtype_is_int64 = utils .ONNX_TO_NUMPY_DTYPE [ctx .get_dtype (node .input [0 ])] == np .int64
1461
+ if not input_dtype_is_int64 :
1462
+ cast_node = ctx .insert_new_node_on_input (node , "Cast" , node .input [0 ], to = onnx_pb .TensorProto .INT64 )
1461
1463
dtype = ctx .get_dtype (node .output [0 ])
1462
1464
value = np .array ([node .inputs [1 ].get_tensor_value ()]).astype (utils .ONNX_TO_NUMPY_DTYPE [dtype ])
1463
1465
value_proto = numpy_helper .from_array (value )
1464
1466
node .set_attr ("value" , value_proto )
1465
1467
del node .input [1 ]
1466
- return [node , cast_node ]
1468
+ return [node ] if input_dtype_is_int64 else [ node , cast_node ]
1467
1469
1468
1470
1469
1471
def reverse_op8 (ctx , node , name , args ):
@@ -1716,9 +1718,12 @@ def logical_compare_op(ctx, node, name, args):
1716
1718
def where_op (ctx , node , name , args ):
1717
1719
# T_y output = Where(T_x condition), return indices of elements whose value are True
1718
1720
node .type = "NonZero"
1721
+ # in onnx, indices are returned in this way [[ind_a_0, ind_b_0, ...], [ind_a_1, ind_b_1,...]];
1722
+ # while in tf, the result will be [[ind_a_0, ind_a_1, ...], [ind_b_0, ind_b_1, ...], ...]
1723
+ # this is the reason a transpose node inserted here.
1719
1724
transpose_node = ctx .insert_new_node_on_output ("Transpose" , node .output [0 ], name = utils .make_name ("where_op_added" ))
1720
- ctx .set_shape ( transpose_node .output [0 ], ctx . get_shape ( node . output [0 ]) )
1721
- ctx .set_dtype ( transpose_node .output [0 ], ctx . get_dtype ( node . output [0 ]) )
1725
+ ctx .copy_shape ( node .output [0 ], transpose_node . output [0 ])
1726
+ ctx .copy_dtype ( node .output [0 ], transpose_node . output [0 ])
1722
1727
return [node , transpose_node ]
1723
1728
1724
1729
0 commit comments