Skip to content

Commit ffe6792

Browse files
authored
Merge pull request #506 from lucienwang1009/multi_birnn
support multi birnn share the same input
2 parents ef0af82 + 444cc55 commit ffe6792

File tree

10 files changed

+643
-301
lines changed

10 files changed

+643
-301
lines changed

tests/test_gru.py

Lines changed: 137 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -482,6 +482,143 @@ def test_dynamic_bidirectional_but_one_gru_and_state_consumed_only(self):
482482
self.run_test_case(feed_dict, input_names_with_port, output_names_with_port, rtol=1e-3, atol=1e-06,
483483
graph_validator=lambda g: check_gru_count(g, 1))
484484

485+
def test_dynamic_bigru_unknown_batch_size(self):
486+
units = 5
487+
batch_size = 6
488+
x_val = np.array([[1., 1.], [2., 2.], [3., 3.]], dtype=np.float32)
489+
x_val = np.stack([x_val] * batch_size)
490+
491+
x = tf.placeholder(tf.float32, [None, 3, 2], name="input_1")
492+
493+
cell1 = rnn.GRUCell(units)
494+
cell2 = rnn.GRUCell(units)
495+
_, cell_state = tf.nn.bidirectional_dynamic_rnn(
496+
cell1,
497+
cell2,
498+
x,
499+
dtype=tf.float32,
500+
)
501+
502+
_ = tf.identity(cell_state, name="cell_state")
503+
504+
feed_dict = {"input_1:0": x_val}
505+
input_names_with_port = ["input_1:0"]
506+
output_names_with_port = ["cell_state:0"]
507+
self.run_test_case(feed_dict, input_names_with_port, output_names_with_port, rtol=1e-06,
508+
graph_validator=lambda g: check_gru_count(g, 1))
509+
510+
def test_dynamic_bigru_outputs_partially_consumed(self):
511+
units = 5
512+
batch_size = 6
513+
x_val = np.array([[1., 1.], [2., 2.], [3., 3.]], dtype=np.float32)
514+
x_val = np.stack([x_val] * batch_size)
515+
516+
x = tf.placeholder(tf.float32, x_val.shape, name="input_1")
517+
518+
cell1 = rnn.GRUCell(units)
519+
cell2 = rnn.GRUCell(units)
520+
(output_fw, _), (_, state_bw) = tf.nn.bidirectional_dynamic_rnn(
521+
cell1,
522+
cell2,
523+
x,
524+
dtype=tf.float32)
525+
526+
_ = tf.identity(output_fw, name="output")
527+
_ = tf.identity(state_bw, name="cell_state")
528+
529+
feed_dict = {"input_1:0": x_val}
530+
input_names_with_port = ["input_1:0"]
531+
output_names_with_port = ["output:0", "cell_state:0"]
532+
self.run_test_case(feed_dict, input_names_with_port, output_names_with_port, rtol=1e-06,
533+
graph_validator=lambda g: check_gru_count(g, 1))
534+
535+
def test_dynamic_multi_bigru_with_same_input_hidden_size(self):
536+
units = 5
537+
batch_size = 10
538+
x_val = np.array([[1., 1.], [2., 2.], [3., 3.]], dtype=np.float32)
539+
x_val = np.stack([x_val] * batch_size)
540+
541+
x = tf.placeholder(tf.float32, x_val.shape, name="input_1")
542+
543+
# bigru, no scope
544+
cell1 = rnn.GRUCell(units)
545+
cell2 = rnn.GRUCell(units)
546+
outputs_1, cell_state_1 = tf.nn.bidirectional_dynamic_rnn(
547+
cell1,
548+
cell2,
549+
x,
550+
dtype=tf.float32,
551+
scope="bigru_1"
552+
)
553+
554+
units = 10
555+
cell1 = rnn.GRUCell(units)
556+
cell2 = rnn.GRUCell(units)
557+
outputs_2, cell_state_2 = tf.nn.bidirectional_dynamic_rnn(
558+
cell1,
559+
cell2,
560+
x,
561+
dtype=tf.float32,
562+
scope="bigru_2"
563+
)
564+
565+
_ = tf.identity(outputs_1, name="output_1")
566+
_ = tf.identity(cell_state_1, name="cell_state_1")
567+
_ = tf.identity(outputs_2, name="output_2")
568+
_ = tf.identity(cell_state_2, name="cell_state_2")
569+
570+
feed_dict = {"input_1:0": x_val}
571+
input_names_with_port = ["input_1:0"]
572+
output_names_with_port = ["output_1:0", "cell_state_1:0", "output_2:0", "cell_state_2:0"]
573+
self.run_test_case(feed_dict, input_names_with_port, output_names_with_port, rtol=1e-3, atol=1e-06,
574+
graph_validator=lambda g: check_gru_count(g, 2))
575+
576+
def test_dynamic_multi_bigru_with_same_input_seq_len(self):
577+
units = 5
578+
batch_size = 10
579+
x_val = np.array([[1., 1.], [2., 2.], [3., 3.]], dtype=np.float32)
580+
x_val = np.stack([x_val] * batch_size)
581+
seq_len_val = np.array([3], dtype=np.int32)
582+
583+
x = tf.placeholder(tf.float32, x_val.shape, name="input_1")
584+
585+
y1 = tf.placeholder(tf.int32, seq_len_val.shape, name="input_2")
586+
seq_len1 = tf.tile(y1, [batch_size])
587+
cell1 = rnn.GRUCell(units)
588+
cell2 = rnn.GRUCell(units)
589+
outputs_1, cell_state_1 = tf.nn.bidirectional_dynamic_rnn(
590+
cell1,
591+
cell2,
592+
x,
593+
sequence_length=seq_len1,
594+
dtype=tf.float32,
595+
scope="bigru_1"
596+
)
597+
598+
y2 = tf.placeholder(tf.int32, seq_len_val.shape, name="input_3")
599+
seq_len2 = tf.tile(y2, [batch_size])
600+
cell1 = rnn.GRUCell(units)
601+
cell2 = rnn.GRUCell(units)
602+
outputs_2, cell_state_2 = tf.nn.bidirectional_dynamic_rnn(
603+
cell1,
604+
cell2,
605+
x,
606+
sequence_length=seq_len2,
607+
dtype=tf.float32,
608+
scope="bigru_2"
609+
)
610+
611+
_ = tf.identity(outputs_1, name="output_1")
612+
_ = tf.identity(cell_state_1, name="cell_state_1")
613+
_ = tf.identity(outputs_2, name="output_2")
614+
_ = tf.identity(cell_state_2, name="cell_state_2")
615+
616+
feed_dict = {"input_1:0": x_val, "input_2:0": seq_len_val, "input_3:0": seq_len_val}
617+
input_names_with_port = ["input_1:0", "input_2:0", "input_3:0"]
618+
output_names_with_port = ["output_1:0", "cell_state_1:0", "output_2:0", "cell_state_2:0"]
619+
self.run_test_case(feed_dict, input_names_with_port, output_names_with_port, rtol=1e-3, atol=1e-06,
620+
graph_validator=lambda g: check_gru_count(g, 2))
621+
485622

