Skip to content

Commit e9569f4

Browse files
committed
expanddims_op get input rank only when necessary
1 parent 7e5dc5e commit e9569f4

File tree

1 file changed

+3
-2
lines changed

1 file changed

+3
-2
lines changed

tf2onnx/tfonnx.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -881,9 +881,10 @@ def expanddims_op(ctx, node, name, args):
881881
dim_node = node.inputs[1]
882882
if dim_node.is_const():
883883
node.type = "Unsqueeze"
884-
input_rank = len(ctx.get_shape(node.input[0]))
885884
dim = dim_node.get_tensor_value()
886-
dim = dim + input_rank + 1 if dim < 0 else dim
885+
if dim < 0:
886+
input_rank = len(ctx.get_shape(node.input[0]))
887+
dim = dim + input_rank + 1
887888
node.set_attr("axes", [dim])
888889
ctx.remove_input(node, node.input[1])
889890
return

0 commit comments

Comments
 (0)