Skip to content

Commit 9fedcc2

Browse files
authored
Merge pull request #956 from onnx/gs/fix-exapanddims
dims can be a list
2 parents 91d0d78 + 60d537c commit 9fedcc2

File tree

2 files changed

+10
-3
lines changed

2 files changed

+10
-3
lines changed

tests/test_backend.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -184,15 +184,20 @@ def test_expand_dims_known_rank(self):
184184
def test_expand_dims_one_unknown_rank(self):
185185
x_val = make_xval([3, 4])
186186
def func(x):
187-
# FIXME: this was tf_placeholder(tf.float32, shape=[None, 4], name=_TFINPUT)
188187
op = tf.expand_dims(x, 0)
189188
return tf.identity(op, name=_TFOUTPUT)
190189
self._run_test_case(func, [_OUTPUT], {_INPUT: x_val})
191190

191+
def test_expand_dims_with_list(self):
192+
x_val = make_xval([3, 4])
193+
def func(x):
194+
op = tf.expand_dims(x, [0])
195+
return tf.identity(op, name=_TFOUTPUT)
196+
self._run_test_case(func, [_OUTPUT], {_INPUT: x_val})
197+
192198
def _test_expand_dims_more_unknown_rank(self, idx):
193199
x_val = make_xval([3, 4])
194200
def func(x):
195-
# FIXME: this was tf_placeholder(tf.float32, shape=[None, None], name=_TFINPUT)
196201
op = tf.expand_dims(x, idx)
197202
return tf.identity(op, name=_TFOUTPUT)
198203
self._run_test_case(func, [_OUTPUT], {_INPUT: x_val})

tf2onnx/onnx_opset/tensor.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -614,7 +614,6 @@ def version_7(cls, ctx, node, **kwargs):
614614
if dim_node.is_const():
615615
node.type = "Unsqueeze"
616616
dim = dim_node.get_tensor_value()
617-
# TODO: isn't this always a list ?
618617
if isinstance(dim, list):
619618
dim = dim[0]
620619
if dim < 0:
@@ -631,6 +630,9 @@ def version_11(cls, ctx, node, **kwargs):
631630
if dim_node.is_const():
632631
node.type = "Unsqueeze"
633632
dim = dim_node.get_tensor_value()
633+
if isinstance(dim, list):
634+
# tf.expanddims() wants a scalar per doc but quietly accepts a list too.
635+
dim = dim[0]
634636
node.set_attr("axes", [dim])
635637
ctx.remove_input(node, node.input[1])
636638
return

0 commit comments

Comments
 (0)