486623
if __name__ == '__main__':
487624
unittest_main()

tests/test_lstm.py

Lines changed: 151 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -554,6 +554,157 @@ def test_dynamic_bilstm_state_consumed_only(self, state_is_tuple=True):
554554
self.run_test_case(feed_dict, input_names_with_port, output_names_with_port, rtol=1e-06,
555555
graph_validator=lambda g: check_lstm_count(g, 1))
556556

557+
def test_dynamic_bilstm_outputs_partially_consumed(self, state_is_tuple=True):
558+
units = 5
559+
batch_size = 6
560+
x_val = np.array([[1., 1.], [2., 2.], [3., 3.]], dtype=np.float32)
561+
x_val = np.stack([x_val] * batch_size)
562+
563+
x = tf.placeholder(tf.float32, x_val.shape, name="input_1")
564+
initializer = init_ops.constant_initializer(0.5)
565+
566+
# bilstm, no scope
567+
cell1 = rnn.LSTMCell(
568+
units,
569+
initializer=initializer,
570+
state_is_tuple=state_is_tuple) # state_is_tuple will impact Pack node (for cell_state)'s usage pattern
571+
cell2 = rnn.LSTMCell(
572+
units,
573+
initializer=initializer,
574+
state_is_tuple=state_is_tuple)
575+
(output_fw, _), (_, state_bw) = tf.nn.bidirectional_dynamic_rnn(
576+
cell1,
577+
cell2,
578+
x,
579+
dtype=tf.float32)
580+
581+
_ = tf.identity(output_fw, name="output")
582+
_ = tf.identity(state_bw, name="cell_state")
583+
584+
feed_dict = {"input_1:0": x_val}
585+
input_names_with_port = ["input_1:0"]
586+
output_names_with_port = ["output:0", "cell_state:0"]
587+
self.run_test_case(feed_dict, input_names_with_port, output_names_with_port, rtol=1e-06,
588+
graph_validator=lambda g: check_lstm_count(g, 1))
589+
590+
def test_dynamic_bilstm_unknown_batch_size(self, state_is_tuple=True):
591+
units = 5
592+
batch_size = 6
593+
x_val = np.array([[1., 1.], [2., 2.], [3., 3.]], dtype=np.float32)
594+
x_val = np.stack([x_val] * batch_size)
595+
596+
x = tf.placeholder(tf.float32, [None, 3, 2], name="input_1")
597+
initializer = init_ops.constant_initializer(0.5)
598+
599+
cell1 = rnn.LSTMCell(
600+
units,
601+
initializer=initializer,
602+
state_is_tuple=state_is_tuple)
603+
cell2 = rnn.LSTMCell(
604+
units,
605+
initializer=initializer,
606+
state_is_tuple=state_is_tuple)
607+
_, cell_state = tf.nn.bidirectional_dynamic_rnn(
608+
cell1,
609+
cell2,
610+
x,
611+
dtype=tf.float32,
612+
)
613+
614+
_ = tf.identity(cell_state, name="cell_state")
615+
616+
feed_dict = {"input_1:0": x_val}
617+
input_names_with_port = ["input_1:0"]
618+
output_names_with_port = ["cell_state:0"]
619+
self.run_test_case(feed_dict, input_names_with_port, output_names_with_port, rtol=1e-06,
620+
graph_validator=lambda g: check_lstm_count(g, 1))
621+
622+
def test_dynamic_multi_bilstm_with_same_input_hidden_size(self):
623+
units = 5
624+
batch_size = 10
625+
x_val = np.array([[1., 1.], [2., 2.], [3., 3.]], dtype=np.float32)
626+
x_val = np.stack([x_val] * batch_size)
627+
628+
x = tf.placeholder(tf.float32, x_val.shape, name="input_1")
629+
630+
cell1 = rnn.LSTMCell(units)
631+
cell2 = rnn.LSTMCell(units)
632+
outputs_1, cell_state_1 = tf.nn.bidirectional_dynamic_rnn(
633+
cell1,
634+
cell2,
635+
x,
636+
dtype=tf.float32,
637+
scope="bilstm_1"
638+
)
639+
640+
units = 10
641+
cell1 = rnn.LSTMCell(units)
642+
cell2 = rnn.LSTMCell(units)
643+
outputs_2, cell_state_2 = tf.nn.bidirectional_dynamic_rnn(
644+
cell1,
645+
cell2,
646+
x,
647+
dtype=tf.float32,
648+
scope="bilstm_2"
649+
)
650+
651+
_ = tf.identity(outputs_1, name="output_1")
652+
_ = tf.identity(cell_state_1, name="cell_state_1")
653+
_ = tf.identity(outputs_2, name="output_2")
654+
_ = tf.identity(cell_state_2, name="cell_state_2")
655+
656+
feed_dict = {"input_1:0": x_val}
657+
input_names_with_port = ["input_1:0"]
658+
output_names_with_port = ["output_1:0", "cell_state_1:0", "output_2:0", "cell_state_2:0"]
659+
self.run_test_case(feed_dict, input_names_with_port, output_names_with_port, rtol=1e-3, atol=1e-06,
660+
graph_validator=lambda g: check_lstm_count(g, 2))
661+
662+
def test_dynamic_multi_bilstm_with_same_input_seq_len(self):
663+
units = 5
664+
batch_size = 10
665+
x_val = np.array([[1., 1.], [2., 2.], [3., 3.]], dtype=np.float32)
666+
x_val = np.stack([x_val] * batch_size)
667+
seq_len_val = np.array([3], dtype=np.int32)
668+
669+
x = tf.placeholder(tf.float32, x_val.shape, name="input_1")
670+
671+
y1 = tf.placeholder(tf.int32, seq_len_val.shape, name="input_2")
672+
seq_len1 = tf.tile(y1, [batch_size])
673+
cell1 = rnn.LSTMCell(units)
674+
cell2 = rnn.LSTMCell(units)
675+
outputs_1, cell_state_1 = tf.nn.bidirectional_dynamic_rnn(
676+
cell1,
677+
cell2,
678+
x,
679+
sequence_length=seq_len1,
680+
dtype=tf.float32,
681+
scope="bilstm_1"
682+
)
683+
684+
y2 = tf.placeholder(tf.int32, seq_len_val.shape, name="input_3")
685+
seq_len2 = tf.tile(y2, [batch_size])
686+
cell1 = rnn.LSTMCell(units)
687+
cell2 = rnn.LSTMCell(units)
688+
outputs_2, cell_state_2 = tf.nn.bidirectional_dynamic_rnn(
689+
cell1,
690+
cell2,
691+
x,
692+
sequence_length=seq_len2,
693+
dtype=tf.float32,
694+
scope="bilstm_2"
695+
)
696+
697+
_ = tf.identity(outputs_1, name="output_1")
698+
_ = tf.identity(cell_state_1, name="cell_state_1")
699+
_ = tf.identity(outputs_2, name="output_2")
700+
_ = tf.identity(cell_state_2, name="cell_state_2")
701+
702+
feed_dict = {"input_1:0": x_val, "input_2:0": seq_len_val, "input_3:0": seq_len_val}
703+
input_names_with_port = ["input_1:0", "input_2:0", "input_3:0"]
704+
output_names_with_port = ["output_1:0", "cell_state_1:0", "output_2:0", "cell_state_2:0"]
705+
self.run_test_case(feed_dict, input_names_with_port, output_names_with_port, rtol=1e-3, atol=1e-06,
706+
graph_validator=lambda g: check_lstm_count(g, 2))
707+
557708

