Skip to content

Commit 62fcc8c

Browse files
committed
fix bug of bigru/bilstm
only "state" tensor of bigru is consmued, then output of bigru doesn't have a reverse op. can be merged
1 parent 003f230 commit 62fcc8c

File tree

5 files changed

+295
-5
lines changed

5 files changed

+295
-5
lines changed

tests/run_pretrained_models.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,14 +23,14 @@
2323
import tensorflow as tf
2424
from tensorflow.core.framework import graph_pb2
2525
from tensorflow.python.framework.graph_util import convert_variables_to_constants
26+
from tensorflow.contrib.rnn import GRUBlockCell # pylint: disable=unused-import
2627
import yaml
2728
import PIL.Image
2829

2930
import tf2onnx
3031
from tf2onnx import utils
3132
from tf2onnx.graph import GraphUtil
3233
from tf2onnx.tfonnx import process_tf_graph
33-
from tensorflow.contrib.rnn import GRUBlockCell # pylint: disable=unused-import
3434

3535
# pylint: disable=broad-except,logging-not-lazy,unused-argument,unnecessary-lambda
3636

tests/test_gru.py

Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -255,6 +255,54 @@ def test_single_dynamic_gru_random_weights2(self):
255255
output_names_with_port = ["output:0", "cell_state:0"]
256256
self.run_test_case(feed_dict, input_names_with_port, output_names_with_port, 0.01)
257257

258+
def test_dynamic_gru_output_consumed_only(self):
259+
units = 5
260+
batch_size = 6
261+
x_val = np.array([[1., 1.], [2., 2.], [3., 3.]], dtype=np.float32)
262+
x_val = np.stack([x_val] * batch_size)
263+
264+
x = tf.placeholder(tf.float32, x_val.shape, name="input_1")
265+
initializer = tf.random_uniform_initializer(-1.0, 1.0)
266+
cell1 = rnn.GRUCell(
267+
units,
268+
kernel_initializer=initializer)
269+
270+
outputs, _ = tf.nn.dynamic_rnn(
271+
cell1,
272+
x,
273+
dtype=tf.float32)
274+
275+
_ = tf.identity(outputs, name="output")
276+
277+
feed_dict = {"input_1:0": x_val}
278+
input_names_with_port = ["input_1:0"]
279+
output_names_with_port = ["output:0"]
280+
self.run_test_case(feed_dict, input_names_with_port, output_names_with_port, 0.0001)
281+
282+
def test_dynamic_gru_state_consumed_only(self):
283+
units = 5
284+
batch_size = 6
285+
x_val = np.array([[1., 1.], [2., 2.], [3., 3.]], dtype=np.float32)
286+
x_val = np.stack([x_val] * batch_size)
287+
288+
x = tf.placeholder(tf.float32, x_val.shape, name="input_1")
289+
initializer = tf.random_uniform_initializer(-1.0, 1.0)
290+
cell1 = rnn.GRUCell(
291+
units,
292+
kernel_initializer=initializer)
293+
294+
_, cell_state = tf.nn.dynamic_rnn(
295+
cell1,
296+
x,
297+
dtype=tf.float32)
298+
299+
_ = tf.identity(cell_state, name="cell_state")
300+
301+
feed_dict = {"input_1:0": x_val}
302+
input_names_with_port = ["input_1:0"]
303+
output_names_with_port = ["cell_state:0"]
304+
self.run_test_case(feed_dict, input_names_with_port, output_names_with_port, 0.0001)
305+
258306
def test_dynamic_bigru(self):
259307
units = 5
260308
batch_size = 1
@@ -320,6 +368,38 @@ def test_dynamic_bigru_output_consumed_only(self):
320368
output_names_with_port = ["output:0"]
321369
self.run_test_case(feed_dict, input_names_with_port, output_names_with_port, rtol=1e-3)
322370

