Skip to content

Commit 62d7d8d

Browse files
authored
Merge pull request #395 from zhijxu-MS/zhijxu/push_DT_ok
use onehot to support some ops found in real tf model and fix bug of lstm
2 parents 4022a52 + 4b770c3 commit 62d7d8d

File tree

8 files changed

+162
-68
lines changed

8 files changed

+162
-68
lines changed

tests/run_pretrained_models.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@ class Test(object):
8484
target = []
8585

8686
def __init__(self, url, local, make_input, input_names, output_names,
87-
disabled=False, more_inputs=None, rtol=0.01, atol=0.,
87+
disabled=False, more_inputs=None, rtol=0.01, atol=1e-6,
8888
check_only_shape=False, model_type="frozen", force_input_shape=False,
8989
skip_tensorflow=False):
9090
self.url = url

tests/test_backend.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1613,6 +1613,24 @@ def test_shape_int64(self):
16131613
kwargs = {"check_dtype": True}
16141614
self._run_test_case([_OUTPUT], {_INPUT: x_val}, **kwargs)
16151615

1616+
# @check_opset_min_version(7, "broadcasting op")
1617+
@unittest.skip("disable it for now, since fold const has bug")
1618+
def test_softmax_cross_entropy_with_logits(self):
1619+
num_class = 5
1620+
data_shape = [100, num_class]
1621+
for np_dtype, tf_dtype in zip([np.int32, np.int64], [tf.int32, tf.int64]):
1622+
tf.reset_default_graph()
1623+
label_val = np.random.randint(0, num_class - 1, data_shape).astype(np_dtype)
1624+
logits_val = np.random.random(data_shape).astype(np.float32)
1625+
1626+
label = tf.placeholder(tf_dtype, shape=data_shape, name=_TFINPUT)
1627+
logits = tf.placeholder(tf.float32, shape=data_shape, name=_TFINPUT1)
1628+
1629+
res1 = tf.nn.softmax_cross_entropy_with_logits_v2(labels=label, logits=logits)
1630+
_ = tf.identity(res1, name=_TFOUTPUT)
1631+
1632+
self._run_test_case([_OUTPUT], {_INPUT: label_val, _INPUT1: logits_val}, atol=1e-5)
1633+
16161634
def test_sparse_softmax_cross_entropy_with_logits(self):
16171635
num_class = 5
16181636
label_val = np.array([3, 2, 0, 4]).astype(np.int32)

tests/test_gru.py

Lines changed: 28 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -123,34 +123,36 @@ def test_single_dynamic_gru_seq_length_is_const(self):
123123
graph_validator=lambda g: check_gru_count(g, 1))
124124

125125
def test_single_dynamic_gru_seq_length_is_not_const(self):
126-
units = 5
127-
batch_size = 1
128-
x_val = np.array([[1., 1.], [2., 2.], [3., 3.], [4., 4.], [5., 5.]], dtype=np.float32)
129-
x_val = np.stack([x_val] * batch_size)
130-
x = tf.placeholder(tf.float32, x_val.shape, name="input_1")
131-
initializer = init_ops.constant_initializer(0.5)
132-
133-
y_val = np.array([5], dtype=np.int32)
134-
seq_length = tf.placeholder(tf.int32, y_val.shape, name="input_2")
135-
136-
# no scope
137-
cell = rnn.GRUCell(
138-
units,
139-
kernel_initializer=initializer)
140-
outputs, cell_state = tf.nn.dynamic_rnn(
141-
cell,
142-
x,
143-
dtype=tf.float32,
144-
sequence_length=tf.identity(seq_length))
126+
for np_dtype, tf_dtype in [[np.int32, tf.int32], [np.int64, tf.int64], [np.float32, tf.float32]]:
127+
tf.reset_default_graph()
128+
units = 5
129+
batch_size = 1
130+
x_val = np.array([[1., 1.], [2., 2.], [3., 3.], [4., 4.], [5., 5.]], dtype=np.float32)
131+
x_val = np.stack([x_val] * batch_size)
132+
x = tf.placeholder(tf.float32, x_val.shape, name="input_1")
133+
initializer = init_ops.constant_initializer(0.5)
134+
135+
y_val = np.array([5], dtype=np_dtype)
136+
seq_length = tf.placeholder(tf_dtype, y_val.shape, name="input_2")
137+
138+
# no scope
139+
cell = rnn.GRUCell(
140+
units,
141+
kernel_initializer=initializer)
142+
outputs, cell_state = tf.nn.dynamic_rnn(
143+
cell,
144+
x,
145+
dtype=tf.float32,
146+
sequence_length=tf.identity(seq_length))
145147