558709
if __name__ == '__main__':
559710
unittest_main()

tf2onnx/graph.py

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -179,17 +179,29 @@ def get_attr(self, name, default=None):
179179
attr = self.attr.get(name, default)
180180
return attr
181181

182+
def get_attr_value(self, name, default=None):
183+
attr = self.get_attr(name)
184+
if attr:
185+
return helper.get_attribute_value(attr)
186+
return default
187+
182188
def get_attr_int(self, name):
183189
"""Get attribute value as int."""
184-
attr = self.get_attr(name)
185-
utils.make_sure(attr is not None, "attribute %s is None", name)
186-
attr = attr.i
187-
return attr
190+
attr_int = self.get_attr_value(name)
191+
utils.make_sure(
192+
attr_int is not None and isinstance(attr_int, int),
193+
"attribute %s is None", name
194+
)
195+
return attr_int
188196

189197
def get_attr_str(self, name, encoding="utf-8"):
190198
"""Get attribute value as string."""
191-
attr = self.get_attr(name)
192-
return attr.s.decode(encoding) if attr else None
199+
attr_str = self.get_attr_value(name)
200+
utils.make_sure(
201+
attr_str is not None and isinstance(attr_str, bytes),
202+
"attribute %s is None", name
203+
)
204+
return attr_str.decode(encoding)
193205

