Skip to content

Commit e116e4d

Browse files
committed
Reorder some logic
1 parent 39a7280 commit e116e4d

File tree

1 file changed

+6
-3
lines changed

1 file changed

+6
-3
lines changed

tf2onnx/onnx_opset/tensor.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1064,16 +1064,19 @@ def version_1(cls, ctx, node, **kwargs):
10641064
axis += len(shape)
10651065
# split the tensor into n outputs
10661066
node.type = "Split"
1067-
output_shape = ctx.get_shape(node.output[0])
1068-
if output_shape:
1069-
ctx.set_shape(node.output[0], output_shape.insert(axis, 1))
1067+
10701068
# for each output we need to squeeze axis
10711069
for n in node.output:
10721070
op_name = utils.make_name(node.name)
10731071
squeeze_node = ctx.insert_new_node_on_output("Squeeze", n, name=op_name, axes=[axis])
10741072
ctx.copy_shape(n, squeeze_node.output[0])
10751073
ctx.copy_dtype(n, squeeze_node.output[0])
10761074

1075+
# split node is 1 rank higher than squeeze nodes
1076+
output_shape = ctx.get_shape(node.output[0])
1077+
if output_shape:
1078+
ctx.set_shape(node.output[0], output_shape.insert(axis, 1))
1079+
10771080

10781081
@tf_op("OneHot")
10791082
class OneHot:

0 commit comments

Comments
 (0)