Skip to content

Commit e3cbb3f

Browse files
committed
expanddims_op7 get input rank only when necessary
1 parent ad365f6 commit e3cbb3f

File tree

1 file changed

+3
-2
lines changed

1 file changed

+3
-2
lines changed

tf2onnx/tfonnx.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -907,9 +907,10 @@ def expanddims_op7(ctx, node, name, args):
907907
dim_node = node.inputs[1]
908908
if dim_node.is_const():
909909
node.type = "Unsqueeze"
910-
input_rank = len(ctx.get_shape(node.input[0]))
911910
dim = dim_node.get_tensor_value()
912-
dim = dim + input_rank + 1 if dim < 0 else dim
911+
if dim < 0:
912+
input_rank = len(ctx.get_shape(node.input[0]))
913+
dim = dim + input_rank + 1
913914
node.set_attr("axes", [dim])
914915
ctx.remove_input(node, node.input[1])
915916
return

0 commit comments

Comments
 (0)