Skip to content

Commit 9d65c68

Browse files
committed
fix : fix axis for converting nhwc2nchw in onnx_opset
Signed-off-by: dongryeol.lee <[email protected]>
1 parent 4fed7de commit 9d65c68

File tree

1 file changed

+3
-1
lines changed

1 file changed

+3
-1
lines changed

tf2onnx/onnx_opset/tensor.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -716,9 +716,11 @@ def version_1(cls, ctx, node, **kwargs):
716716
# T output = Split(int32 split_dim, T value, @int num_split)
717717
# T outputs = Split(T input, @INT axis, @INTS split)
718718
split_dims = node.inputs[0].get_tensor_value()
719+
new_split_dims = split_dims + len(node.output_shapes[0]) if split_dims < 0 else split_dims
720+
new_split_dims = 1 if new_split_dims == 3 else new_split_dims
719721
ctx.remove_input(node, node.input[0], 0)
720722
node.set_attr("num_outputs", node.get_attr_int("num_split"))
721-
node.set_attr("axis", split_dims)
723+
node.set_attr("axis", new_split_dims)
722724

723725
@classmethod
724726
def version_2(cls, ctx, node, **kwargs):

0 commit comments

Comments
 (0)