Skip to content

Commit cee777a

Browse files
committed
make sure shape input for reshape() is int64
1 parent a9f59dc commit cee777a

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

tf2onnx/tfonnx.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -362,7 +362,7 @@ def calc_shape(a, b):
362362
# new reshape takes new shape as input[1]
363363
op_name = utils.make_name(node.name)
364364
shape_name = utils.make_name(node.name)
365-
shape_node = ctx.make_const(shape_name, "Const", np.array(new_kernel_shape))
365+
shape_node = ctx.make_const(shape_name, "Const", np.array(new_kernel_shape, dtype=np.int64))
366366
input_name = node.input[1]
367367
reshape = ctx.insert_new_node_on_input(node, "Reshape", input_name, name=op_name)
368368
reshape.input.append(shape_name)

0 commit comments

Comments
 (0)