371+
def test_dynamic_bigru_state_consumed_only(self):
372+
units = 5
373+
batch_size = 1
374+
x_val = np.array([[1., 1.], [2., 2.], [3., 3.]], dtype=np.float32)
375+
x_val = np.stack([x_val] * batch_size)
376+
377+
x = tf.placeholder(tf.float32, x_val.shape, name="input_1")
378+
initializer = init_ops.constant_initializer(0.5)
379+
380+
gru_list = []
381+
if True:
382+
# bigru, no scope
383+
cell1 = rnn.GRUCell(
384+
units,
385+
kernel_initializer=initializer)
386+
cell2 = rnn.GRUCell(
387+
units,
388+
kernel_initializer=initializer)
389+
outputs, cell_state = tf.nn.bidirectional_dynamic_rnn(
390+
cell1,
391+
cell2,
392+
x,
393+
dtype=tf.float32)
394+
gru_list.append(outputs)
395+
396+
_ = tf.identity(cell_state, name="cell_state")
397+
398+
feed_dict = {"input_1:0": x_val}
399+
input_names_with_port = ["input_1:0"]
400+
output_names_with_port = ["cell_state:0"]
401+
self.run_test_case(feed_dict, input_names_with_port, output_names_with_port, rtol=1e-3)
402+
323403
def test_dynamic_bidirectional_but_one_gru(self):
324404
units = 5
325405
batch_size = 1
@@ -377,6 +457,33 @@ def test_dynamic_bidirectional_but_one_gru_and_output_consumed_only(self):
377457
output_names_with_port = ["output:0"]
378458
self.run_test_case(feed_dict, input_names_with_port, output_names_with_port, rtol=1e-3)
379459

460+
def test_dynamic_bidirectional_but_one_gru_and_state_consumed_only(self):
461+
units = 5
462+
batch_size = 1
463+
x_val = np.array([[1., 1.], [2., 2.], [3., 3.]], dtype=np.float32)
464+
x_val = np.stack([x_val] * batch_size)
465+
466+
x = tf.placeholder(tf.float32, x_val.shape, name="input_1")
467+
468+
gru_list = []
469+
if True:
470+
# bigru, no scope
471+
cell = rnn.GRUCell(
472+
units)
473+
outputs, cell_state = tf.nn.bidirectional_dynamic_rnn(
474+
cell,
475+
cell,
476+
x,
477+
dtype=tf.float32)
478+
gru_list.append(outputs)
479+
480+
_ = tf.identity(cell_state, name="cell_state")
481+
482+
feed_dict = {"input_1:0": x_val}
483+
input_names_with_port = ["input_1:0"]
484+
output_names_with_port = ["cell_state:0"]
485+
self.run_test_case(feed_dict, input_names_with_port, output_names_with_port, rtol=1e-3)
486+
380487

381488
if __name__ == '__main__':
382489
Tf2OnnxBackendTestBase.trigger(GRUTests)

tests/test_grublock.py

Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -238,6 +238,50 @@ def test_single_dynamic_gru_random_weights2(self):
238238
output_names_with_port = ["output:0", "cell_state:0"]
239239
self.run_test_case(feed_dict, input_names_with_port, output_names_with_port, 0.01)
240240

241+
def test_dynamic_gru_output_consumed_only(self):
242+
units = 5
243+
batch_size = 6
244+
x_val = np.array([[1., 1.], [2., 2.], [3., 3.]], dtype=np.float32)
245+
x_val = np.stack([x_val] * batch_size)
246+
247+
x = tf.placeholder(tf.float32, x_val.shape, name="input_1")
248+
cell1 = rnn.GRUBlockCell(
249+
units)
250+
251+
outputs, _ = tf.nn.dynamic_rnn(
252+
cell1,
253+
x,
254+
dtype=tf.float32)
255+
256+
_ = tf.identity(outputs, name="output")
257+
258+
feed_dict = {"input_1:0": x_val}
259+
input_names_with_port = ["input_1:0"]
260+
output_names_with_port = ["output:0"]
261+
self.run_test_case(feed_dict, input_names_with_port, output_names_with_port, 0.0001)
262+
263+
def test_dynamic_gru_state_consumed_only(self):
264+
units = 5
265+
batch_size = 6
266+
x_val = np.array([[1., 1.], [2., 2.], [3., 3.]], dtype=np.float32)
267+
x_val = np.stack([x_val] * batch_size)
268+
269+
x = tf.placeholder(tf.float32, x_val.shape, name="input_1")
270+
cell1 = rnn.GRUBlockCell(
271+
units)
272+
273+
_, cell_state = tf.nn.dynamic_rnn(
274+
cell1,
275+
x,
276+
dtype=tf.float32)
277+
278+
_ = tf.identity(cell_state, name="cell_state")
279+
280+
feed_dict = {"input_1:0": x_val}
281+
input_names_with_port = ["input_1:0"]
282+
output_names_with_port = ["cell_state:0"]
283+
self.run_test_case(feed_dict, input_names_with_port, output_names_with_port, 0.0001)
284+
241285
def test_dynamic_bigru(self):
242286
units = 5
243287
batch_size = 1
@@ -297,6 +341,35 @@ def test_dynamic_bigru_output_consumed_only(self):
297341
output_names_with_port = ["output:0"]
298342
self.run_test_case(feed_dict, input_names_with_port, output_names_with_port, rtol=1e-3)
299343