146-
_ = tf.identity(outputs, name="output")
147-
_ = tf.identity(cell_state, name="cell_state")
148+
_ = tf.identity(outputs, name="output")
149+
_ = tf.identity(cell_state, name="cell_state")
148150

149-
feed_dict = {"input_1:0": x_val, "input_2:0": y_val}
150-
input_names_with_port = ["input_1:0", "input_2:0"]
151-
output_names_with_port = ["output:0", "cell_state:0"]
152-
self.run_test_case(feed_dict, input_names_with_port, output_names_with_port, rtol=1e-03, atol=1e-06,
153-
graph_validator=lambda g: check_gru_count(g, 1))
151+
feed_dict = {"input_1:0": x_val, "input_2:0": y_val}
152+
input_names_with_port = ["input_1:0", "input_2:0"]
153+
output_names_with_port = ["output:0", "cell_state:0"]
154+
self.run_test_case(feed_dict, input_names_with_port, output_names_with_port, rtol=1e-03, atol=1e-06,
155+
graph_validator=lambda g: check_gru_count(g, 1))
154156

155157
def test_single_dynamic_gru_placeholder_input(self):
156158
units = 5

tests/test_lstm.py

Lines changed: 30 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -147,36 +147,38 @@ def test_single_dynamic_lstm_seq_length_is_const(self):
147147
graph_validator=lambda g: check_lstm_count(g, 1))
148148

149149
def test_single_dynamic_lstm_seq_length_is_not_const(self):
150-
units = 5
151-
batch_size = 6
152-
x_val = np.array([[1., 1.], [2., 2.], [3., 3.], [4., 4.], [5., 5.]], dtype=np.float32)
153-
x_val = np.stack([x_val] * batch_size)
154-
state_is_tuple = True
155-
x = tf.placeholder(tf.float32, x_val.shape, name="input_1")
156-
initializer = init_ops.constant_initializer(0.5)
157-
158-
y_val = np.array([4, 3, 4, 5, 2, 1], dtype=np.int32)
159-
seq_length = tf.placeholder(tf.int32, y_val.shape, name="input_2")
160-
161-
# no scope
162-
cell = rnn.LSTMCell(
163-
units,
164-
initializer=initializer,
165-
state_is_tuple=state_is_tuple)
166-
outputs, cell_state = tf.nn.dynamic_rnn(
167-
cell,
168-
x,
169-
dtype=tf.float32,
170-
sequence_length=tf.identity(seq_length))
150+
for np_dtype, tf_dtype in [[np.int32, tf.int32], [np.int64, tf.int64], [np.float32, tf.float32]]:
151+
tf.reset_default_graph()
152+
units = 5
153+
batch_size = 6
154+
x_val = np.array([[1., 1.], [2., 2.], [3., 3.], [4., 4.], [5., 5.]], dtype=np.float32)
155+
x_val = np.stack([x_val] * batch_size)
156+
state_is_tuple = True
157+
x = tf.placeholder(tf.float32, x_val.shape, name="input_1")
158+
initializer = init_ops.constant_initializer(0.5)
159+
160+
y_val = np.array([4, 3, 4, 5, 2, 1], dtype=np_dtype)
161+
seq_length = tf.placeholder(tf_dtype, y_val.shape, name="input_2")
162+
163+
# no scope
164+
cell = rnn.LSTMCell(
165+
units,
166+
initializer=initializer,
167+
state_is_tuple=state_is_tuple)
168+
outputs, cell_state = tf.nn.dynamic_rnn(
169+
cell,
170+
x,
171+
dtype=tf.float32,
172+
sequence_length=tf.identity(seq_length))
171173

