Skip to content

Commit 444cc55

Browse files
pair bi-rnn by name hidden_size and seq_len
1 parent 03cf6cd commit 444cc55

File tree

12 files changed

+443
-284
lines changed

12 files changed

+443
-284
lines changed

tests/test_gru.py

Lines changed: 102 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -482,22 +482,67 @@ def test_dynamic_bidirectional_but_one_gru_and_state_consumed_only(self):
482482
self.run_test_case(feed_dict, input_names_with_port, output_names_with_port, rtol=1e-3, atol=1e-06,
483483
graph_validator=lambda g: check_gru_count(g, 1))
484484

485-
def test_dynamic_multi_bigru_with_same_input(self):
485+
def test_dynamic_bigru_unknown_batch_size(self):
486486
units = 5
487-
batch_size = 1
487+
batch_size = 6
488+
x_val = np.array([[1., 1.], [2., 2.], [3., 3.]], dtype=np.float32)
489+
x_val = np.stack([x_val] * batch_size)
490+
491+
x = tf.placeholder(tf.float32, [None, 3, 2], name="input_1")
492+
493+
cell1 = rnn.GRUCell(units)
494+
cell2 = rnn.GRUCell(units)
495+
_, cell_state = tf.nn.bidirectional_dynamic_rnn(
496+
cell1,
497+
cell2,
498+
x,
499+
dtype=tf.float32,
500+
)
501+
502+
_ = tf.identity(cell_state, name="cell_state")
503+
504+
feed_dict = {"input_1:0": x_val}
505+
input_names_with_port = ["input_1:0"]
506+
output_names_with_port = ["cell_state:0"]
507+
self.run_test_case(feed_dict, input_names_with_port, output_names_with_port, rtol=1e-06,
508+
graph_validator=lambda g: check_gru_count(g, 1))
509+
510+
def test_dynamic_bigru_outputs_partially_consumed(self):
511+
units = 5
512+
batch_size = 6
513+
x_val = np.array([[1., 1.], [2., 2.], [3., 3.]], dtype=np.float32)
514+
x_val = np.stack([x_val] * batch_size)
515+
516+
x = tf.placeholder(tf.float32, x_val.shape, name="input_1")
517+
518+
cell1 = rnn.GRUCell(units)
519+
cell2 = rnn.GRUCell(units)
520+
(output_fw, _), (_, state_bw) = tf.nn.bidirectional_dynamic_rnn(
521+
cell1,
522+
cell2,
523+
x,
524+
dtype=tf.float32)
525+
526+
_ = tf.identity(output_fw, name="output")
527+
_ = tf.identity(state_bw, name="cell_state")
528+
529+
feed_dict = {"input_1:0": x_val}
530+
input_names_with_port = ["input_1:0"]
531+
output_names_with_port = ["output:0", "cell_state:0"]
532+
self.run_test_case(feed_dict, input_names_with_port, output_names_with_port, rtol=1e-06,
533+
graph_validator=lambda g: check_gru_count(g, 1))
534+
535+
def test_dynamic_multi_bigru_with_same_input_hidden_size(self):
536+
units = 5
537+
batch_size = 10
488538
x_val = np.array([[1., 1.], [2., 2.], [3., 3.]], dtype=np.float32)
489539
x_val = np.stack([x_val] * batch_size)
490540

491541
x = tf.placeholder(tf.float32, x_val.shape, name="input_1")
492-
initializer = init_ops.constant_initializer(0.5)
493542

494543
# bigru, no scope
495-
cell1 = rnn.GRUCell(
496-
units,
497-
kernel_initializer=initializer)
498-
cell2 = rnn.GRUCell(
499-
units,
500-
kernel_initializer=initializer)
544+
cell1 = rnn.GRUCell(units)
545+
cell2 = rnn.GRUCell(units)
501546
outputs_1, cell_state_1 = tf.nn.bidirectional_dynamic_rnn(
502547
cell1,
503548
cell2,
@@ -507,12 +552,8 @@ def test_dynamic_multi_bigru_with_same_input(self):
507552
)
508553

