Skip to content

Commit c6e55d6

Browse files
authored
Merge pull request #398 from lucienwang1009/pr_for_new_model
fix some bugs
2 parents 4a20e9e + da5a034 commit c6e55d6

File tree

6 files changed

+145
-21
lines changed

6 files changed

+145
-21
lines changed

tests/run_pretrained_models.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
import time
1616
import traceback
1717
import zipfile
18+
import logging
1819

1920
import PIL.Image
2021
import numpy as np
@@ -34,6 +35,9 @@
3435

3536
# pylint: disable=broad-except,logging-not-lazy,unused-argument,unnecessary-lambda
3637

38+
logging.basicConfig(level=logging.INFO)
39+
log = logging.getLogger("tf2onnx")
40+
3741
TEMP_DIR = os.path.join(utils.get_temp_directory(), "run_pretrained")
3842
PERFITER = 1000
3943

@@ -246,9 +250,10 @@ def run_test(self, name, backend="caffe2", debug=False, onnx_file=None, opset=No
246250
for k in inputs.keys(): # pylint: disable=consider-iterating-dictionary
247251
t = sess.graph.get_tensor_by_name(k)
248252
dtype = tf.as_dtype(t.dtype).name
249-
if type != "float32":
250-
v = inputs[k]
251-
inputs[k] = v.astype(dtype)
253+
v = inputs[k]
254+
if dtype != v.dtype:
255+
log.warning("input dtype doesn't match tensorflow's")
256+
inputs[k] = np.array(v, dtype=dtype)
252257
if self.force_input_shape:
253258
for k, v in inputs.items():
254259
shape_override[k] = list(v.shape)

tests/test_backend.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -596,6 +596,15 @@ def test_min(self):
596596
_ = tf.identity(mi, name=_TFOUTPUT)
597597
self._run_test_case([_OUTPUT], {_INPUT: x_val1, _INPUT1: x_val2})
598598

599+
tf.reset_default_graph()
600+
x_val1 = np.array([4.0, 16.0, 4.0, 1.6], dtype=np.int32).reshape((2, 2))
601+
x_val2 = np.array([4.0, 4.0, 4.0, 4.0], dtype=np.int32).reshape((2, 2))
602+
x1 = tf.placeholder(tf.int32, x_val1.shape, name=_TFINPUT)
603+
x2 = tf.placeholder(tf.int32, x_val2.shape, name=_TFINPUT1)
604+
mi = tf.minimum(x1, x2)
605+
_ = tf.identity(mi, name=_TFOUTPUT)
606+
self._run_test_case([_OUTPUT], {_INPUT: x_val1, _INPUT1: x_val2})
607+
599608
@skip_caffe2_backend("issue with broadcasting scalar")
600609
@check_onnxruntime_incompatibility("Sub")
601610
def test_min_broadcast(self):
@@ -788,6 +797,35 @@ def test_concat(self):
788797
_ = tf.identity(x_, name=_TFOUTPUT)
789798
self._run_test_case([_OUTPUT], {_INPUT: x_val1, _INPUT1: x_val2, "input3:0": x_val3})
790799

800+
def test_concat_empty_const_input(self):
801+
x_val1 = np.array([1, 2, 3], dtype=np.float32)
802+
x_val2 = np.array([], dtype=np.float32)
803+
x1 = tf.placeholder(tf.float32, x_val1.shape, name=_TFINPUT)
804+
x2 = tf.constant(x_val2, dtype=tf.float32)
805+
x_ = tf.concat([x1, x2], 0)
806+
_ = tf.identity(x_, name=_TFOUTPUT)
807+
self._run_test_case([_OUTPUT], {_INPUT: x_val1})
808+
809+
tf.reset_default_graph()
810+
x_val1 = np.array([[1, 2, 3]], dtype=np.float32)
811+
x_val2 = np.array([[]], dtype=np.float32)
812+
x1 = tf.placeholder(tf.float32, x_val1.shape, name=_TFINPUT)
813+
x2 = tf.constant(x_val2, dtype=tf.float32)
814+
x_ = tf.concat([x1, x2], 1)
815+
_ = tf.identity(x_, name=_TFOUTPUT)
816+
self._run_test_case([_OUTPUT], {_INPUT: x_val1})
817+
818+
tf.reset_default_graph()
819+
x_val1 = np.array([1, 2, 3], dtype=np.float32)
820+
x_val2 = np.array([], dtype=np.float32)
821+
x_val3 = np.array([13, 14, 15], dtype=np.float32)
822+
x1 = tf.placeholder(tf.float32, x_val1.shape, name=_TFINPUT)
823+
x2 = tf.constant(x_val2, dtype=tf.float32)
824+
x3 = tf.placeholder(tf.float32, x_val3.shape, name=_TFINPUT1)
825+
x_ = tf.concat([x1, x2, x3], 0)
826+
_ = tf.identity(x_, name=_TFOUTPUT)
827+
self._run_test_case([_OUTPUT], {_INPUT: x_val1, _INPUT1: x_val3})
828+
791829
@check_opset_min_version(6, "cast")
792830
def test_concat_int64(self):
793831
x_val1 = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.int64)

