Skip to content

Commit b5ee2f1

Browse files
committed
fix bug of lstm
1 parent 09c202d commit b5ee2f1

File tree

3 files changed

+74
-63
lines changed

3 files changed

+74
-63
lines changed

tests/test_gru.py

Lines changed: 28 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -123,34 +123,36 @@ def test_single_dynamic_gru_seq_length_is_const(self):
123123
graph_validator=lambda g: check_gru_count(g, 1))
124124

125125
def test_single_dynamic_gru_seq_length_is_not_const(self):
126-
units = 5
127-
batch_size = 1
128-
x_val = np.array([[1., 1.], [2., 2.], [3., 3.], [4., 4.], [5., 5.]], dtype=np.float32)
129-
x_val = np.stack([x_val] * batch_size)
130-
x = tf.placeholder(tf.float32, x_val.shape, name="input_1")
131-
initializer = init_ops.constant_initializer(0.5)
132-
133-
y_val = np.array([5], dtype=np.int32)
134-
seq_length = tf.placeholder(tf.int32, y_val.shape, name="input_2")
135-
136-
# no scope
137-
cell = rnn.GRUCell(
138-
units,
139-
kernel_initializer=initializer)
140-
outputs, cell_state = tf.nn.dynamic_rnn(
141-
cell,
142-
x,
143-
dtype=tf.float32,
144-
sequence_length=tf.identity(seq_length))
126+
for np_dtype, tf_dtype in [[np.int32, tf.int32], [np.int64, tf.int64], [np.float32, tf.float32]]:
127+
tf.reset_default_graph()
128+
units = 5
129+
batch_size = 1
130+
x_val = np.array([[1., 1.], [2., 2.], [3., 3.], [4., 4.], [5., 5.]], dtype=np.float32)
131+
x_val = np.stack([x_val] * batch_size)
132+
x = tf.placeholder(tf.float32, x_val.shape, name="input_1")
133+
initializer = init_ops.constant_initializer(0.5)
134+
135+
y_val = np.array([5], dtype=np_dtype)
136+
seq_length = tf.placeholder(tf_dtype, y_val.shape, name="input_2")
137+
138+
# no scope
139+
cell = rnn.GRUCell(
140+
units,
141+
kernel_initializer=initializer)
142+
outputs, cell_state = tf.nn.dynamic_rnn(
143+
cell,
144+
x,
145+
dtype=tf.float32,
146+
sequence_length=tf.identity(seq_length))
145147

146-
_ = tf.identity(outputs, name="output")
147-
_ = tf.identity(cell_state, name="cell_state")
148+
_ = tf.identity(outputs, name="output")
149+
_ = tf.identity(cell_state, name="cell_state")
148150

149-
feed_dict = {"input_1:0": x_val, "input_2:0": y_val}
150-
input_names_with_port = ["input_1:0", "input_2:0"]
151-
output_names_with_port = ["output:0", "cell_state:0"]
152-
self.run_test_case(feed_dict, input_names_with_port, output_names_with_port, rtol=1e-03, atol=1e-06,
153-
graph_validator=lambda g: check_gru_count(g, 1))
151+
feed_dict = {"input_1:0": x_val, "input_2:0": y_val}
152+
input_names_with_port = ["input_1:0", "input_2:0"]
153+
output_names_with_port = ["output:0", "cell_state:0"]
154+
self.run_test_case(feed_dict, input_names_with_port, output_names_with_port, rtol=1e-03, atol=1e-06,
155+
graph_validator=lambda g: check_gru_count(g, 1))
154156

155157
def test_single_dynamic_gru_placeholder_input(self):
156158
units = 5

tests/test_lstm.py

Lines changed: 30 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -147,36 +147,38 @@ def test_single_dynamic_lstm_seq_length_is_const(self):
147147
graph_validator=lambda g: check_lstm_count(g, 1))
148148