344+
def test_dynamic_bigru_state_consumed_only(self):
345+
units = 5
346+
batch_size = 1
347+
x_val = np.array([[1., 1.], [2., 2.], [3., 3.]], dtype=np.float32)
348+
x_val = np.stack([x_val] * batch_size)
349+
350+
x = tf.placeholder(tf.float32, x_val.shape, name="input_1")
351+
352+
gru_list = []
353+
if True:
354+
# bigru, no scope
355+
cell1 = rnn.GRUBlockCell(
356+
units)
357+
cell2 = rnn.GRUBlockCell(
358+
units)
359+
_, cell_state = tf.nn.bidirectional_dynamic_rnn(
360+
cell1,
361+
cell2,
362+
x,
363+
dtype=tf.float32)
364+
gru_list.append(cell_state)
365+
366+
_ = tf.identity(cell_state, name="cell_state")
367+
368+
feed_dict = {"input_1:0": x_val}
369+
input_names_with_port = ["input_1:0"]
370+
output_names_with_port = ["cell_state:0"]
371+
self.run_test_case(feed_dict, input_names_with_port, output_names_with_port, rtol=1e-3)
372+
300373
def test_dynamic_bidirectional_but_one_gru(self):
301374
units = 5
302375
batch_size = 1
@@ -352,6 +425,33 @@ def test_dynamic_bidirectional_but_one_gru_and_output_consumed_only(self):
352425
output_names_with_port = ["output:0"]
353426
self.run_test_case(feed_dict, input_names_with_port, output_names_with_port, rtol=1e-3)
354427

428+
def test_dynamic_bidirectional_but_one_gru_and_state_consumed_only(self):
429+
units = 5
430+
batch_size = 1
431+
x_val = np.array([[1., 1.], [2., 2.], [3., 3.]], dtype=np.float32)
432+
x_val = np.stack([x_val] * batch_size)
433+
434+
x = tf.placeholder(tf.float32, x_val.shape, name="input_1")
435+
436+
gru_list = []
437+
if True:
438+
# bigru, no scope
439+
cell = rnn.GRUBlockCell(
440+
units)
441+
_, cell_state = tf.nn.bidirectional_dynamic_rnn(
442+
cell,
443+
cell,
444+
x,
445+
dtype=tf.float32)
446+
gru_list.append(cell_state)
447+
448+
_ = tf.identity(cell_state, name="cell_state")
449+
450+
feed_dict = {"input_1:0": x_val}
451+
input_names_with_port = ["input_1:0"]
452+
output_names_with_port = ["cell_state:0"]
453+
self.run_test_case(feed_dict, input_names_with_port, output_names_with_port, rtol=1e-3)
454+
355455

356456
if __name__ == '__main__':
357457
Tf2OnnxBackendTestBase.trigger(GRUBlockTests)

tests/test_lstm.py

Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919

2020
# pylint: disable=missing-docstring,invalid-name,unused-argument,using-constant-test
2121

22+
2223
class LSTMTests(Tf2OnnxBackendTestBase):
2324
def test_test_single_dynamic_lstm_state_is_tuple(self):
2425
self.internal_test_single_dynamic_lstm(True)
@@ -334,6 +335,52 @@ def test_dynamic_basiclstm(self):
334335
output_names_with_port = ["output:0", "cell_state:0"]
335336
self.run_test_case(feed_dict, input_names_with_port, output_names_with_port, 0.0001)
336337