172-
_ = tf.identity(outputs, name="output")
173-
_ = tf.identity(cell_state, name="cell_state")
174+
_ = tf.identity(outputs, name="output")
175+
_ = tf.identity(cell_state, name="cell_state")
174176

175-
feed_dict = {"input_1:0": x_val, "input_2:0": y_val}
176-
input_names_with_port = ["input_1:0", "input_2:0"]
177-
output_names_with_port = ["output:0", "cell_state:0"]
178-
self.run_test_case(feed_dict, input_names_with_port, output_names_with_port, rtol=1e-06,
179-
graph_validator=lambda g: check_lstm_count(g, 1))
177+
feed_dict = {"input_1:0": x_val, "input_2:0": y_val}
178+
input_names_with_port = ["input_1:0", "input_2:0"]
179+
output_names_with_port = ["output:0", "cell_state:0"]
180+
self.run_test_case(feed_dict, input_names_with_port, output_names_with_port, rtol=1e-06,
181+
graph_validator=lambda g: check_lstm_count(g, 1))
180182

181183
def test_single_dynamic_lstm_placeholder_input(self):
182184
units = 5

tf2onnx/function/__init__.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,13 +11,17 @@
1111
from tf2onnx.function.matrixbandpart import matrixbandpart_op
1212
from tf2onnx.function.range import range_op7
1313
from tf2onnx.function.select import select_op8
14-
from tf2onnx.function.sparse_softmax_cross_entropy_with_logits import sparse_softmax_cross_entropy_with_logits_op
14+
from tf2onnx.function.softmax_cross_entropy_with_logits import softmax_cross_entropy_with_logits_op7
15+
from tf2onnx.function.softmax_cross_entropy_with_logits import sparse_softmax_cross_entropy_with_logits_op7
16+
from tf2onnx.function.softmax_cross_entropy_with_logits import sparse_softmax_cross_entropy_with_logits_op9
1517

1618
__all__ = [
1719
"gathernd_op",
1820
"lstm_block_cell_op",
1921
"matrixbandpart_op",
2022
"range_op7",
2123
"select_op8",
22-
"sparse_softmax_cross_entropy_with_logits_op"
24+
"softmax_cross_entropy_with_logits_op7",
25+
"sparse_softmax_cross_entropy_with_logits_op7",
26+
"sparse_softmax_cross_entropy_with_logits_op9",
2327
]

tf2onnx/function/sparse_softmax_cross_entropy_with_logits.py renamed to tf2onnx/function/softmax_cross_entropy_with_logits.py

Lines changed: 60 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,38 @@
1212

1313
# pylint: disable=unused-argument,missing-docstring
1414