tests/test_internals.py

Lines changed: 33 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,9 @@
1212
from collections import namedtuple
1313

1414
import graphviz as gv
15+
import numpy as np
1516
from onnx import TensorProto
16-
from onnx import helper
17+
from onnx import helper, numpy_helper
1718

1819
import tensorflow as tf
1920
from tf2onnx import utils
@@ -247,6 +248,37 @@ def test_node_attr_onnx(self):
247248
self.assertTrue("my_attr" in n1.attr)
248249
self.assertTrue("my_attr" in n1.attr_onnx)
249250

251+
def test_tensor_data(self):
252+
tensors = {
253+
"empty_tensor": np.array([], dtype=np.float32),
254+
"multi_dim_empty_tensor": np.array([[], []], dtype=np.float32),
255+
"scalar": np.array(1., dtype=np.float32),
256+
"one_item_array": np.array([1.], dtype=np.float32),
257+
"normal_array": np.array([[1., 2.], [2., 3.]], dtype=np.float32)
258+
}
259+
tf.reset_default_graph()
260+
with tf.Session() as sess:
261+
for n, data in tensors.items():
262+
tf.constant(data, dtype=tf.float32, name=n)
263+
264+
for tf_node in sess.graph.get_operations():
265+
name = tf_node.name
266+
self.assertTrue(name in tensors.keys())
267+
268+
self.assertTrue("value" in tf_node.node_def.attr)
269+
# convert to onnx tensor value
270+
tensor_value = utils.tf_to_onnx_tensor(
271+
utils.get_tf_node_attr(tf_node, "value"),
272+
name=utils.port_name(tf_node.name)
273+
)
274+
attr = helper.make_attribute("value", tensor_value)
275+
# same as node.get_tensor_value(is_list=False)
276+
actual = numpy_helper.to_array(helper.get_attribute_value(attr))
277+
278+
expected = tensors[name]
279+
280+
self.assertTrue(np.array_equal(expected, actual))
281+
250282

251283
if __name__ == '__main__':
252284
unittest_main()

