Skip to content

Commit aaaea95

Browse files
authored
Fix None shape error when one input to ConcatV2 has a shape. (#2135)
* Fix None shape error when one input to ConcatV2 has a shape. I tried to convert a network where one of the inputs to a concatenation along dimension -1 had shape None. The other input did however have a shape. The conversion failed because the code only looked att the shape of the first input to determine what positive axis value to use in the concatenation. If the order of the inputs had been reversed, the conversion would have worked. I have now changed the code to look at the shapes of both input nodes. With the new code, I can convert the network. I have also verified that the resulting onnx-file works. Signed-off-by: Klas Magnusson <[email protected]> --------- Signed-off-by: Klas Magnusson <[email protected]>
1 parent 8f8d49a commit aaaea95

File tree

2 files changed

+24
-2
lines changed

2 files changed

+24
-2
lines changed

tests/test_backend.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1685,6 +1685,23 @@ def func(x1, x2, x3):
16851685
return tf.identity(x_, name=_TFOUTPUT)
16861686
self._run_test_case(func, [_OUTPUT], {_INPUT: x_val1, _INPUT1: x_val2, "input3:0": x_val3})
16871687

1688+
def test_concat_negative_axis_none_shape(self):
1689+
x_val = np.array([1.0, 2.0, 3.0, 4.0, 5.0, 6.0], dtype=np.float32).reshape((2, 3))
1690+
y_val = np.array([7.0, 8.0, 9.0, 10.0, 11.0, 12.0], dtype=np.float32).reshape((2, 3))
1691+
s1_val = np.array([1, 1], dtype=np.int32)
1692+
s2_val = np.array([1, 1], dtype=np.int32)
1693+
def func():
1694+
x = tf_placeholder(tf.float32, [2, 3], name=_TFINPUT)
1695+
y = tf_placeholder(tf.float32, [2, 3], name=_TFINPUT1)
1696+
s1 = tf_placeholder(tf.int32, [2], name="input3")
1697+
s2 = tf_placeholder(tf.int32, [2], name="input4")
1698+
s = tf.add(s1, s2)
1699+
x_with_none_shape = tf.slice(x, [0, 0], s)
1700+
t = tf.concat([x_with_none_shape, y], -1)
1701+
return tf.identity(t, name=_TFOUTPUT)
1702+
self._run_test_case(func, [_OUTPUT], {_INPUT: x_val, _INPUT1: y_val, "input3:0": s1_val, "input4:0": s2_val},
1703+
as_session=True, premade_placeholders=True)
1704+
16881705
def test_concat_const_string(self):
16891706
x_val1 = np.array([["Hello world", "abc"], ["def", "♦♥♠♣"]], dtype=str)
16901707
const_val = np.array([["Hello there", "wxyz"], ["", "π"]], dtype=str)

tf2onnx/onnx_opset/tensor.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -297,8 +297,13 @@ def version_1(cls, ctx, node, **kwargs):
297297
ctx.remove_input(node, node.input[-1], len(node.input) - 1)
298298

299299
if axis_val < 0: # onnxruntime does not support -1 axis, but TF supports.
300-
input_shape = ctx.get_shape(node.input[0])
301-
utils.make_sure(input_shape is not None, "shape of {} is None".format(node.input[0]))
300+
input_shape = None
301+
for node_input in node.input:
302+
input_shape = ctx.get_shape(node_input)
303+
if input_shape is not None:
304+
break
305+
utils.make_sure(input_shape is not None,
306+
"the shapes of the following inputs are None: {}".format(', '.join(node.input)))
302307
axis_val = len(input_shape) + axis_val
303308
node.set_attr("axis", axis_val)
304309

0 commit comments

Comments
 (0)