We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent 4117092 commit 60d537cCopy full SHA for 60d537c
tests/test_backend.py
@@ -184,15 +184,20 @@ def test_expand_dims_known_rank(self):
184
def test_expand_dims_one_unknown_rank(self):
185
x_val = make_xval([3, 4])
186
def func(x):
187
- # FIXME: this was tf_placeholder(tf.float32, shape=[None, 4], name=_TFINPUT)
188
op = tf.expand_dims(x, 0)
189
return tf.identity(op, name=_TFOUTPUT)
190
self._run_test_case(func, [_OUTPUT], {_INPUT: x_val})
191
+ def test_expand_dims_with_list(self):
192
+ x_val = make_xval([3, 4])
193
+ def func(x):
194
+ op = tf.expand_dims(x, [0])
195
+ return tf.identity(op, name=_TFOUTPUT)
196
+ self._run_test_case(func, [_OUTPUT], {_INPUT: x_val})
197
+
198
def _test_expand_dims_more_unknown_rank(self, idx):
199
200
- # FIXME: this was tf_placeholder(tf.float32, shape=[None, None], name=_TFINPUT)
201
op = tf.expand_dims(x, idx)
202
203
tf2onnx/onnx_opset/tensor.py
@@ -614,7 +614,6 @@ def version_7(cls, ctx, node, **kwargs):
614
if dim_node.is_const():
615
node.type = "Unsqueeze"
616
dim = dim_node.get_tensor_value()
617
- # TODO: isn't this always a list ?
618
if isinstance(dim, list):
619
dim = dim[0]
620
if dim < 0:
@@ -631,6 +630,9 @@ def version_11(cls, ctx, node, **kwargs):
631
630
632
633
+ if isinstance(dim, list):
634
+ # tf.expanddims() wants a scalar per doc but quietly accepts a list too.
635
+ dim = dim[0]
636
node.set_attr("axes", [dim])
637
ctx.remove_input(node, node.input[1])
638
return
0 commit comments