509554
units = 10
510-
cell1 = rnn.GRUCell(
511-
units,
512-
kernel_initializer=initializer)
513-
cell2 = rnn.GRUCell(
514-
units,
515-
kernel_initializer=initializer)
555+
cell1 = rnn.GRUCell(units)
556+
cell2 = rnn.GRUCell(units)
516557
outputs_2, cell_state_2 = tf.nn.bidirectional_dynamic_rnn(
517558
cell1,
518559
cell2,
@@ -532,6 +573,52 @@ def test_dynamic_multi_bigru_with_same_input(self):
532573
self.run_test_case(feed_dict, input_names_with_port, output_names_with_port, rtol=1e-3, atol=1e-06,
533574
graph_validator=lambda g: check_gru_count(g, 2))
534575

576+
def test_dynamic_multi_bigru_with_same_input_seq_len(self):
577+
units = 5
578+
batch_size = 10
579+
x_val = np.array([[1., 1.], [2., 2.], [3., 3.]], dtype=np.float32)
580+
x_val = np.stack([x_val] * batch_size)
581+
seq_len_val = np.array([3], dtype=np.int32)
582+
583+
x = tf.placeholder(tf.float32, x_val.shape, name="input_1")
584+
585+
y1 = tf.placeholder(tf.int32, seq_len_val.shape, name="input_2")
586+
seq_len1 = tf.tile(y1, [batch_size])
587+
cell1 = rnn.GRUCell(units)
588+
cell2 = rnn.GRUCell(units)
589+
outputs_1, cell_state_1 = tf.nn.bidirectional_dynamic_rnn(
590+
cell1,
591+
cell2,
592+
x,
593+
sequence_length=seq_len1,
594+
dtype=tf.float32,
595+
scope="bigru_1"
596+
)
597+
598+
y2 = tf.placeholder(tf.int32, seq_len_val.shape, name="input_3")
599+
seq_len2 = tf.tile(y2, [batch_size])
600+
cell1 = rnn.GRUCell(units)
601+
cell2 = rnn.GRUCell(units)
602+
outputs_2, cell_state_2 = tf.nn.bidirectional_dynamic_rnn(
603+
cell1,
604+
cell2,
605+
x,
606+
sequence_length=seq_len2,
607+
dtype=tf.float32,
608+
scope="bigru_2"
609+
)
610+
611+
_ = tf.identity(outputs_1, name="output_1")
612+
_ = tf.identity(cell_state_1, name="cell_state_1")
613+
_ = tf.identity(outputs_2, name="output_2")
614+
_ = tf.identity(cell_state_2, name="cell_state_2")
615+
616+
feed_dict = {"input_1:0": x_val, "input_2:0": seq_len_val, "input_3:0": seq_len_val}
617+
input_names_with_port = ["input_1:0", "input_2:0", "input_3:0"]
618+
output_names_with_port = ["output_1:0", "cell_state_1:0", "output_2:0", "cell_state_2:0"]
619+
self.run_test_case(feed_dict, input_names_with_port, output_names_with_port, rtol=1e-3, atol=1e-06,
620+
graph_validator=lambda g: check_gru_count(g, 2))
621+
535622

536623
if __name__ == '__main__':
537624
unittest_main()

tests/test_grublock.py

Lines changed: 0 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -447,47 +447,6 @@ def test_dynamic_bidirectional_but_one_gru_and_state_consumed_only(self):
447447
self.run_test_case(feed_dict, input_names_with_port, output_names_with_port, rtol=1e-3, atol=1e-06,
448448
graph_validator=lambda g: check_gru_count(g, 1))
449449

450-
def test_dynamic_multi_bigru_with_same_input(self):
451-
units = 5
452-
batch_size = 1
453-
x_val = np.array([[1., 1.], [2., 2.], [3., 3.]], dtype=np.float32)
454-
x_val = np.stack([x_val] * batch_size)
455-
456-
x = tf.placeholder(tf.float32, x_val.shape, name="input_1")
457-
458-
# bigru, no scope
459-
cell1 = rnn.GRUBlockCell(units)
460-
cell2 = rnn.GRUBlockCell(units)
461-
outputs_1, cell_state_1 = tf.nn.bidirectional_dynamic_rnn(
462-
cell1,
463-
cell2,
464-
x,
465-
dtype=tf.float32,
466-
scope="bigru_1"
467-
)
468-
469-
units = 10
470-
cell1 = rnn.GRUBlockCell(units)
471-
cell2 = rnn.GRUBlockCell(units)
472-
outputs_2, cell_state_2 = tf.nn.bidirectional_dynamic_rnn(
473-
cell1,
474-
cell2,
475-
x,
476-
dtype=tf.float32,
477-
scope="bigru_2"
478-
)
479-
480-
_ = tf.identity(outputs_1, name="output_1")
481-
_ = tf.identity(cell_state_1, name="cell_state_1")
482-
_ = tf.identity(outputs_2, name="output_2")
483-
_ = tf.identity(cell_state_2, name="cell_state_2")
484-
485-
feed_dict = {"input_1:0": x_val}
486-
input_names_with_port = ["input_1:0"]
487-
output_names_with_port = ["output_1:0", "cell_state_1:0", "output_2:0", "cell_state_2:0"]
488-
self.run_test_case(feed_dict, input_names_with_port, output_names_with_port, rtol=1e-3, atol=1e-06,
489-
graph_validator=lambda g: check_gru_count(g, 2))
490-
491450

492451
if __name__ == '__main__':
493452
unittest_main()

tests/test_lstm.py

Lines changed: 108 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -554,50 +554,92 @@ def test_dynamic_bilstm_state_consumed_only(self, state_is_tuple=True):
554554
self.run_test_case(feed_dict, input_names_with_port, output_names_with_port, rtol=1e-06,
555555
graph_validator=lambda g: check_lstm_count(g, 1))
556556

557-
def test_dynamic_multi_bilstm_with_same_input_state_is_tuple(self):
558-
self.internal_test_dynamic_multi_bilstm_with_same_input(True)
559-
560-
def test_dynamic_multi_bilstm_with_same_input_state_is_list(self):
561-
self.internal_test_dynamic_multi_bilstm_with_same_input(False)
562-
563-
def internal_test_dynamic_multi_bilstm_with_same_input(self, state_is_tuple):
557+
def test_dynamic_bilstm_outputs_partially_consumed(self, state_is_tuple=True):
564558
units = 5
565-
batch_size = 1
559+
batch_size = 6
566560
x_val = np.array([[1., 1.], [2., 2.], [3., 3.]], dtype=np.float32)
567561
x_val = np.stack([x_val] * batch_size)
568562

569563
x = tf.placeholder(tf.float32, x_val.shape, name="input_1")
570564
initializer = init_ops.constant_initializer(0.5)
571565

566+
# bilstm, no scope
572567
cell1 = rnn.LSTMCell(
573568
units,
574569
initializer=initializer,
575-
state_is_tuple=state_is_tuple
576-
)
570+
state_is_tuple=state_is_tuple) # state_is_tuple will impact Pack node (for cell_state)'s usage pattern
577571
cell2 = rnn.LSTMCell(
578572
units,
579573
initializer=initializer,
580-
state_is_tuple=state_is_tuple
581-
)
582-
outputs_1, cell_state_1 = tf.nn.bidirectional_dynamic_rnn(
574+
state_is_tuple=state_is_tuple)
575+
(output_fw, _), (_, state_bw) = tf.nn.bidirectional_dynamic_rnn(
583576
cell1,
584577
cell2,
585578
x,
586-
dtype=tf.float32,
587-
scope="bilstm_1"
588-
)
579+
dtype=tf.float32)
580+
581+
_ = tf.identity(output_fw, name="output")
582+
_ = tf.identity(state_bw, name="cell_state")
583+
584+
feed_dict = {"input_1:0": x_val}
585+
input_names_with_port = ["input_1:0"]
586+
output_names_with_port = ["output:0", "cell_state:0"]
587+
self.run_test_case(feed_dict, input_names_with_port, output_names_with_port, rtol=1e-06,
588+
graph_validator=lambda g: check_lstm_count(g, 1))
589+
590+
def test_dynamic_bilstm_unknown_batch_size(self, state_is_tuple=True):
591+
units = 5
592+
batch_size = 6
593+
x_val = np.array([[1., 1.], [2., 2.], [3., 3.]], dtype=np.float32)
594+
x_val = np.stack([x_val] * batch_size)
595+
596+
x = tf.placeholder(tf.float32, [None, 3, 2], name="input_1")
597+
initializer = init_ops.constant_initializer(0.5)
589598