tf2onnx/graph.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -908,7 +908,7 @@ def insert_new_node_on_input(self, node, op_type, input_name, name=None, domain=
908908
break
909909
return new_node
910910

911-
def insert_new_node_on_output(self, op_type, output_name, name=None, domain=None, **kwargs):
911+
def insert_new_node_on_output(self, op_type, output_name, name, domain=None, **kwargs):
912912
"""Create and insert a new node into the graph.
913913
Args:
914914
op_type: type for new operation

tf2onnx/tfonnx.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -767,6 +767,15 @@ def concat_op(ctx, node, name, args):
767767
def concatv2_op(ctx, node, name, args):
768768
# T output = ConcatV2(T values, Tidx axis, @int N, @type Tidx)
769769
# T concat_result = Concat(T inputs, @INT axis)
770+
# if any input is empty, remove the input and concat the others
771+
# NOTE: workaround for https://github.com/Microsoft/onnxruntime/issues/681
772+
for i, inp in enumerate(node.inputs):
773+
if inp.is_const() and inp.get_tensor_value(as_list=False).size == 0:
774+
ctx.remove_input(node, node.input[i])
775+
# all inputs are deleted
776+
if not node.input:
777+
raise RuntimeError("all inputs of {} are empty".format(name))
778+
770779
axis_node = node.inputs[-1]
771780
axis_val = axis_node.get_tensor_value()
772781
ctx.remove_input(node, node.input[-1])
@@ -1155,6 +1164,28 @@ def minmax_op(ctx, node, name, args):
11551164
# handle this by doing something like:
11561165
# y = min(x1, add(x2, sub(x1, x1))), where x1, x2 are the inputs and x2 is a scalar
11571166
# this will create a tensor of zeros of the shape of x1, adds x2 to it (which broadcasts) and use that for min.
1167+
# support more dtype
1168+
supported_dtypes = [
1169+
onnx_pb.TensorProto.FLOAT,
1170+
onnx_pb.TensorProto.FLOAT16,
1171+
onnx_pb.TensorProto.DOUBLE
1172+
]
1173+
target_dtype = onnx_pb.TensorProto.FLOAT
1174+
for inp in node.input:
1175+
dtype = ctx.get_dtype(inp)
1176+
utils.make_sure(dtype, "dtype of {} is None".format(inp))
1177+
if dtype not in supported_dtypes:
1178+
inp_cast = ctx.insert_new_node_on_input(node, "Cast", inp, to=target_dtype)
1179+
ctx.copy_shape(inp, inp_cast.output[0])
1180+
ctx.set_dtype(inp_cast.output[0], target_dtype)
1181+
origin_dtype = ctx.get_dtype(node.output[0])
1182+
utils.make_sure(origin_dtype is not None, "dtype of {} is None".format(node.output[0]))
1183+
ctx.set_dtype(node.output[0], target_dtype)
1184+
cast_name = utils.make_name(name)
1185+
cast_node = ctx.insert_new_node_on_output("Cast", node.output[0], name=cast_name, to=origin_dtype)
1186+
to_replace = [n for n in ctx.get_nodes() if n != cast_node]
1187+
ctx.replace_all_inputs(to_replace, node.output[0], cast_node.output[0])
1188+
11581189
shapeo = ctx.get_shape(node.output[0])
11591190
needs_broadcast_op = []
11601191
has_correct_shape = []
@@ -1677,6 +1708,7 @@ def where_op(ctx, node, name, args):
16771708
"BiasAdd": (biasadd_op, []),
16781709
"BiasAddV1": (biasadd_op, []),
16791710
"Cast": (cast_op, []),
1711+
"CheckNumerics": (identity_op, ["Identity"]),
16801712
"Concat": (concat_op, ["Concat"]),
16811713
"ConcatV2": (concatv2_op, ["Concat"]),
16821714
"Const": (direct_op, []),

tf2onnx/utils.py

Lines changed: 33 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -120,21 +120,39 @@ def split_nodename_and_shape(name):
120120

121121

122122
def tf_to_onnx_tensor(tensor, name=""):
123-
"""Convert tensorflow tensor to onnx tensor."""
123+
"""
124+
Convert tensorflow tensor to onnx tensor.
125+
Here deal with three types of tensor:
126+
1. normal tensor, e.g., np.array([1,2,3], dtype=DTYPE):
127+
tensor_content: raw data of [1,2,3]
128+
tensor_shape.dim: [3]
129+
DTYPE_val: empty
130+
2. scalar tensor, e.g., np.array(1, dtype=DTYPE):
131+
tensor_content: empty
132+
tensor_shape.dim: [0]
133+
DTYPE_val: 1
134+
3. empty tensor, e.g., np.array([], dtype=DTYPE) and np.array([[]], dtype=DTYPE):
135+
tensor_content: empty
136+
tensor_shape.dim: [0] and [1, 0]
137+
DTYPE_val: empty
138+
"""
124139
new_type = TF_TO_ONNX_DTYPE[tensor.dtype]
125140
tdim = tensor.tensor_shape.dim
126141
dims = [d.size for d in tdim]
127-
# FIXME: something is fishy here
128-
if dims == [0]:
129-
dims = [1]
130142
is_raw, data = get_tf_tensor_data(tensor)
143+
# empty tensor
144+
if not is_raw and data is None:
145+
np_data = np.array([], dtype=map_onnx_to_numpy_type(new_type)).reshape(dims)
146+
return numpy_helper.from_array(np_data, name=name)
147+
make_sure(data, "tensor data isn't expected to be None or empty")
148+
# scalar tensor
149+
if dims == [0] and not is_raw and len(data) == 1:
150+
return helper.make_tensor(name, new_type, [], data, False)
131151
if not is_raw and len(data) == 1 and np.prod(dims) > 1:
132152
batch_data = np.zeros(dims, dtype=map_onnx_to_numpy_type(new_type))
133153
batch_data.fill(data[0])
134-
onnx_tensor = numpy_helper.from_array(batch_data, name=name)
135-
else:
136-
onnx_tensor = helper.make_tensor(name, new_type, dims, data, is_raw)
137-
return onnx_tensor
154+
return numpy_helper.from_array(batch_data, name=name)
155+
return helper.make_tensor(name, new_type, dims, data, is_raw)
138156

139157

140158
def get_tf_tensor_data(tensor):
@@ -154,16 +172,15 @@ def get_tf_tensor_data(tensor):
154172
data = tensor.int64_val
155173
elif tensor.bool_val:
156174
data = tensor.bool_val
157-
elif tensor.dtype == tf.int32:
158-
data = [0]
159-
elif tensor.dtype == tf.int64:
160-
data = [0]
161-
elif tensor.dtype == tf.float32:
162-
data = [0.]
163-
elif tensor.dtype == tf.float16:
164-
data = [0]
165175
elif tensor.string_val:
166176
data = tensor.string_val
177+
elif tensor.dtype in [
178+
tf.int32,
179+
tf.int64,
180+
tf.float32,
181+
tf.float16
182+
]:
183+
data = None
167184
else:
168185
raise ValueError('tensor data not supported')
169186
return [is_raw, data]

0 commit comments

Comments
 (0)