Skip to content

Commit 39a7280

Browse files
committed
Unpack operator: fix incorrect shape
1 parent 3bf9eae commit 39a7280

File tree

1 file changed

+3
-0
lines changed

1 file changed

+3
-0
lines changed

tf2onnx/onnx_opset/tensor.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1064,6 +1064,9 @@ def version_1(cls, ctx, node, **kwargs):
10641064
axis += len(shape)
10651065
# split the tensor into n outputs
10661066
node.type = "Split"
1067+
output_shape = ctx.get_shape(node.output[0])
1068+
if output_shape:
1069+
ctx.set_shape(node.output[0], output_shape.insert(axis, 1))
10671070
# for each output we need to squeeze axis
10681071
for n in node.output:
10691072
op_name = utils.make_name(node.name)

0 commit comments

Comments
 (0)