Skip to content

Commit 97d6f4f

Browse files
committed
code refactor
1 parent c3f04d4 commit 97d6f4f

File tree

1 file changed

+9
-4
lines changed

1 file changed

+9
-4
lines changed

tf2onnx/tfonnx.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1457,13 +1457,15 @@ def fill_op(ctx, node, name, args):
14571457
# In onnx the value is an attribute so we need to fetch the value as const which
14581458
# sooner or later will be a problem for tensorflow-onnx.
14591459
# ConstantOfShape in onnxruntime only support int64, so insert cast op
1460-
cast_node = ctx.insert_new_node_on_input(node, "Cast", node.input[0], to=onnx_pb.TensorProto.INT64)
1460+
input_dtype_is_int64 = utils.ONNX_TO_NUMPY_DTYPE[ctx.get_dtype(node.input[0])] == np.int64
1461+
if not input_dtype_is_int64:
1462+
cast_node = ctx.insert_new_node_on_input(node, "Cast", node.input[0], to=onnx_pb.TensorProto.INT64)
14611463
dtype = ctx.get_dtype(node.output[0])
14621464
value = np.array([node.inputs[1].get_tensor_value()]).astype(utils.ONNX_TO_NUMPY_DTYPE[dtype])
14631465
value_proto = numpy_helper.from_array(value)
14641466
node.set_attr("value", value_proto)
14651467
del node.input[1]
1466-
return [node, cast_node]
1468+
return [node] if input_dtype_is_int64 else [node, cast_node]
14671469

14681470

14691471
def reverse_op8(ctx, node, name, args):
@@ -1716,9 +1718,12 @@ def logical_compare_op(ctx, node, name, args):
17161718
def where_op(ctx, node, name, args):
17171719
# T_y output = Where(T_x condition), return indices of elements whose value are True
17181720
node.type = "NonZero"
1721+
# in onnx, indices are returned in this way [[ind_a_0, ind_b_0, ...], [ind_a_1, ind_b_1,...]];
1722+
# while in tf, the result will be [[ind_a_0, ind_a_1, ...], [ind_b_0, ind_b_1, ...], ...]
1723+
# this is the reason a transpose node inserted here.
17191724
transpose_node = ctx.insert_new_node_on_output("Transpose", node.output[0], name=utils.make_name("where_op_added"))
1720-
ctx.set_shape(transpose_node.output[0], ctx.get_shape(node.output[0]))
1721-
ctx.set_dtype(transpose_node.output[0], ctx.get_dtype(node.output[0]))
1725+
ctx.copy_shape(node.output[0], transpose_node.output[0])
1726+
ctx.copy_dtype(node.output[0], transpose_node.output[0])
17221727
return [node, transpose_node]
17231728

17241729

0 commit comments

Comments
 (0)