590-
units = 10
591599
cell1 = rnn.LSTMCell(
592600
units,
593601
initializer=initializer,
594-
state_is_tuple=state_is_tuple
595-
)
602+
state_is_tuple=state_is_tuple)
596603
cell2 = rnn.LSTMCell(
597604
units,
598605
initializer=initializer,
599-
state_is_tuple=state_is_tuple
606+
state_is_tuple=state_is_tuple)
607+
_, cell_state = tf.nn.bidirectional_dynamic_rnn(
608+
cell1,
609+
cell2,
610+
x,
611+
dtype=tf.float32,
612+
)
613+
614+
_ = tf.identity(cell_state, name="cell_state")
615+
616+
feed_dict = {"input_1:0": x_val}
617+
input_names_with_port = ["input_1:0"]
618+
output_names_with_port = ["cell_state:0"]
619+
self.run_test_case(feed_dict, input_names_with_port, output_names_with_port, rtol=1e-06,
620+
graph_validator=lambda g: check_lstm_count(g, 1))
621+
622+
def test_dynamic_multi_bilstm_with_same_input_hidden_size(self):
623+
units = 5
624+
batch_size = 10
625+
x_val = np.array([[1., 1.], [2., 2.], [3., 3.]], dtype=np.float32)
626+
x_val = np.stack([x_val] * batch_size)
627+
628+
x = tf.placeholder(tf.float32, x_val.shape, name="input_1")
629+
630+
cell1 = rnn.LSTMCell(units)
631+
cell2 = rnn.LSTMCell(units)
632+
outputs_1, cell_state_1 = tf.nn.bidirectional_dynamic_rnn(
633+
cell1,
634+
cell2,
635+
x,
636+
dtype=tf.float32,
637+
scope="bilstm_1"
600638
)
639+
640+
units = 10
641+
cell1 = rnn.LSTMCell(units)
642+
cell2 = rnn.LSTMCell(units)
601643
outputs_2, cell_state_2 = tf.nn.bidirectional_dynamic_rnn(
602644
cell1,
603645
cell2,
@@ -617,6 +659,52 @@ def internal_test_dynamic_multi_bilstm_with_same_input(self, state_is_tuple):
617659
self.run_test_case(feed_dict, input_names_with_port, output_names_with_port, rtol=1e-3, atol=1e-06,
618660
graph_validator=lambda g: check_lstm_count(g, 2))
619661

662+
def test_dynamic_multi_bilstm_with_same_input_seq_len(self):
663+
units = 5
664+
batch_size = 10
665+
x_val = np.array([[1., 1.], [2., 2.], [3., 3.]], dtype=np.float32)
666+
x_val = np.stack([x_val] * batch_size)
667+
seq_len_val = np.array([3], dtype=np.int32)
668+
669+
x = tf.placeholder(tf.float32, x_val.shape, name="input_1")
670+
671+
y1 = tf.placeholder(tf.int32, seq_len_val.shape, name="input_2")
672+
seq_len1 = tf.tile(y1, [batch_size])
673+
cell1 = rnn.LSTMCell(units)
674+
cell2 = rnn.LSTMCell(units)
675+
outputs_1, cell_state_1 = tf.nn.bidirectional_dynamic_rnn(
676+
cell1,
677+
cell2,
678+
x,
679+
sequence_length=seq_len1,
680+
dtype=tf.float32,
681+
scope="bilstm_1"
682+
)
683+
684+
y2 = tf.placeholder(tf.int32, seq_len_val.shape, name="input_3")
685+
seq_len2 = tf.tile(y2, [batch_size])
686+
cell1 = rnn.LSTMCell(units)
687+
cell2 = rnn.LSTMCell(units)
688+
outputs_2, cell_state_2 = tf.nn.bidirectional_dynamic_rnn(
689+
cell1,
690+
cell2,
691+
x,
692+
sequence_length=seq_len2,
693+
dtype=tf.float32,
694+
scope="bilstm_2"
695+
)
696+
697+
_ = tf.identity(outputs_1, name="output_1")
698+
_ = tf.identity(cell_state_1, name="cell_state_1")
699+
_ = tf.identity(outputs_2, name="output_2")
700+
_ = tf.identity(cell_state_2, name="cell_state_2")
701+
702+
feed_dict = {"input_1:0": x_val, "input_2:0": seq_len_val, "input_3:0": seq_len_val}
703+
input_names_with_port = ["input_1:0", "input_2:0", "input_3:0"]
704+
output_names_with_port = ["output_1:0", "cell_state_1:0", "output_2:0", "cell_state_2:0"]
705+
self.run_test_case(feed_dict, input_names_with_port, output_names_with_port, rtol=1e-3, atol=1e-06,
706+
graph_validator=lambda g: check_lstm_count(g, 2))
707+
620708

621709
if __name__ == '__main__':
622710
unittest_main()

0 commit comments

Comments
 (0)