194206
def set_attr(self, name, value):
195207
self.attr[name] = helper.make_attribute(name, value)

tf2onnx/graph_matcher.py

Lines changed: 0 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -20,13 +20,10 @@
2020
from __future__ import unicode_literals
2121

2222
import copy
23-
import logging
2423

2524
import six
2625

2726

28-
logger = logging.getLogger(__name__)
29-
3027

3128
class OpTypePattern(object):
3229
"""A tree pattern that matches TF expressions with certain op types."""
@@ -161,12 +158,6 @@ def _match_pattern(self, pattern, op, tensor):
161158

162159
if pattern.op_type != '*':
163160
if op is None or op.type not in pattern.op_type.split('|'):
164-
logger.debug(
165-
"mismatched type at %s: [%s, %s]",
166-
op.name if op else "None",
167-
pattern.op_type,
168-
op.type if op else "None"
169-
)
170161
return False
171162

172163
self._match_result.add(pattern, op, tensor)
@@ -177,12 +168,6 @@ def _match_pattern(self, pattern, op, tensor):
177168
return True
178169

179170
if not op or len(op.inputs) != len(pattern.inputs):
180-
logger.debug(
181-
"mismatched input number at %s: [%s, %s]",
182-
op.name if op else "None",
183-
len(pattern.inputs),
184-
len(op.inputs)
185-
)
186171
return False
187172

188173
if self._allow_reorder:
@@ -219,7 +204,6 @@ def match_op(self, op):
219204
Returns a `MatchResult` if `op` matches the pattern; otherwise, returns
220205
None.
221206
"""
222-
logger.debug("match %s against the pattern", op.name)
223207
self._match_result = MatchResult()
224208
if not self._match_pattern(self._pattern, op, tensor=None):
225209
return None

0 commit comments

Comments
 (0)