Skip to content

Commit 03cf6cd

Browse files
support multi birnn share the same input
1 parent 4437438 commit 03cf6cd

File tree

8 files changed

+407
-224
lines changed

8 files changed

+407
-224
lines changed

tests/test_gru.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -482,6 +482,56 @@ def test_dynamic_bidirectional_but_one_gru_and_state_consumed_only(self):
482482
self.run_test_case(feed_dict, input_names_with_port, output_names_with_port, rtol=1e-3, atol=1e-06,
483483
graph_validator=lambda g: check_gru_count(g, 1))
484484

485+
def test_dynamic_multi_bigru_with_same_input(self):
486+
units = 5
487+
batch_size = 1
488+
x_val = np.array([[1., 1.], [2., 2.], [3., 3.]], dtype=np.float32)
489+
x_val = np.stack([x_val] * batch_size)
490+
491+
x = tf.placeholder(tf.float32, x_val.shape, name="input_1")
492+
initializer = init_ops.constant_initializer(0.5)
493+
494+
# bigru, no scope
495+
cell1 = rnn.GRUCell(
496+
units,
497+
kernel_initializer=initializer)
498+
cell2 = rnn.GRUCell(
499+
units,
500+
kernel_initializer=initializer)
501+
outputs_1, cell_state_1 = tf.nn.bidirectional_dynamic_rnn(
502+
cell1,
503+
cell2,
504+
x,
505+
dtype=tf.float32,
506+
scope="bigru_1"
507+
)
508+
509+
units = 10
510+
cell1 = rnn.GRUCell(
511+
units,
512+
kernel_initializer=initializer)
513+
cell2 = rnn.GRUCell(
514+
units,
515+
kernel_initializer=initializer)
516+
outputs_2, cell_state_2 = tf.nn.bidirectional_dynamic_rnn(
517+
cell1,
518+
cell2,
519+
x,
520+
dtype=tf.float32,
521+
scope="bigru_2"
522+
)
523+
524+
_ = tf.identity(outputs_1, name="output_1")
525+
_ = tf.identity(cell_state_1, name="cell_state_1")
526+
_ = tf.identity(outputs_2, name="output_2")
527+
_ = tf.identity(cell_state_2, name="cell_state_2")
528+
529+
feed_dict = {"input_1:0": x_val}
530+
input_names_with_port = ["input_1:0"]
531+
output_names_with_port = ["output_1:0", "cell_state_1:0", "output_2:0", "cell_state_2:0"]
532+
self.run_test_case(feed_dict, input_names_with_port, output_names_with_port, rtol=1e-3, atol=1e-06,
533+
graph_validator=lambda g: check_gru_count(g, 2))
534+
485535

486536
if __name__ == '__main__':
487537
unittest_main()

tests/test_grublock.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -447,6 +447,47 @@ def test_dynamic_bidirectional_but_one_gru_and_state_consumed_only(self):
447447
self.run_test_case(feed_dict, input_names_with_port, output_names_with_port, rtol=1e-3, atol=1e-06,
448448
graph_validator=lambda g: check_gru_count(g, 1))
449449

450+
def test_dynamic_multi_bigru_with_same_input(self):
451+
units = 5
452+
batch_size = 1
453+
x_val = np.array([[1., 1.], [2., 2.], [3., 3.]], dtype=np.float32)
454+
x_val = np.stack([x_val] * batch_size)
455+
456+
x = tf.placeholder(tf.float32, x_val.shape, name="input_1")
457+
458+
# bigru, no scope
459+
cell1 = rnn.GRUBlockCell(units)
460+
cell2 = rnn.GRUBlockCell(units)
461+
outputs_1, cell_state_1 = tf.nn.bidirectional_dynamic_rnn(
462+
cell1,
463+
cell2,
464+
x,
465+
dtype=tf.float32,
466+
scope="bigru_1"
467+
)
468+
469+
units = 10
470+
cell1 = rnn.GRUBlockCell(units)
471+
cell2 = rnn.GRUBlockCell(units)
472+
outputs_2, cell_state_2 = tf.nn.bidirectional_dynamic_rnn(
473+
cell1,
474+
cell2,
475+
x,
476+
dtype=tf.float32,
477+
scope="bigru_2"
478+
)
479+
480+
_ = tf.identity(outputs_1, name="output_1")
481+
_ = tf.identity(cell_state_1, name="cell_state_1")
482+
_ = tf.identity(outputs_2, name="output_2")
483+
_ = tf.identity(cell_state_2, name="cell_state_2")
484+
485+
feed_dict = {"input_1:0": x_val}
486+
input_names_with_port = ["input_1:0"]
487+
output_names_with_port = ["output_1:0", "cell_state_1:0", "output_2:0", "cell_state_2:0"]
488+
self.run_test_case(feed_dict, input_names_with_port, output_names_with_port, rtol=1e-3, atol=1e-06,
489+
graph_validator=lambda g: check_gru_count(g, 2))
490+
450491