15-
def sparse_softmax_cross_entropy_with_logits_op(ctx, node, name, args):
15+
16+
def _make_softmax_cross_entropy_with_logits(ctx, label, logit, tf_ori_node):
17+
label_dtype = ctx.get_dtype(label.output[0])
18+
logit_dtype = ctx.get_dtype(logit.output[0])
19+
utils.make_sure(label_dtype == logit_dtype, "the following logic only works on same dtype of label and logit")
20+
21+
log_softmax = ctx.make_node(op_type="LogSoftmax", inputs=logit.output)
22+
# implement tf.multiply(-1, tf.reduce_sum(tf.multiply(label, log_softmax), axis=1))
23+
mul1 = ctx.make_node(op_type="Mul", inputs=[label.output[0], log_softmax.output[0]])
24+
reduce_sum = ctx.make_node(op_type="ReduceSum", inputs=[mul1.output[0]], attr={"axes": [-1]})
25+
const_negative_one = ctx.make_const(name=utils.make_name("const_negative_one"),
26+
np_val=np.array(-1).astype(utils.ONNX_TO_NUMPY_DTYPE[logit_dtype]))
27+
mul2 = ctx.make_node(op_type="Mul", inputs=[const_negative_one.output[0], reduce_sum.output[0]])
28+
shapes = tf_ori_node.output_shapes
29+
dtypes = tf_ori_node.output_dtypes
30+
ctx.remove_node(tf_ori_node.name)
31+
ctx.make_node(op_type="Squeeze", inputs=[mul2.output[0]], attr={"axes": [1]},
32+
outputs=[tf_ori_node.output[0]], shapes=[shapes[0]], dtypes=[dtypes[0]])
33+
34+
35+
def softmax_cross_entropy_with_logits_op7(ctx, node, name, args):
36+
logits = node.inputs[0]
37+
logit_dtype = ctx.get_dtype(logits.output[0])
38+
labels = node.inputs[1]
39+
label_dtype = ctx.get_dtype(labels.output[0])
40+
if label_dtype != logit_dtype:
41+
labels = ctx.make_node("Cast", labels.output, attr={"to": logit_dtype}, dtypes=[logit_dtype])
42+
43+
_make_softmax_cross_entropy_with_logits(ctx, labels, logits, node)
44+
45+
46+
def sparse_softmax_cross_entropy_with_logits_op7(ctx, node, name, args):
1647
# make subgraph to implement one_hot, idea comes from onehot_op
1748
indices_name = node.input[1]
1849
indices_shape = ctx.get_shape(indices_name)
@@ -92,3 +123,31 @@ def sparse_softmax_cross_entropy_with_logits_op_by_gathernd(ctx, node, name, arg
92123
ctx.make_node(op_type="Squeeze",
93124
inputs=[mul2.output[0]], outputs=[node.output[0]],
94125
attr={"axes": [1]}, shapes=[shapes[0]], dtypes=[dtypes[0]])
126+
127+
128+
def sparse_softmax_cross_entropy_with_logits_op9(ctx, node, name, args):
129+
# float32/64 output = SparseSoftmaxCrossEntropyWithLogits(float32/64 features, int32/64 labels)
130+
# the detail math process of this op is: a = onehot(labels), b = logsoftmax(features), reduce_sum(mul(a, b))
131+
logit_node = node.inputs[0]
132+
logit_shape = ctx.get_shape(node.input[0])
133+
logit_dtype = ctx.get_dtype(node.input[0])
134+
135+
label_name = node.input[1]
136+
label_dtype = ctx.get_dtype(label_name)
137+
138+
num_class = logit_shape[-1]
139+
utils.make_sure(num_class != -1, "number of class should be known, otherwise subgraph to get the info is needed")
140+
# int64 is used because of onnxruntime "onehot" only supports this dtype
141+
depth_node = ctx.make_const(utils.make_name("onehot_depth"), np.array([num_class]).astype(np.int64))
142+
values_node = ctx.make_const(utils.make_name("onehot_values"), np.array([0, 1]).astype(np.int64))
143+
if label_dtype != TensorProto.INT64:
144+
onehot_indice = ctx.make_node("Cast", [label_name], attr={"to": TensorProto.INT64}).output[0]
145+
else:
146+
onehot_indice = label_name
147+
label_node = ctx.make_node(op_type="OneHot", inputs=[onehot_indice, depth_node.output[0], values_node.output[0]])
148+
# the above logic makes output dtype of label_node now always int64
149+
# make sure label has same dtype as logit
150+
if logit_dtype != TensorProto.INT64:
151+
label_node = ctx.make_node("Cast", label_node.output, attr={"to": logit_dtype}, dtypes=[logit_dtype])
152+
153+
_make_softmax_cross_entropy_with_logits(ctx, label_node, logit_node, node)

tf2onnx/rewriter/unit_rewriter_base.py

Lines changed: 16 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -215,21 +215,21 @@ def find_sequence_length_node(self, rnn_scope_name):
215215
seq_len_node_cnt = len(seq_len_nodes)
216216
if seq_len_node_cnt == 0:
217217
return None
218+
218219
if seq_len_node_cnt == 1:
219220
seq_len_node = seq_len_nodes[0]
220221
if seq_len_node.is_const():
221222
return seq_len_node
222-
# input of the "identity" node may be a "cast"
223-
# if so, then we have to keep it
224-
# sentence "math_ops.to_int32(sequence_length)" in tf results in the "cast" op
225-
if seq_len_node.inputs[0].type == "Cast":
226-
cast_node = seq_len_node.inputs[0]
227-
if not cast_node.inputs[0].name.startswith(rnn_scope_name):
228-
return seq_len_node.inputs[0]
229-
raise ValueError("sequence length node should be outside of rnn scope")
223+
230224
if not seq_len_node.inputs[0].name.startswith(rnn_scope_name):
231225
return seq_len_node.inputs[0]
232-
raise ValueError("sequence length node should be outside of rnn scope")
226+
227+
# input of the "identity" node may be a "cast" op generated by "math_ops.to_int32(sequence_length)" in tf
228+
# if so, then we have to find cast input as sequence node.
229+
node = seq_len_node.inputs[0]
230+
if node.type == "Cast" and not node.inputs[0].name.startswith(rnn_scope_name):
231+
return node.inputs[0]
232+
233233
raise ValueError("there are more sequence length nodes than expected")
234234

235235
def get_rnn_input_blacklist(self, rnn_weights, rnn_props):
@@ -342,6 +342,13 @@ def process_seq_length(self, rnn_props, seq_length_node):
342342
attr={"to": onnx_pb.TensorProto.INT32})
343343

