Skip to content

Commit b9f3175

Browse files
authored
Merge pull request #360 from nbcsm/s2s
add seq2seq test
2 parents 3c9e997 + 957f670 commit b9f3175

File tree

2 files changed

+152
-5
lines changed

2 files changed

+152
-5
lines changed

tests/test_seq2seq.py

Lines changed: 152 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,152 @@
1+
# Copyright (c) Microsoft Corporation. All rights reserved.
2+
# Licensed under the MIT license.
3+
4+
""" Unit Tests for tf.contrib.seq2seq """
5+
6+
from __future__ import absolute_import
7+
from __future__ import division
8+
from __future__ import print_function
9+
10+
import numpy as np
11+
import tensorflow as tf
12+
from tensorflow.contrib import rnn
13+
from tensorflow.python.ops import init_ops
14+
from backend_test_base import Tf2OnnxBackendTestBase
15+
from common import unittest_main
16+
17+
18+
# pylint: disable=missing-docstring
19+
20+
class Seq2SeqTests(Tf2OnnxBackendTestBase):
21+
def test_dynamic_decode_maximum_iterations(self):
22+
batch_size = 2
23+
num_units = 4
24+
vocab_size = 5
25+
embedding_size = 3
26+
go_token = 0
27+
end_token = 1
28+
29+
embedding = tf.constant(np.ones([vocab_size, embedding_size], dtype=np.float32))
30+
state_val = np.reshape([np.ones([num_units], dtype=np.float32) * i for i in range(batch_size)],
31+
[batch_size, num_units])
32+
encoder_state = tf.nn.rnn_cell.LSTMStateTuple(state_val, state_val)
33+
initializer = init_ops.constant_initializer(0.5)
34+
cell = rnn.LSTMCell(
35+
num_units=num_units,
36+
initializer=initializer,
37+
state_is_tuple=True)
38+
39+
helper = tf.contrib.seq2seq.GreedyEmbeddingHelper(
40+
embedding=embedding,
41+
start_tokens=tf.tile([go_token], [batch_size]),
42+
end_token=end_token)
43+
44+
output_layer = tf.layers.Dense(vocab_size, kernel_initializer=initializer)
45+
decoder = tf.contrib.seq2seq.BasicDecoder(
46+
cell=cell,
47+
helper=helper,
48+
initial_state=encoder_state,
49+
output_layer=output_layer)
50+
51+
outputs, state, sequence_lengths = tf.contrib.seq2seq.dynamic_decode(
52+
decoder=decoder,
53+
maximum_iterations=6)
54+
55+
_ = tf.identity(outputs.rnn_output, name="rnn_output")
56+
_ = tf.identity(outputs.sample_id, name="sample_id")
57+
_ = tf.identity(state, name="state")
58+
_ = tf.identity(sequence_lengths, name="sequence_lengths")
59+
60+
output_names_with_port = [
61+
"rnn_output:0",
62+
# "sample_id:0", # incomplete type support for Transpose on onnxruntime 0.2.1
63+
"state:0",
64+
]
65+
66+
self.run_test_case({}, [], output_names_with_port, atol=1e-06, rtol=1e-6)
67+
68+
def test_dynamic_decode_normal_stop(self):
69+
batch_size = 2
70+
num_units = 4
71+
vocab_size = 5
72+
embedding_size = 3
73+
go_token = 0
74+
end_token = 1
75+
76+
embedding = tf.constant(np.ones([vocab_size, embedding_size], dtype=np.float32))
77+
state_val = np.reshape([np.ones([num_units], dtype=np.float32) * i for i in range(batch_size)],
78+
[batch_size, num_units])
79+
encoder_state = tf.nn.rnn_cell.LSTMStateTuple(state_val, state_val)
80+
81+
cell_initializer = init_ops.constant_initializer(
82+
np.array([[-0.9592235, 0.42451382, 0.7437744, -0.54485345, -0.80763197,
83+
0.19663906, -0.22738314, 0.7762785, 0.7464578, 0.27227187,
84+
0.7661047, 0.3596425, -0.8528242, -0.89316916, -0.48946142,
85+
0.87882376],
86+
[0.86586094, -0.75018406, 0.25992537, -0.69368935, 0.2515502,
87+
-0.26379275, 0.8954313, 0.5759742, -0.7753072, -0.4388857,
88+
0.95751476, -0.82085776, -0.9467752, -0.37055635, -0.18570113,
89+
-0.86504984],
90+
[0.02305841, 0.3850248, 0.893692, -0.6866486, -0.83703446,
91+
-0.9828961, 0.3989377, -0.59993076, 0.5330808, 0.6916566,
92+
0.98468065, -0.6047034, 0.10823512, 0.34599304, -0.7834821,
93+
-0.7852347],
94+
[0.81643987, 0.31507468, -0.51369476, -0.12273741, 0.9701307,
95+
-0.79669356, -0.34496522, -0.88750815, -0.17995334, 0.34707904,
96+
-0.09201193, 0.5363934, -0.87229705, -0.5073328, -0.95894027,
97+
0.5481839],
98+
[-0.84093595, -0.2341497, -0.86047816, 0.43370056, -0.39073753,
99+
0.37730122, 0.48026466, 0.3004985, -0.60727096, 0.9043884,
100+
-0.37619448, 0.22490788, -0.03739262, 0.61672115, 0.478899,
101+
-0.40780973],
102+
[0.31202435, -0.22045255, -0.6087918, 0.95115066, 0.00199413,
103+
-0.688287, -0.1103518, 0.4169519, 0.7913246, -0.9844644,
104+
-0.6193857, 0.38659644, -0.4726901, -0.44781208, -0.5174744,
105+
-0.605911],
106+
[0.66771054, 0.34912825, 0.22297978, -0.4990945, 0.24057317,
107+
-0.5540829, 0.92277217, 0.74939895, -0.35278273, -0.21587133,
108+
-0.28613377, -0.8794241, -0.40119147, 0.67175174, -0.22741508,
109+
0.37898326]], dtype=np.float32))
110+
dense_initializer = init_ops.constant_initializer(
111+
np.array([[0.56177187, -0.6233454, 0.73997784, 0.35032558, 0.6479795],
112+
[0.6831174, -0.34233975, 0.39330363, 0.45177555, -0.49649096],
113+
[-0.98890066, 0.6175642, 0.09800482, -0.6721206, 0.48805737],
114+
[0.19671416, 0.2623148, 0.742548, 0.13555217, 0.56009054]], dtype=np.float32))
115+
116+
cell = rnn.LSTMCell(
117+
num_units=num_units,
118+
initializer=cell_initializer,
119+
state_is_tuple=True)
120+
121+
helper = tf.contrib.seq2seq.GreedyEmbeddingHelper(
122+
embedding=embedding,
123+
start_tokens=tf.tile([go_token], [batch_size]),
124+
end_token=end_token)
125+
126+
output_layer = tf.layers.Dense(vocab_size, kernel_initializer=dense_initializer)
127+
decoder = tf.contrib.seq2seq.BasicDecoder(
128+
cell=cell,
129+
helper=helper,
130+
initial_state=encoder_state,
131+
output_layer=output_layer)
132+
133+
outputs, state, sequence_lengths = tf.contrib.seq2seq.dynamic_decode(
134+
decoder=decoder,
135+
maximum_iterations=6)
136+
137+
_ = tf.identity(outputs.rnn_output, name="rnn_output")
138+
_ = tf.identity(outputs.sample_id, name="sample_id")
139+
_ = tf.identity(state, name="state")
140+
_ = tf.identity(sequence_lengths, name="sequence_lengths")
141+
142+
output_names_with_port = [
143+
"rnn_output:0",
144+
# "sample_id:0", # incomplete type support for Transpose on onnxruntime 0.2.1
145+
"state:0",
146+
]
147+
148+
self.run_test_case({}, [], output_names_with_port, atol=1e-06, rtol=1e-6)
149+
150+
151+
if __name__ == '__main__':
152+
unittest_main()

