12
12
from tensorflow .contrib import rnn
13
13
from tensorflow .python .ops import init_ops
14
14
from backend_test_base import Tf2OnnxBackendTestBase
15
- from common import check_opset_min_version , unittest_main
15
+ from common import unittest_main
16
16
17
17
18
18
# pylint: disable=missing-docstring
@@ -23,8 +23,8 @@ def test_dynamic_decode_maximum_iterations(self):
23
23
num_units = 4
24
24
vocab_size = 5
25
25
embedding_size = 3
26
- GO_SYMBOL = 0
27
- END_SYMBOL = 1
26
+ go_token = 0
27
+ end_token = 1
28
28
29
29
embedding = tf .constant (np .ones ([vocab_size , embedding_size ], dtype = np .float32 ))
30
30
state_val = np .reshape ([np .ones ([num_units ], dtype = np .float32 ) * i for i in range (batch_size )],
@@ -38,8 +38,8 @@ def test_dynamic_decode_maximum_iterations(self):
38
38
39
39
helper = tf .contrib .seq2seq .GreedyEmbeddingHelper (
40
40
embedding = embedding ,
41
- start_tokens = tf .tile ([GO_SYMBOL ], [batch_size ]),
42
- end_token = END_SYMBOL )
41
+ start_tokens = tf .tile ([go_token ], [batch_size ]),
42
+ end_token = end_token )
43
43
44
44
output_layer = tf .layers .Dense (vocab_size , kernel_initializer = initializer )
45
45
decoder = tf .contrib .seq2seq .BasicDecoder (
@@ -65,6 +65,88 @@ def test_dynamic_decode_maximum_iterations(self):
65
65
66
66
self .run_test_case ({}, [], output_names_with_port , atol = 1e-06 , rtol = 1e-6 )
67
67
68
+ def test_dynamic_decode_normal_stop (self ):
69
+ batch_size = 2
70
+ num_units = 4
71
+ vocab_size = 5
72
+ embedding_size = 3
73
+ go_token = 0
74
+ end_token = 1
75
+
76
+ embedding = tf .constant (np .ones ([vocab_size , embedding_size ], dtype = np .float32 ))
77
+ state_val = np .reshape ([np .ones ([num_units ], dtype = np .float32 ) * i for i in range (batch_size )],
78
+ [batch_size , num_units ])
79
+ encoder_state = tf .nn .rnn_cell .LSTMStateTuple (state_val , state_val )
80
+
81
+ cell_initializer = init_ops .constant_initializer (
82
+ np .array ([[- 0.9592235 , 0.42451382 , 0.7437744 , - 0.54485345 , - 0.80763197 ,
83
+ 0.19663906 , - 0.22738314 , 0.7762785 , 0.7464578 , 0.27227187 ,
84
+ 0.7661047 , 0.3596425 , - 0.8528242 , - 0.89316916 , - 0.48946142 ,
85
+ 0.87882376 ],
86
+ [0.86586094 , - 0.75018406 , 0.25992537 , - 0.69368935 , 0.2515502 ,
87
+ - 0.26379275 , 0.8954313 , 0.5759742 , - 0.7753072 , - 0.4388857 ,
88
+ 0.95751476 , - 0.82085776 , - 0.9467752 , - 0.37055635 , - 0.18570113 ,
89
+ - 0.86504984 ],
90
+ [0.02305841 , 0.3850248 , 0.893692 , - 0.6866486 , - 0.83703446 ,
91
+ - 0.9828961 , 0.3989377 , - 0.59993076 , 0.5330808 , 0.6916566 ,
92
+ 0.98468065 , - 0.6047034 , 0.10823512 , 0.34599304 , - 0.7834821 ,
93
+ - 0.7852347 ],
94
+ [0.81643987 , 0.31507468 , - 0.51369476 , - 0.12273741 , 0.9701307 ,
95
+ - 0.79669356 , - 0.34496522 , - 0.88750815 , - 0.17995334 , 0.34707904 ,
96
+ - 0.09201193 , 0.5363934 , - 0.87229705 , - 0.5073328 , - 0.95894027 ,
97
+ 0.5481839 ],
98
+ [- 0.84093595 , - 0.2341497 , - 0.86047816 , 0.43370056 , - 0.39073753 ,
99
+ 0.37730122 , 0.48026466 , 0.3004985 , - 0.60727096 , 0.9043884 ,
100
+ - 0.37619448 , 0.22490788 , - 0.03739262 , 0.61672115 , 0.478899 ,
101
+ - 0.40780973 ],
102
+ [0.31202435 , - 0.22045255 , - 0.6087918 , 0.95115066 , 0.00199413 ,
103
+ - 0.688287 , - 0.1103518 , 0.4169519 , 0.7913246 , - 0.9844644 ,
104
+ - 0.6193857 , 0.38659644 , - 0.4726901 , - 0.44781208 , - 0.5174744 ,
105
+ - 0.605911 ],
106
+ [0.66771054 , 0.34912825 , 0.22297978 , - 0.4990945 , 0.24057317 ,
107
+ - 0.5540829 , 0.92277217 , 0.74939895 , - 0.35278273 , - 0.21587133 ,
108
+ - 0.28613377 , - 0.8794241 , - 0.40119147 , 0.67175174 , - 0.22741508 ,
109
+ 0.37898326 ]], dtype = np .float32 ))
110
+ dense_initializer = init_ops .constant_initializer (
111
+ np .array ([[0.56177187 , - 0.6233454 , 0.73997784 , 0.35032558 , 0.6479795 ],
112
+ [0.6831174 , - 0.34233975 , 0.39330363 , 0.45177555 , - 0.49649096 ],
113
+ [- 0.98890066 , 0.6175642 , 0.09800482 , - 0.6721206 , 0.48805737 ],
114
+ [0.19671416 , 0.2623148 , 0.742548 , 0.13555217 , 0.56009054 ]], dtype = np .float32 ))
115
+
116
+ cell = rnn .LSTMCell (
117
+ num_units = num_units ,
118
+ initializer = cell_initializer ,
119
+ state_is_tuple = True )
120
+
121
+ helper = tf .contrib .seq2seq .GreedyEmbeddingHelper (
122
+ embedding = embedding ,
123
+ start_tokens = tf .tile ([go_token ], [batch_size ]),
124
+ end_token = end_token )
125
+
126
+ output_layer = tf .layers .Dense (vocab_size , kernel_initializer = dense_initializer )
127
+ decoder = tf .contrib .seq2seq .BasicDecoder (
128
+ cell = cell ,
129
+ helper = helper ,
130
+ initial_state = encoder_state ,
131
+ output_layer = output_layer )
132
+
133
+ outputs , state , sequence_lengths = tf .contrib .seq2seq .dynamic_decode (
134
+ decoder = decoder ,
135
+ maximum_iterations = 6 )
136
+
137
+ _ = tf .identity (outputs .rnn_output , name = "rnn_output" )
138
+ _ = tf .identity (outputs .sample_id , name = "sample_id" )
139
+ _ = tf .identity (state , name = "state" )
140
+ _ = tf .identity (sequence_lengths , name = "sequence_lengths" )
141
+
142
+ output_names_with_port = [
143
+ "rnn_output:0" ,
144
+ # "sample_id:0", # incomplete type support for Transpose on onnxruntime 0.2.1
145
+ "state:0" ,
146
+ ]
147
+
148
+ self .run_test_case ({}, [], output_names_with_port , atol = 1e-06 , rtol = 1e-6 )
149
+
68
150
69
151
if __name__ == '__main__' :
70
152
unittest_main ()
0 commit comments