344344
self.all_nodes.extend([timestep_node, tile_node, seq_length_node])
345+
else:
346+
# LSTM sequence_lens needs to be int32
347+
ori_seq_dtype = self.g.get_dtype(seq_length_node.name)
348+
if ori_seq_dtype != onnx_pb.TensorProto.INT32:
349+
seq_length_node = self.g.make_node('Cast', [seq_length_node.output[0]],
350+
attr={"to": onnx_pb.TensorProto.INT32})
351+
self.all_nodes.append(seq_length_node)
345352

346353
rnn_props.onnx_input_ids["sequence_lens"] = seq_length_node.output[0]
347354
return seq_length_node, batchsize_node

tf2onnx/tfonnx.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1800,7 +1800,6 @@ def where_op(ctx, node, name, args):
18001800
"ExpandDims": (expanddims_op7, []),
18011801
"OneHot": (onehot_op, []),
18021802
"Reshape": (reshape_op5, []),
1803-
"SparseSoftmaxCrossEntropyWithLogits": (sparse_softmax_cross_entropy_with_logits_op, [])
18041803
}
18051804

18061805
_OPSET_6 = {
@@ -1841,6 +1840,8 @@ def where_op(ctx, node, name, args):
18411840
"ResizeNearestNeighbor": (upsample_op7, ["Upsample", "nearest"]),
18421841
"Sin": (direct_op, []),
18431842
"Sub": (broadcast_op7, []),
1843+
"SoftmaxCrossEntropyWithLogits": (softmax_cross_entropy_with_logits_op7, []),
1844+
"SparseSoftmaxCrossEntropyWithLogits": (sparse_softmax_cross_entropy_with_logits_op7, []),
18441845
"Tan": (direct_op, []),
18451846
"Tile": (tile_op7, []),
18461847
"TruncateDiv": (broadcast_op7, ["Div"]),
@@ -1870,6 +1871,7 @@ def where_op(ctx, node, name, args):
18701871
"ReverseSequence": (reverse_op9, []),
18711872
"Sign": (sign_op9, []),
18721873
"Sinh": (direct_op, []),
1874+
"SparseSoftmaxCrossEntropyWithLogits": (sparse_softmax_cross_entropy_with_logits_op9, []),
18731875
"Where": (where_op, []),
18741876
}
18751877

0 commit comments

Comments
 (0)