@@ -554,50 +554,92 @@ def test_dynamic_bilstm_state_consumed_only(self, state_is_tuple=True):
554
554
self .run_test_case (feed_dict , input_names_with_port , output_names_with_port , rtol = 1e-06 ,
555
555
graph_validator = lambda g : check_lstm_count (g , 1 ))
556
556
557
- def test_dynamic_multi_bilstm_with_same_input_state_is_tuple (self ):
558
- self .internal_test_dynamic_multi_bilstm_with_same_input (True )
559
-
560
- def test_dynamic_multi_bilstm_with_same_input_state_is_list (self ):
561
- self .internal_test_dynamic_multi_bilstm_with_same_input (False )
562
-
563
- def internal_test_dynamic_multi_bilstm_with_same_input (self , state_is_tuple ):
557
+ def test_dynamic_bilstm_outputs_partially_consumed (self , state_is_tuple = True ):
564
558
units = 5
565
- batch_size = 1
559
+ batch_size = 6
566
560
x_val = np .array ([[1. , 1. ], [2. , 2. ], [3. , 3. ]], dtype = np .float32 )
567
561
x_val = np .stack ([x_val ] * batch_size )
568
562
569
563
x = tf .placeholder (tf .float32 , x_val .shape , name = "input_1" )
570
564
initializer = init_ops .constant_initializer (0.5 )
571
565
566
+ # bilstm, no scope
572
567
cell1 = rnn .LSTMCell (
573
568
units ,
574
569
initializer = initializer ,
575
- state_is_tuple = state_is_tuple
576
- )
570
+ state_is_tuple = state_is_tuple ) # state_is_tuple will impact Pack node (for cell_state)'s usage pattern
577
571
cell2 = rnn .LSTMCell (
578
572
units ,
579
573
initializer = initializer ,
580
- state_is_tuple = state_is_tuple
581
- )
582
- outputs_1 , cell_state_1 = tf .nn .bidirectional_dynamic_rnn (
574
+ state_is_tuple = state_is_tuple )
575
+ (output_fw , _ ), (_ , state_bw ) = tf .nn .bidirectional_dynamic_rnn (
583
576
cell1 ,
584
577
cell2 ,
585
578
x ,
586
- dtype = tf .float32 ,
587
- scope = "bilstm_1"
588
- )
579
+ dtype = tf .float32 )
580
+
581
+ _ = tf .identity (output_fw , name = "output" )
582
+ _ = tf .identity (state_bw , name = "cell_state" )
583
+
584
+ feed_dict = {"input_1:0" : x_val }
585
+ input_names_with_port = ["input_1:0" ]
586
+ output_names_with_port = ["output:0" , "cell_state:0" ]
587
+ self .run_test_case (feed_dict , input_names_with_port , output_names_with_port , rtol = 1e-06 ,
588
+ graph_validator = lambda g : check_lstm_count (g , 1 ))
589
+
590
+ def test_dynamic_bilstm_unknown_batch_size (self , state_is_tuple = True ):
591
+ units = 5
592
+ batch_size = 6
593
+ x_val = np .array ([[1. , 1. ], [2. , 2. ], [3. , 3. ]], dtype = np .float32 )
594
+ x_val = np .stack ([x_val ] * batch_size )
595
+
596
+ x = tf .placeholder (tf .float32 , [None , 3 , 2 ], name = "input_1" )
597
+ initializer = init_ops .constant_initializer (0.5 )
589
598
590
- units = 10
591
599
cell1 = rnn .LSTMCell (
592
600
units ,
593
601
initializer = initializer ,
594
- state_is_tuple = state_is_tuple
595
- )
602
+ state_is_tuple = state_is_tuple )
596
603
cell2 = rnn .LSTMCell (
597
604
units ,
598
605
initializer = initializer ,
599
- state_is_tuple = state_is_tuple
606
+ state_is_tuple = state_is_tuple )
607
+ _ , cell_state = tf .nn .bidirectional_dynamic_rnn (
608
+ cell1 ,
609
+ cell2 ,
610
+ x ,
611
+ dtype = tf .float32 ,
612
+ )
613
+
614
+ _ = tf .identity (cell_state , name = "cell_state" )
615
+
616
+ feed_dict = {"input_1:0" : x_val }
617
+ input_names_with_port = ["input_1:0" ]
618
+ output_names_with_port = ["cell_state:0" ]
619
+ self .run_test_case (feed_dict , input_names_with_port , output_names_with_port , rtol = 1e-06 ,
620
+ graph_validator = lambda g : check_lstm_count (g , 1 ))
621
+
622
+ def test_dynamic_multi_bilstm_with_same_input_hidden_size (self ):
623
+ units = 5
624
+ batch_size = 10
625
+ x_val = np .array ([[1. , 1. ], [2. , 2. ], [3. , 3. ]], dtype = np .float32 )
626
+ x_val = np .stack ([x_val ] * batch_size )
627
+
628
+ x = tf .placeholder (tf .float32 , x_val .shape , name = "input_1" )
629
+
630
+ cell1 = rnn .LSTMCell (units )
631
+ cell2 = rnn .LSTMCell (units )
632
+ outputs_1 , cell_state_1 = tf .nn .bidirectional_dynamic_rnn (
633
+ cell1 ,
634
+ cell2 ,
635
+ x ,
636
+ dtype = tf .float32 ,
637
+ scope = "bilstm_1"
600
638
)
639
+
640
+ units = 10
641
+ cell1 = rnn .LSTMCell (units )
642
+ cell2 = rnn .LSTMCell (units )
601
643
outputs_2 , cell_state_2 = tf .nn .bidirectional_dynamic_rnn (
602
644
cell1 ,
603
645
cell2 ,
@@ -617,6 +659,52 @@ def internal_test_dynamic_multi_bilstm_with_same_input(self, state_is_tuple):
617
659
self .run_test_case (feed_dict , input_names_with_port , output_names_with_port , rtol = 1e-3 , atol = 1e-06 ,
618
660
graph_validator = lambda g : check_lstm_count (g , 2 ))
619
661
662
+ def test_dynamic_multi_bilstm_with_same_input_seq_len (self ):
663
+ units = 5
664
+ batch_size = 10
665
+ x_val = np .array ([[1. , 1. ], [2. , 2. ], [3. , 3. ]], dtype = np .float32 )
666
+ x_val = np .stack ([x_val ] * batch_size )
667
+ seq_len_val = np .array ([3 ], dtype = np .int32 )
668
+
669
+ x = tf .placeholder (tf .float32 , x_val .shape , name = "input_1" )
670
+
671
+ y1 = tf .placeholder (tf .int32 , seq_len_val .shape , name = "input_2" )
672
+ seq_len1 = tf .tile (y1 , [batch_size ])
673
+ cell1 = rnn .LSTMCell (units )
674
+ cell2 = rnn .LSTMCell (units )
675
+ outputs_1 , cell_state_1 = tf .nn .bidirectional_dynamic_rnn (
676
+ cell1 ,
677
+ cell2 ,
678
+ x ,
679
+ sequence_length = seq_len1 ,
680
+ dtype = tf .float32 ,
681
+ scope = "bilstm_1"
682
+ )
683
+
684
+ y2 = tf .placeholder (tf .int32 , seq_len_val .shape , name = "input_3" )
685
+ seq_len2 = tf .tile (y2 , [batch_size ])
686
+ cell1 = rnn .LSTMCell (units )
687
+ cell2 = rnn .LSTMCell (units )
688
+ outputs_2 , cell_state_2 = tf .nn .bidirectional_dynamic_rnn (
689
+ cell1 ,
690
+ cell2 ,
691
+ x ,
692
+ sequence_length = seq_len2 ,
693
+ dtype = tf .float32 ,
694
+ scope = "bilstm_2"
695
+ )
696
+
697
+ _ = tf .identity (outputs_1 , name = "output_1" )
698
+ _ = tf .identity (cell_state_1 , name = "cell_state_1" )
699
+ _ = tf .identity (outputs_2 , name = "output_2" )
700
+ _ = tf .identity (cell_state_2 , name = "cell_state_2" )
701
+
702
+ feed_dict = {"input_1:0" : x_val , "input_2:0" : seq_len_val , "input_3:0" : seq_len_val }
703
+ input_names_with_port = ["input_1:0" , "input_2:0" , "input_3:0" ]
704
+ output_names_with_port = ["output_1:0" , "cell_state_1:0" , "output_2:0" , "cell_state_2:0" ]
705
+ self .run_test_case (feed_dict , input_names_with_port , output_names_with_port , rtol = 1e-3 , atol = 1e-06 ,
706
+ graph_validator = lambda g : check_lstm_count (g , 2 ))
707
+
620
708
621
709
if __name__ == '__main__' :
622
710
unittest_main ()
0 commit comments