Skip to content

Commit d5d990e

Browse files
Merge pull request #1119 from onnx/tom/AddCastAfterSize
Add cast after Size if expected type is not int64
2 parents c9b4219 + 359e69c commit d5d990e

File tree

1 file changed

+10
-1
lines changed

1 file changed

+10
-1
lines changed

tf2onnx/onnx_opset/tensor.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,16 @@ def _wrap_concat_with_cast(ctx, node):
6161
class Size:
6262
@classmethod
6363
def version_1(cls, ctx, node, **kwargs):
64-
ctx.set_dtype(node.output[0], onnx_pb.TensorProto.INT64)
64+
output_name = node.output[0]
65+
dtype = ctx.get_dtype(output_name)
66+
# TF size can output int32 or int64 but onnx only does int 64
67+
if dtype != onnx_pb.TensorProto.INT64:
68+
ctx.set_dtype(output_name, onnx_pb.TensorProto.INT64)
69+
output_cast = ctx.insert_new_node_on_output("Cast", output_name, name=node.child_name(),
70+
to=dtype)
71+
ctx.set_dtype(output_cast.output[0], dtype)
72+
ctx.copy_shape(output_name, output_cast.output[0])
73+
6574

6675

6776
@tf_op("Flatten")

0 commit comments

Comments
 (0)