Skip to content

Commit be51386

Browse files
committed
add seq2seq test
1 parent 081e295 commit be51386

File tree

1 file changed

+70
-0
lines changed

1 file changed

+70
-0
lines changed

tests/test_seq2seq.py

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
# Copyright (c) Microsoft Corporation. All rights reserved.
2+
# Licensed under the MIT license.
3+
4+
""" Unit Tests for tf.contrib.seq2seq """
5+
6+
from __future__ import absolute_import
7+
from __future__ import division
8+
from __future__ import print_function
9+
10+
import numpy as np
11+
import tensorflow as tf
12+
from tensorflow.contrib import rnn
13+
from tensorflow.python.ops import init_ops
14+
from backend_test_base import Tf2OnnxBackendTestBase
15+
from common import check_opset_min_version, unittest_main
16+
17+
18+
# pylint: disable=missing-docstring
19+
20+
class Seq2SeqTests(Tf2OnnxBackendTestBase):
21+
def test_dynamic_decode_maximum_iterations(self):
22+
batch_size = 2
23+
num_units = 4
24+
vocab_size = 5
25+
embedding_size = 3
26+
GO_SYMBOL = 0
27+
END_SYMBOL = 1
28+
29+
embedding = tf.constant(np.ones([vocab_size, embedding_size], dtype=np.float32))
30+
state_val = np.reshape([np.ones([num_units], dtype=np.float32) * i for i in range(batch_size)],
31+
[batch_size, num_units])
32+
encoder_state = tf.nn.rnn_cell.LSTMStateTuple(state_val, state_val)
33+
initializer = init_ops.constant_initializer(0.5)
34+
cell = rnn.LSTMCell(
35+
num_units=num_units,
36+
initializer=initializer,
37+
state_is_tuple=True)
38+
39+
helper = tf.contrib.seq2seq.GreedyEmbeddingHelper(
40+
embedding=embedding,
41+
start_tokens=tf.tile([GO_SYMBOL], [batch_size]),
42+
end_token=END_SYMBOL)
43+
44+
output_layer = tf.layers.Dense(vocab_size, kernel_initializer=initializer)
45+
decoder = tf.contrib.seq2seq.BasicDecoder(
46+
cell=cell,
47+
helper=helper,
48+
initial_state=encoder_state,
49+
output_layer=output_layer)
50+
51+
outputs, state, sequence_lengths = tf.contrib.seq2seq.dynamic_decode(
52+
decoder=decoder,
53+
maximum_iterations=6)
54+
55+
_ = tf.identity(outputs.rnn_output, name="rnn_output")
56+
_ = tf.identity(outputs.sample_id, name="sample_id")
57+
_ = tf.identity(state, name="state")
58+
_ = tf.identity(sequence_lengths, name="sequence_lengths")
59+
60+
output_names_with_port = [
61+
"rnn_output:0",
62+
# "sample_id:0", # incomplete type support for Transpose on onnxruntime 0.2.1
63+
"state:0",
64+
]
65+
66+
self.run_test_case({}, [], output_names_with_port, atol=1e-06, rtol=1e-6)
67+
68+
69+
if __name__ == '__main__':
70+
unittest_main()

0 commit comments

Comments
 (0)