149149
def test_single_dynamic_lstm_seq_length_is_not_const(self):
150-
units = 5
151-
batch_size = 6
152-
x_val = np.array([[1., 1.], [2., 2.], [3., 3.], [4., 4.], [5., 5.]], dtype=np.float32)
153-
x_val = np.stack([x_val] * batch_size)
154-
state_is_tuple = True
155-
x = tf.placeholder(tf.float32, x_val.shape, name="input_1")
156-
initializer = init_ops.constant_initializer(0.5)
157-
158-
y_val = np.array([4, 3, 4, 5, 2, 1], dtype=np.int32)
159-
seq_length = tf.placeholder(tf.int32, y_val.shape, name="input_2")
160-
161-
# no scope
162-
cell = rnn.LSTMCell(
163-
units,
164-
initializer=initializer,
165-
state_is_tuple=state_is_tuple)
166-
outputs, cell_state = tf.nn.dynamic_rnn(
167-
cell,
168-
x,
169-
dtype=tf.float32,
170-
sequence_length=tf.identity(seq_length))
150+
for np_dtype, tf_dtype in [[np.int32, tf.int32], [np.int64, tf.int64], [np.float32, tf.float32]]:
151+
tf.reset_default_graph()
152+
units = 5
153+
batch_size = 6
154+
x_val = np.array([[1., 1.], [2., 2.], [3., 3.], [4., 4.], [5., 5.]], dtype=np.float32)
155+
x_val = np.stack([x_val] * batch_size)
156+
state_is_tuple = True
157+
x = tf.placeholder(tf.float32, x_val.shape, name="input_1")
158+
initializer = init_ops.constant_initializer(0.5)
159+
160+
y_val = np.array([4, 3, 4, 5, 2, 1], dtype=np_dtype)
161+
seq_length = tf.placeholder(tf_dtype, y_val.shape, name="input_2")
162+
163+
# no scope
164+
cell = rnn.LSTMCell(
165+
units,
166+
initializer=initializer,
167+
state_is_tuple=state_is_tuple)
168+
outputs, cell_state = tf.nn.dynamic_rnn(
169+
cell,
170+
x,
171+
dtype=tf.float32,
172+
sequence_length=tf.identity(seq_length))
171173

172-
_ = tf.identity(outputs, name="output")
173-
_ = tf.identity(cell_state, name="cell_state")
174+
_ = tf.identity(outputs, name="output")
175+
_ = tf.identity(cell_state, name="cell_state")
174176

175-
feed_dict = {"input_1:0": x_val, "input_2:0": y_val}
176-
input_names_with_port = ["input_1:0", "input_2:0"]
177-
output_names_with_port = ["output:0", "cell_state:0"]
178-
self.run_test_case(feed_dict, input_names_with_port, output_names_with_port, rtol=1e-06,
179-
graph_validator=lambda g: check_lstm_count(g, 1))
177+
feed_dict = {"input_1:0": x_val, "input_2:0": y_val}
178+
input_names_with_port = ["input_1:0", "input_2:0"]
179+
output_names_with_port = ["output:0", "cell_state:0"]
180+
self.run_test_case(feed_dict, input_names_with_port, output_names_with_port, rtol=1e-06,
181+
graph_validator=lambda g: check_lstm_count(g, 1))
180182

181183
def test_single_dynamic_lstm_placeholder_input(self):
182184
units = 5

tf2onnx/rewriter/unit_rewriter_base.py

Lines changed: 16 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -215,21 +215,21 @@ def find_sequence_length_node(self, rnn_scope_name):
215215
seq_len_node_cnt = len(seq_len_nodes)
216216
if seq_len_node_cnt == 0:
217217
return None
218+
218219
if seq_len_node_cnt == 1:
219220
seq_len_node = seq_len_nodes[0]
220221
if seq_len_node.is_const():
221222
return seq_len_node
222-
# input of the "identity" node may be a "cast"
223-
# if so, then we have to keep it
224-
# sentence "math_ops.to_int32(sequence_length)" in tf results in the "cast" op
225-
if seq_len_node.inputs[0].type == "Cast":
226-
cast_node = seq_len_node.inputs[0]
227-
if not cast_node.inputs[0].name.startswith(rnn_scope_name):
228-
return seq_len_node.inputs[0]
229-
raise ValueError("sequence length node should be outside of rnn scope")
223+
230224
if not seq_len_node.inputs[0].name.startswith(rnn_scope_name):
231225
return seq_len_node.inputs[0]
232-
raise ValueError("sequence length node should be outside of rnn scope")
226+
227+
# input of the "identity" node may be a "cast" op generated by "math_ops.to_int32(sequence_length)" in tf
228+
# if so, then we have to find cast input as sequence node.
229+
node = seq_len_node.inputs[0]
230+
if node.type == "Cast" and not node.inputs[0].name.startswith(rnn_scope_name):
231+
return node.inputs[0]
232+
233233
raise ValueError("there are more sequence length nodes than expected")
234234

235235
def get_rnn_input_blacklist(self, rnn_weights, rnn_props):
@@ -342,6 +342,13 @@ def process_seq_length(self, rnn_props, seq_length_node):
342342
attr={"to": onnx_pb.TensorProto.INT32})
343343

344344
self.all_nodes.extend([timestep_node, tile_node, seq_length_node])
345+
else:
346+
# LSTM sequence_lens needs to be int32
347+
ori_seq_dtype = self.g.get_dtype(seq_length_node.name)
348+
if ori_seq_dtype != onnx_pb.TensorProto.INT32:
349+
seq_length_node = self.g.make_node('Cast', [seq_length_node.output[0]],
350+
attr={"to": onnx_pb.TensorProto.INT32})
351+
self.all_nodes.append(seq_length_node)
345352

346353
rnn_props.onnx_input_ids["sequence_lens"] = seq_length_node.output[0]
347354
return seq_length_node, batchsize_node

0 commit comments

Comments
 (0)