Skip to content

Commit f5e4ed3

Browse files
authored
Merge pull request #974 from jignparm/jignparm/fixunpack
Unpack operator: fix incorrect shape
2 parents bcf626c + e116e4d commit f5e4ed3

File tree

1 file changed

+6
-0
lines changed

1 file changed

+6
-0
lines changed

tf2onnx/onnx_opset/tensor.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1064,13 +1064,19 @@ def version_1(cls, ctx, node, **kwargs):
10641064
axis += len(shape)
10651065
# split the tensor into n outputs
10661066
node.type = "Split"
1067+
10671068
# for each output we need to squeeze axis
10681069
for n in node.output:
10691070
op_name = utils.make_name(node.name)
10701071
squeeze_node = ctx.insert_new_node_on_output("Squeeze", n, name=op_name, axes=[axis])
10711072
ctx.copy_shape(n, squeeze_node.output[0])
10721073
ctx.copy_dtype(n, squeeze_node.output[0])
10731074

1075+
# split node is 1 rank higher than squeeze nodes
1076+
output_shape = ctx.get_shape(node.output[0])
1077+
if output_shape:
1078+
ctx.set_shape(node.output[0], output_shape.insert(axis, 1))
1079+
10741080

10751081
@tf_op("OneHot")
10761082
class OneHot:

0 commit comments

Comments
 (0)