451492
if __name__ == '__main__':
452493
unittest_main()

tests/test_lstm.py

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -554,6 +554,69 @@ def test_dynamic_bilstm_state_consumed_only(self, state_is_tuple=True):
554554
self.run_test_case(feed_dict, input_names_with_port, output_names_with_port, rtol=1e-06,
555555
graph_validator=lambda g: check_lstm_count(g, 1))
556556

557+
def test_dynamic_multi_bilstm_with_same_input_state_is_tuple(self):
558+
self.internal_test_dynamic_multi_bilstm_with_same_input(True)
559+
560+
def test_dynamic_multi_bilstm_with_same_input_state_is_list(self):
561+
self.internal_test_dynamic_multi_bilstm_with_same_input(False)
562+
563+
def internal_test_dynamic_multi_bilstm_with_same_input(self, state_is_tuple):
564+
units = 5
565+
batch_size = 1
566+
x_val = np.array([[1., 1.], [2., 2.], [3., 3.]], dtype=np.float32)
567+
x_val = np.stack([x_val] * batch_size)
568+
569+
x = tf.placeholder(tf.float32, x_val.shape, name="input_1")
570+
initializer = init_ops.constant_initializer(0.5)
571+
572+
cell1 = rnn.LSTMCell(
573+
units,
574+
initializer=initializer,
575+
state_is_tuple=state_is_tuple
576+
)
577+
cell2 = rnn.LSTMCell(
578+
units,
579+
initializer=initializer,
580+
state_is_tuple=state_is_tuple
581+
)
582+
outputs_1, cell_state_1 = tf.nn.bidirectional_dynamic_rnn(
583+
cell1,
584+
cell2,
585+
x,
586+
dtype=tf.float32,
587+
scope="bilstm_1"
588+
)
589+
590+
units = 10
591+
cell1 = rnn.LSTMCell(
592+
units,
593+
initializer=initializer,
594+
state_is_tuple=state_is_tuple
595+
)
596+
cell2 = rnn.LSTMCell(
597+
units,
598+
initializer=initializer,
599+
state_is_tuple=state_is_tuple
600+
)
601+
outputs_2, cell_state_2 = tf.nn.bidirectional_dynamic_rnn(
602+
cell1,
603+
cell2,
604+
x,
605+
dtype=tf.float32,
606+
scope="bilstm_2"
607+
)
608+
609+
_ = tf.identity(outputs_1, name="output_1")
610+
_ = tf.identity(cell_state_1, name="cell_state_1")
611+
_ = tf.identity(outputs_2, name="output_2")
612+
_ = tf.identity(cell_state_2, name="cell_state_2")
613+
614+
feed_dict = {"input_1:0": x_val}
615+
input_names_with_port = ["input_1:0"]
616+
output_names_with_port = ["output_1:0", "cell_state_1:0", "output_2:0", "cell_state_2:0"]
617+
self.run_test_case(feed_dict, input_names_with_port, output_names_with_port, rtol=1e-3, atol=1e-06,
618+
graph_validator=lambda g: check_lstm_count(g, 2))
619+
557620

558621
if __name__ == '__main__':
559622
unittest_main()

