Skip to content

Commit f950f51

Browse files
committed
handle scalar and list
1 parent bf20ff1 commit f950f51

File tree

1 file changed

+3
-1
lines changed

1 file changed

+3
-1
lines changed

tf2onnx/onnx_opset/tensor.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -607,7 +607,9 @@ def version_7(cls, ctx, node, **kwargs):
607607
if dim_node.is_const():
608608
node.type = "Unsqueeze"
609609
dim = dim_node.get_tensor_value()
610-
dim = dim[0]
610+
# TODO: isn't this always a list ?
611+
if isinstance(dim, list):
612+
dim = dim[0]
611613
if dim < 0:
612614
input_rank = len(ctx.get_shape(node.input[0]))
613615
dim = dim + input_rank + 1

0 commit comments

Comments
 (0)