Skip to content

Commit 2af4841

Browse files
authored
Merge pull request #779 from onnx/gs/unsqueeze
get_tensor_value() returns a list
2 parents 1dffcf5 + f950f51 commit 2af4841

File tree

1 file changed

+3
-0
lines changed

1 file changed

+3
-0
lines changed

tf2onnx/onnx_opset/tensor.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -607,6 +607,9 @@ def version_7(cls, ctx, node, **kwargs):
607607
if dim_node.is_const():
608608
node.type = "Unsqueeze"
609609
dim = dim_node.get_tensor_value()
610+
# TODO: isn't this always a list ?
611+
if isinstance(dim, list):
612+
dim = dim[0]
610613
if dim < 0:
611614
input_rank = len(ctx.get_shape(node.input[0]))
612615
dim = dim + input_rank + 1

0 commit comments

Comments
 (0)