tests/test_lstmblock.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -488,5 +488,45 @@ def test_multi_rnn_lstm(self):
488488
output_names_with_port = ["output:0", "cell_state:0"]
489489
self.run_test_case(feed_dict, input_names_with_port, output_names_with_port, rtol=1e-03, atol=1e-06)
490490

491+
def test_dynamic_multi_bilstm_with_same_input(self):
492+
units = 5
493+
batch_size = 1
494+
x_val = np.array([[1., 1.], [2., 2.], [3., 3.]], dtype=np.float32)
495+
x_val = np.stack([x_val] * batch_size)
496+
497+
x = tf.placeholder(tf.float32, x_val.shape, name="input_1")
498+
499+
cell1 = rnn.LSTMBlockCell(units)
500+
cell2 = rnn.LSTMBlockCell(units)
501+
outputs_1, cell_state_1 = tf.nn.bidirectional_dynamic_rnn(
502+
cell1,
503+
cell2,
504+
x,
505+
dtype=tf.float32,
506+
scope="bilstm_1"
507+
)
508+
509+
units = 10
510+
cell1 = rnn.LSTMBlockCell(units)
511+
cell2 = rnn.LSTMBlockCell(units)
512+
outputs_2, cell_state_2 = tf.nn.bidirectional_dynamic_rnn(
513+
cell1,
514+
cell2,
515+
x,
516+
dtype=tf.float32,
517+
scope="bilstm_2"
518+
)
519+
520+
_ = tf.identity(outputs_1, name="output_1")
521+
_ = tf.identity(cell_state_1, name="cell_state_1")
522+
_ = tf.identity(outputs_2, name="output_2")
523+
_ = tf.identity(cell_state_2, name="cell_state_2")
524+
525+
feed_dict = {"input_1:0": x_val}
526+
input_names_with_port = ["input_1:0"]
527+
output_names_with_port = ["output_1:0", "cell_state_1:0", "output_2:0", "cell_state_2:0"]
528+
self.run_test_case(feed_dict, input_names_with_port, output_names_with_port, rtol=1e-3, atol=1e-06,
529+
graph_validator=lambda g: check_lstm_count(g, 2))
530+
491531
if __name__ == '__main__':
492532
unittest_main()

tf2onnx/graph.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1036,6 +1036,15 @@ def find_output_consumers(self, output_name):
10361036
nodes.extend(g.find_output_consumers(output_name))
10371037
return nodes
10381038

1039+
def find_common_consumers(self, *outputs):
1040+
if not outputs:
1041+
return False
1042+
1043+
common_consumer = set.intersection(
1044+
*[set(self.find_output_consumers(out)) for out in outputs]
1045+
)
1046+
return list(common_consumer)
1047+
10391048
@staticmethod
10401049
def replace_all_inputs(ops, old_input, new_input):
10411050
"""Replace all inputs pointing to old_input with new_input."""

tf2onnx/rewriter/bigru_rewriter.py

Lines changed: 8 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -13,24 +13,18 @@
1313
import logging
1414
import numpy as np
1515
from tf2onnx import utils
16-
from tf2onnx.utils import is_tf_reverse_op
17-
from tf2onnx.rewriter.bilstm_rewriter import slice_bilstm_for_original_lstm_consumers,\
18-
get_reverse_nodes_after_y_output, get_np_val_for_const, _process_single_init_node
19-
16+
from tf2onnx.rewriter.rnn_utils import find_bidirectional_rnns, ONNX_RNN_TYPE, get_np_val_for_const, \
17+
process_single_init_node, slice_birnn_for_original_rnn_consumers
2018

2119

2220
logger = logging.getLogger(__name__)
2321

2422
# pylint: disable=invalid-name,unused-argument,missing-docstring
2523

2624
def process_bigru(g, bi_grus):
27-
for fw, bw in bi_grus:
28-
input_id = fw[0]
25+
for gru_fw, gru_bw in bi_grus:
2926
logger.debug("=========================")
30-
logger.debug("start handling potential bidirectional gru %s", input_id)
31-
32-
gru_fw = fw[1]
33-
gru_bw = bw[1]
27+
logger.debug("start handling potential bidirectional gru: %s, %s", gru_fw.name, gru_bw.name)
3428