338+
def test_dynamic_lstm_output_consumed_only(self):
339+
units = 5
340+
batch_size = 6
341+
x_val = np.array([[1., 1.], [2., 2.], [3., 3.]], dtype=np.float32)
342+
x_val = np.stack([x_val] * batch_size)
343+
344+
x = tf.placeholder(tf.float32, x_val.shape, name="input_1")
345+
cell1 = rnn.LSTMCell(
346+
units,
347+
state_is_tuple=True)
348+
349+
outputs, _ = tf.nn.dynamic_rnn(
350+
cell1,
351+
x,
352+
dtype=tf.float32)
353+
354+
_ = tf.identity(outputs, name="output")
355+
356+
feed_dict = {"input_1:0": x_val}
357+
input_names_with_port = ["input_1:0"]
358+
output_names_with_port = ["output:0"]
359+
self.run_test_case(feed_dict, input_names_with_port, output_names_with_port, 0.0001)
360+
361+
def test_dynamic_lstm_state_consumed_only(self):
362+
units = 5
363+
batch_size = 6
364+
x_val = np.array([[1., 1.], [2., 2.], [3., 3.]], dtype=np.float32)
365+
x_val = np.stack([x_val] * batch_size)
366+
367+
x = tf.placeholder(tf.float32, x_val.shape, name="input_1")
368+
cell1 = rnn.LSTMCell(
369+
units,
370+
state_is_tuple=True)
371+
372+
_, cell_state = tf.nn.dynamic_rnn(
373+
cell1,
374+
x,
375+
dtype=tf.float32)
376+
377+
_ = tf.identity(cell_state, name="cell_state")
378+
379+
feed_dict = {"input_1:0": x_val}
380+
input_names_with_port = ["input_1:0"]
381+
output_names_with_port = ["cell_state:0"]
382+
self.run_test_case(feed_dict, input_names_with_port, output_names_with_port, 0.0001)
383+
337384
def test_dynamic_bilstm_state_is_tuple(self):
338385
self.internal_test_dynamic_bilstm_with_parameters(True)
339386

@@ -409,6 +456,40 @@ def test_dynamic_bilstm_output_consumed_only(self, state_is_tuple=True):
409456
output_names_with_port = ["output:0"]
410457
self.run_test_case(feed_dict, input_names_with_port, output_names_with_port, rtol=1e-06)
411458

459+
def test_dynamic_bilstm_state_consumed_only(self, state_is_tuple=True):
460+
units = 5
461+
batch_size = 6
462+
x_val = np.array([[1., 1.], [2., 2.], [3., 3.]], dtype=np.float32)
463+
x_val = np.stack([x_val] * batch_size)
464+
465+
x = tf.placeholder(tf.float32, x_val.shape, name="input_1")
466+
initializer = init_ops.constant_initializer(0.5)
467+
468+
lstm_list = []
469+
if True:
470+
# bilstm, no scope
471+
cell1 = rnn.LSTMCell(
472+
units,
473+
initializer=initializer,
474+
state_is_tuple=state_is_tuple) # state_is_tuple will impact Pack node (for cell_state)'s usage pattern
475+
cell2 = rnn.LSTMCell(
476+
units,
477+
initializer=initializer,
478+
state_is_tuple=state_is_tuple)
479+
outputs, cell_state = tf.nn.bidirectional_dynamic_rnn(
480+
cell1,
481+
cell2,
482+
x,
483+
dtype=tf.float32)
484+
lstm_list.append(outputs)
485+
486+
_ = tf.identity(cell_state, name="cell_state")
487+
488+
feed_dict = {"input_1:0": x_val}
489+
input_names_with_port = ["input_1:0"]
490+
output_names_with_port = ["cell_state:0"]
491+
self.run_test_case(feed_dict, input_names_with_port, output_names_with_port, rtol=1e-06)
492+
412493

413494
if __name__ == '__main__':
414495
Tf2OnnxBackendTestBase.trigger(LSTMTests)

tf2onnx/rewriter/bigru_rewriter.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -153,10 +153,12 @@ def rewrite_bidirectional_grus(g, ops):
153153
is_backward_gru = True
154154

155155
if is_backward_gru:
156-
# make sure reverse gru output will be reversed back
157-
if get_reverse_nodes_after_y_output(g, n):
158-
log.debug("find bw gru %s", input_id)
159-
bw_gru[input_id] = [input_id, n]
156+
# if output 0 is consumed, and there is no reverse after the gru output.
157+
# it's not reversed gru
158+
if g.find_output_consumers(n.output[0]) and not get_reverse_nodes_after_y_output(g, n):
159+
continue
160+
log.debug("find bw gru %s", input_id)
161+
bw_gru[input_id] = [input_id, n]
160162
else:
161163
log.debug("find fw gru %s", input_id)
162164
fw_gru[input_id] = [input_id, n]

0 commit comments

Comments
 (0)