Skip to content

Commit da5a034

Browse files
author
wayuanho
committed
add tests and fix bugs for some cases
1 parent 20f081b commit da5a034

File tree

4 files changed

+71
-7
lines changed

4 files changed

+71
-7
lines changed

tests/test_backend.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -797,6 +797,35 @@ def test_concat(self):
797797
_ = tf.identity(x_, name=_TFOUTPUT)
798798
self._run_test_case([_OUTPUT], {_INPUT: x_val1, _INPUT1: x_val2, "input3:0": x_val3})
799799

800+
def test_concat_empty_const_input(self):
801+
x_val1 = np.array([1, 2, 3], dtype=np.float32)
802+
x_val2 = np.array([], dtype=np.float32)
803+
x1 = tf.placeholder(tf.float32, x_val1.shape, name=_TFINPUT)
804+
x2 = tf.constant(x_val2, dtype=tf.float32)
805+
x_ = tf.concat([x1, x2], 0)
806+
_ = tf.identity(x_, name=_TFOUTPUT)
807+
self._run_test_case([_OUTPUT], {_INPUT: x_val1})
808+
809+
tf.reset_default_graph()
810+
x_val1 = np.array([[1, 2, 3]], dtype=np.float32)
811+
x_val2 = np.array([[]], dtype=np.float32)
812+
x1 = tf.placeholder(tf.float32, x_val1.shape, name=_TFINPUT)
813+
x2 = tf.constant(x_val2, dtype=tf.float32)
814+
x_ = tf.concat([x1, x2], 1)
815+
_ = tf.identity(x_, name=_TFOUTPUT)
816+
self._run_test_case([_OUTPUT], {_INPUT: x_val1})
817+
818+
tf.reset_default_graph()
819+
x_val1 = np.array([1, 2, 3], dtype=np.float32)
820+
x_val2 = np.array([], dtype=np.float32)
821+
x_val3 = np.array([13, 14, 15], dtype=np.float32)
822+
x1 = tf.placeholder(tf.float32, x_val1.shape, name=_TFINPUT)
823+
x2 = tf.constant(x_val2, dtype=tf.float32)
824+
x3 = tf.placeholder(tf.float32, x_val3.shape, name=_TFINPUT1)
825+
x_ = tf.concat([x1, x2, x3], 0)
826+
_ = tf.identity(x_, name=_TFOUTPUT)
827+
self._run_test_case([_OUTPUT], {_INPUT: x_val1, _INPUT1: x_val3})
828+
800829
@check_opset_min_version(6, "cast")
801830
def test_concat_int64(self):
802831
x_val1 = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.int64)

tests/test_internals.py

Lines changed: 33 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,9 @@
1212
from collections import namedtuple
1313

1414
import graphviz as gv
15+
import numpy as np
1516
from onnx import TensorProto
16-
from onnx import helper
17+
from onnx import helper, numpy_helper
1718

1819
import tensorflow as tf
1920
from tf2onnx import utils
@@ -247,6 +248,37 @@ def test_node_attr_onnx(self):
247248
self.assertTrue("my_attr" in n1.attr)
248249
self.assertTrue("my_attr" in n1.attr_onnx)
249250

251+
def test_tensor_data(self):
252+
tensors = {
253+
"empty_tensor": np.array([], dtype=np.float32),
254+
"multi_dim_empty_tensor": np.array([[], []], dtype=np.float32),
255+
"scalar": np.array(1., dtype=np.float32),
256+
"one_item_array": np.array([1.], dtype=np.float32),
257+
"normal_array": np.array([[1., 2.], [2., 3.]], dtype=np.float32)
258+
}
259+
tf.reset_default_graph()
260+
with tf.Session() as sess:
261+
for n, data in tensors.items():
262+
tf.constant(data, dtype=tf.float32, name=n)
263+
264+
for tf_node in sess.graph.get_operations():
265+
name = tf_node.name
266+
self.assertTrue(name in tensors.keys())
267+
268+
self.assertTrue("value" in tf_node.node_def.attr)
269+
# convert to onnx tensor value
270+
tensor_value = utils.tf_to_onnx_tensor(
271+
utils.get_tf_node_attr(tf_node, "value"),
272+
name=utils.port_name(tf_node.name)
273+
)
274+
attr = helper.make_attribute("value", tensor_value)
275+
# same as node.get_tensor_value(is_list=False)
276+
actual = numpy_helper.to_array(helper.get_attribute_value(attr))
277+
278+
expected = tensors[name]
279+
280+
self.assertTrue(np.array_equal(expected, actual))
281+
250282

251283
if __name__ == '__main__':
252284
unittest_main()

tf2onnx/tfonnx.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -769,8 +769,11 @@ def concatv2_op(ctx, node, name, args):
769769
# if any input is empty, remove the input and concat the others
770770
# NOTE: workaround for https://github.com/Microsoft/onnxruntime/issues/681
771771
for i, inp in enumerate(node.inputs):
772-
if inp.is_const() and inp.get_tensor_value() == []:
772+
if inp.is_const() and inp.get_tensor_value(as_list=False).size == 0:
773773
ctx.remove_input(node, node.input[i])
774+
# all inputs are deleted
775+
if not node.input:
776+
raise RuntimeError("all inputs of {} are empty".format(name))
774777

775778
axis_node = node.inputs[-1]
776779
axis_val = axis_node.get_tensor_value()
@@ -1175,7 +1178,7 @@ def minmax_op(ctx, node, name, args):
11751178
ctx.copy_shape(inp, inp_cast.output[0])
11761179
ctx.set_dtype(inp_cast.output[0], target_dtype)
11771180
origin_dtype = ctx.get_dtype(node.output[0])
1178-
utils.make_sure(origin_dtype, "dtype of {} is None".format(node.output[0]))
1181+
utils.make_sure(origin_dtype is not None, "dtype of {} is None".format(node.output[0]))
11791182
ctx.set_dtype(node.output[0], target_dtype)
11801183
cast_name = utils.make_name(name)
11811184
cast_node = ctx.insert_new_node_on_output("Cast", node.output[0], name=cast_name, to=origin_dtype)

tf2onnx/utils.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -131,18 +131,18 @@ def tf_to_onnx_tensor(tensor, name=""):
131131
tensor_content: empty
132132
tensor_shape.dim: [0]
133133
DTYPE_val: 1
134-
3. empty tensor, e.g., np.array([], dtype=DTYPE):
134+
3. empty tensor, e.g., np.array([], dtype=DTYPE) and np.array([[]], dtype=DTYPE):
135135
tensor_content: empty
136-
tensor_shape.dim: [0]
136+
tensor_shape.dim: [0] and [1, 0]
137137
DTYPE_val: empty
138138
"""
139139
new_type = TF_TO_ONNX_DTYPE[tensor.dtype]
140140
tdim = tensor.tensor_shape.dim
141141
dims = [d.size for d in tdim]
142142
is_raw, data = get_tf_tensor_data(tensor)
143143
# empty tensor
144-
if dims == [0] and not is_raw and data is None:
145-
np_data = np.array([], dtype=map_onnx_to_numpy_type(new_type))
144+
if not is_raw and data is None:
145+
np_data = np.array([], dtype=map_onnx_to_numpy_type(new_type)).reshape(dims)
146146
return numpy_helper.from_array(np_data, name=name)
147147
make_sure(data, "tensor data isn't expected to be None or empty")
148148
# scalar tensor

0 commit comments

Comments
 (0)