3529
w_fw = get_np_val_for_const(g, gru_fw, 1)
3630
w_bw = get_np_val_for_const(g, gru_bw, 1)
@@ -102,80 +96,25 @@ def process_bigru(g, bi_grus):
10296

10397
to_remove = [gru_fw.name, gru_fw.input[1], gru_fw.input[2], gru_fw.input[3],
10498
gru_bw.name, gru_bw.input[1], gru_bw.input[2], gru_bw.input[3]]
105-
slice_bilstm_for_original_lstm_consumers(
99+
slice_birnn_for_original_rnn_consumers(
106100
g, gru_fw, gru_bw, bi_gru_node, 0, all_nodes, to_remove)
107-
slice_bilstm_for_original_lstm_consumers(
101+
slice_birnn_for_original_rnn_consumers(
108102
g, gru_fw, gru_bw, bi_gru_node, 1, all_nodes, to_remove)
109103

110-
gru_bw_old_x = gru_bw.input[0]
111-
112104
for n in to_remove:
113105
g.remove_node(n)
114106

115-
old_x_consumers = g.find_output_consumers(gru_bw_old_x)
116-
# the transpose/reverse here must be followed by GRU if it is still useful.
117-
# this is guaranteed by dynamic_rnn logic.
118-
old_x_has_gru_as_consumer = [
119-
n for n in old_x_consumers if n.type == "GRU"]
120-
if not old_x_has_gru_as_consumer:
121-
logger.debug("plan to remove useless reverse op in bw")
122-
reverse_node = g.get_node_by_output(gru_bw_old_x)
123-
124-
if reverse_node.type == "Transpose":
125-
reverse_node = reverse_node.inputs[0]
126-
127-
g.replace_all_inputs(
128-
g.get_nodes(), reverse_node.output[0], reverse_node.input[0])
129-
g.remove_node(reverse_node.name)
130-
else:
131-
raise ValueError(
132-
"Reverse is still used by GRU as input, cannot remove")
133-
134107
return g.get_nodes()
135108

136109

137110
def process_init_nodes(g, gru_fw, gru_bw, to_append):
138-
initializer_node = _process_single_init_node(
111+
initializer_node = process_single_init_node(
139112
g, gru_fw.input[5], gru_bw.input[5], to_append)
140113

141114
return initializer_node
142115

143116

144117
def rewrite_bidirectional_grus(g, ops):
145-
"""
146-
return: list of tuple, format of tuple is
147-
((fw input_id, fw onnx gru node), (bw input_id, bw onnx gru node)), and fw input_id equals to bw input_id
148-
"""
149-
fw_gru = {}
150-
bw_gru = {}
151-
for n in g.get_nodes():
152-
if n.type != "GRU":
153-
continue
154-
input_id = n.input[0]
155-
temp = n.inputs[0]
156-
is_backward_gru = False
157-
if temp.type == "Transpose":
158-
input_id = temp.input[0]
159-
temp = temp.inputs[0]
160-
161-
if is_tf_reverse_op(temp):
162-
input_id = temp.input[0]
163-
is_backward_gru = True
164-
165-
if is_backward_gru:
166-
# if output 0 is consumed, and there is no reverse after the gru output.
167-
# it's not reversed gru
168-
if g.find_output_consumers(n.output[0]) and not get_reverse_nodes_after_y_output(g, n):
169-
continue
170-
logger.debug("find bw gru %s", input_id)
171-
bw_gru[input_id] = [input_id, n]
172-
else:
173-
logger.debug("find fw gru %s", input_id)
174-
fw_gru[input_id] = [input_id, n]
175-
176-
# when fw_gru has same input as bw_gru, then it may be a bi gru
177-
bigru_input = list(set(fw_gru.keys()).intersection(bw_gru.keys()))
178-
bi_grus = [(fw_gru[input_id], bw_gru[input_id])
179-
for input_id in bigru_input]
118+
bi_grus = find_bidirectional_rnns(g, ops, ONNX_RNN_TYPE.GRU)
180119

181120
return process_bigru(g, bi_grus)

0 commit comments

Comments
 (0)