Skip to content

Commit 957f670

Browse files
committed
add test_dynamic_decode_normal_stop
1 parent fba20cf commit 957f670

File tree

1 file changed

+87
-5
lines changed

1 file changed

+87
-5
lines changed

tests/test_seq2seq.py

Lines changed: 87 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from tensorflow.contrib import rnn
1313
from tensorflow.python.ops import init_ops
1414
from backend_test_base import Tf2OnnxBackendTestBase
15-
from common import check_opset_min_version, unittest_main
15+
from common import unittest_main
1616

1717

1818
# pylint: disable=missing-docstring
@@ -23,8 +23,8 @@ def test_dynamic_decode_maximum_iterations(self):
2323
num_units = 4
2424
vocab_size = 5
2525
embedding_size = 3
26-
GO_SYMBOL = 0
27-
END_SYMBOL = 1
26+
go_token = 0
27+
end_token = 1
2828

2929
embedding = tf.constant(np.ones([vocab_size, embedding_size], dtype=np.float32))
3030
state_val = np.reshape([np.ones([num_units], dtype=np.float32) * i for i in range(batch_size)],
@@ -38,8 +38,8 @@ def test_dynamic_decode_maximum_iterations(self):
3838

3939
helper = tf.contrib.seq2seq.GreedyEmbeddingHelper(
4040
embedding=embedding,
41-
start_tokens=tf.tile([GO_SYMBOL], [batch_size]),
42-
end_token=END_SYMBOL)
41+
start_tokens=tf.tile([go_token], [batch_size]),
42+
end_token=end_token)
4343

4444
output_layer = tf.layers.Dense(vocab_size, kernel_initializer=initializer)
4545
decoder = tf.contrib.seq2seq.BasicDecoder(
@@ -65,6 +65,88 @@ def test_dynamic_decode_maximum_iterations(self):
6565

6666
self.run_test_case({}, [], output_names_with_port, atol=1e-06, rtol=1e-6)
6767

68+
def test_dynamic_decode_normal_stop(self):
69+
batch_size = 2
70+
num_units = 4
71+
vocab_size = 5
72+
embedding_size = 3
73+
go_token = 0
74+
end_token = 1
75+
76+
embedding = tf.constant(np.ones([vocab_size, embedding_size], dtype=np.float32))
77+
state_val = np.reshape([np.ones([num_units], dtype=np.float32) * i for i in range(batch_size)],
78+
[batch_size, num_units])
79+
encoder_state = tf.nn.rnn_cell.LSTMStateTuple(state_val, state_val)
80+
81+
cell_initializer = init_ops.constant_initializer(
82+
np.array([[-0.9592235, 0.42451382, 0.7437744, -0.54485345, -0.80763197,
83+
0.19663906, -0.22738314, 0.7762785, 0.7464578, 0.27227187,
84+
0.7661047, 0.3596425, -0.8528242, -0.89316916, -0.48946142,
85+
0.87882376],
86+
[0.86586094, -0.75018406, 0.25992537, -0.69368935, 0.2515502,
87+
-0.26379275, 0.8954313, 0.5759742, -0.7753072, -0.4388857,
88+
0.95751476, -0.82085776, -0.9467752, -0.37055635, -0.18570113,
89+
-0.86504984],
90+
[0.02305841, 0.3850248, 0.893692, -0.6866486, -0.83703446,
91+
-0.9828961, 0.3989377, -0.59993076, 0.5330808, 0.6916566,
92+
0.98468065, -0.6047034, 0.10823512, 0.34599304, -0.7834821,
93+
-0.7852347],
94+
[0.81643987, 0.31507468, -0.51369476, -0.12273741, 0.9701307,
95+
-0.79669356, -0.34496522, -0.88750815, -0.17995334, 0.34707904,
96+
-0.09201193, 0.5363934, -0.87229705, -0.5073328, -0.95894027,
97+
0.5481839],
98+
[-0.84093595, -0.2341497, -0.86047816, 0.43370056, -0.39073753,
99+
0.37730122, 0.48026466, 0.3004985, -0.60727096, 0.9043884,
100+
-0.37619448, 0.22490788, -0.03739262, 0.61672115, 0.478899,
101+
-0.40780973],
102+
[0.31202435, -0.22045255, -0.6087918, 0.95115066, 0.00199413,
103+
-0.688287, -0.1103518, 0.4169519, 0.7913246, -0.9844644,
104+
-0.6193857, 0.38659644, -0.4726901, -0.44781208, -0.5174744,
105+
-0.605911],
106+
[0.66771054, 0.34912825, 0.22297978, -0.4990945, 0.24057317,
107+
-0.5540829, 0.92277217, 0.74939895, -0.35278273, -0.21587133,
108+
-0.28613377, -0.8794241, -0.40119147, 0.67175174, -0.22741508,
109+
0.37898326]], dtype=np.float32))
110+
dense_initializer = init_ops.constant_initializer(
111+
np.array([[0.56177187, -0.6233454, 0.73997784, 0.35032558, 0.6479795],
112+
[0.6831174, -0.34233975, 0.39330363, 0.45177555, -0.49649096],
113+
[-0.98890066, 0.6175642, 0.09800482, -0.6721206, 0.48805737],
114+
[0.19671416, 0.2623148, 0.742548, 0.13555217, 0.56009054]], dtype=np.float32))
115+
116+
cell = rnn.LSTMCell(
117+
num_units=num_units,
118+
initializer=cell_initializer,
119+
state_is_tuple=True)
120+
121+
helper = tf.contrib.seq2seq.GreedyEmbeddingHelper(
122+
embedding=embedding,
123+
start_tokens=tf.tile([go_token], [batch_size]),
124+
end_token=end_token)
125+
126+
output_layer = tf.layers.Dense(vocab_size, kernel_initializer=dense_initializer)
127+
decoder = tf.contrib.seq2seq.BasicDecoder(
128+
cell=cell,
129+
helper=helper,
130+
initial_state=encoder_state,
131+
output_layer=output_layer)
132+
133+
outputs, state, sequence_lengths = tf.contrib.seq2seq.dynamic_decode(
134+
decoder=decoder,
135+
maximum_iterations=6)
136+
137+
_ = tf.identity(outputs.rnn_output, name="rnn_output")
138+
_ = tf.identity(outputs.sample_id, name="sample_id")
139+
_ = tf.identity(state, name="state")
140+
_ = tf.identity(sequence_lengths, name="sequence_lengths")
141+
142+
output_names_with_port = [
143+
"rnn_output:0",
144+
# "sample_id:0", # incomplete type support for Transpose on onnxruntime 0.2.1
145+
"state:0",
146+
]
147+
148+
self.run_test_case({}, [], output_names_with_port, atol=1e-06, rtol=1e-6)
149+
68150

69151
if __name__ == '__main__':
70152
unittest_main()

0 commit comments

Comments
 (0)