tf2onnx/rewriter/loop_rewriter_base.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -187,7 +187,6 @@ def rewrite(self, context):
187187
return REWRITER_RESULT.FAIL
188188

189189
def run_internal(self):
190-
log.debug("enter loop rewriter")
191190
for op in self.g.get_nodes():
192191
if not is_loopcond_op(op):
193192
continue
@@ -265,11 +264,9 @@ def _crop_loop_body_sub_graph(self, context):
265264
output_ids = [out_tensor_value_info.id for out_tensor_value_info in outputs]
266265
ops, enter_nodes, _ = self.find_subgraph(set(input_ids), set(output_ids), self.g, merge_as_end=False)
267266

268-
other_enter_input_ids = []
269267
for enter_node in enter_nodes:
270268
# connect Enter's output to Enter's input
271269
self.g.replace_all_inputs(ops, enter_node.output[0], enter_node.input[0])
272-
other_enter_input_ids.append(enter_node.input[0])
273270

274271
return GraphInfo(ops, inputs, outputs)
275272

@@ -279,11 +276,9 @@ def _crop_loop_condition_sub_graph(self, context):
279276
outputs = [TensorValueInfo(o, self.g) for o in output_ids]
280277
ops, enter_nodes, merge_nodes = self.find_subgraph(set(input_ids), set(output_ids), self.g, merge_as_end=True)
281278

282-
other_enter_input_ids = []
283279
for enter_node in enter_nodes:
284280
# connect Enter's output to Enter's input
285281
self.g.replace_all_inputs(ops, enter_node.output[0], enter_node.input[0])
286-
other_enter_input_ids.append(enter_node.input[0])
287282

288283
dependent_vars = []
289284
for merge_node in merge_nodes:

0 commit comments

Comments
 (0)