|
| 1 | +# Copyright (c) Microsoft Corporation. All rights reserved. |
| 2 | +# Licensed under the MIT license. |
| 3 | + |
| 4 | +""" Unit Tests for tf.contrib.seq2seq """ |
| 5 | + |
| 6 | +from __future__ import absolute_import |
| 7 | +from __future__ import division |
| 8 | +from __future__ import print_function |
| 9 | + |
| 10 | +import numpy as np |
| 11 | +import tensorflow as tf |
| 12 | +from tensorflow.contrib import rnn |
| 13 | +from tensorflow.python.ops import init_ops |
| 14 | +from backend_test_base import Tf2OnnxBackendTestBase |
| 15 | +from common import unittest_main |
| 16 | + |
| 17 | + |
| 18 | +# pylint: disable=missing-docstring |
| 19 | + |
| 20 | +class Seq2SeqTests(Tf2OnnxBackendTestBase): |
| 21 | + def test_dynamic_decode_maximum_iterations(self): |
| 22 | + batch_size = 2 |
| 23 | + num_units = 4 |
| 24 | + vocab_size = 5 |
| 25 | + embedding_size = 3 |
| 26 | + go_token = 0 |
| 27 | + end_token = 1 |
| 28 | + |
| 29 | + embedding = tf.constant(np.ones([vocab_size, embedding_size], dtype=np.float32)) |
| 30 | + state_val = np.reshape([np.ones([num_units], dtype=np.float32) * i for i in range(batch_size)], |
| 31 | + [batch_size, num_units]) |
| 32 | + encoder_state = tf.nn.rnn_cell.LSTMStateTuple(state_val, state_val) |
| 33 | + initializer = init_ops.constant_initializer(0.5) |
| 34 | + cell = rnn.LSTMCell( |
| 35 | + num_units=num_units, |
| 36 | + initializer=initializer, |
| 37 | + state_is_tuple=True) |
| 38 | + |
| 39 | + helper = tf.contrib.seq2seq.GreedyEmbeddingHelper( |
| 40 | + embedding=embedding, |
| 41 | + start_tokens=tf.tile([go_token], [batch_size]), |
| 42 | + end_token=end_token) |
| 43 | + |
| 44 | + output_layer = tf.layers.Dense(vocab_size, kernel_initializer=initializer) |
| 45 | + decoder = tf.contrib.seq2seq.BasicDecoder( |
| 46 | + cell=cell, |
| 47 | + helper=helper, |
| 48 | + initial_state=encoder_state, |
| 49 | + output_layer=output_layer) |
| 50 | + |
| 51 | + outputs, state, sequence_lengths = tf.contrib.seq2seq.dynamic_decode( |
| 52 | + decoder=decoder, |
| 53 | + maximum_iterations=6) |
| 54 | + |
| 55 | + _ = tf.identity(outputs.rnn_output, name="rnn_output") |
| 56 | + _ = tf.identity(outputs.sample_id, name="sample_id") |
| 57 | + _ = tf.identity(state, name="state") |
| 58 | + _ = tf.identity(sequence_lengths, name="sequence_lengths") |
| 59 | + |
| 60 | + output_names_with_port = [ |
| 61 | + "rnn_output:0", |
| 62 | + # "sample_id:0", # incomplete type support for Transpose on onnxruntime 0.2.1 |
| 63 | + "state:0", |
| 64 | + ] |
| 65 | + |
| 66 | + self.run_test_case({}, [], output_names_with_port, atol=1e-06, rtol=1e-6) |
| 67 | + |
| 68 | + def test_dynamic_decode_normal_stop(self): |
| 69 | + batch_size = 2 |
| 70 | + num_units = 4 |
| 71 | + vocab_size = 5 |
| 72 | + embedding_size = 3 |
| 73 | + go_token = 0 |
| 74 | + end_token = 1 |
| 75 | + |
| 76 | + embedding = tf.constant(np.ones([vocab_size, embedding_size], dtype=np.float32)) |
| 77 | + state_val = np.reshape([np.ones([num_units], dtype=np.float32) * i for i in range(batch_size)], |
| 78 | + [batch_size, num_units]) |
| 79 | + encoder_state = tf.nn.rnn_cell.LSTMStateTuple(state_val, state_val) |
| 80 | + |
| 81 | + cell_initializer = init_ops.constant_initializer( |
| 82 | + np.array([[-0.9592235, 0.42451382, 0.7437744, -0.54485345, -0.80763197, |
| 83 | + 0.19663906, -0.22738314, 0.7762785, 0.7464578, 0.27227187, |
| 84 | + 0.7661047, 0.3596425, -0.8528242, -0.89316916, -0.48946142, |
| 85 | + 0.87882376], |
| 86 | + [0.86586094, -0.75018406, 0.25992537, -0.69368935, 0.2515502, |
| 87 | + -0.26379275, 0.8954313, 0.5759742, -0.7753072, -0.4388857, |
| 88 | + 0.95751476, -0.82085776, -0.9467752, -0.37055635, -0.18570113, |
| 89 | + -0.86504984], |
| 90 | + [0.02305841, 0.3850248, 0.893692, -0.6866486, -0.83703446, |
| 91 | + -0.9828961, 0.3989377, -0.59993076, 0.5330808, 0.6916566, |
| 92 | + 0.98468065, -0.6047034, 0.10823512, 0.34599304, -0.7834821, |
| 93 | + -0.7852347], |
| 94 | + [0.81643987, 0.31507468, -0.51369476, -0.12273741, 0.9701307, |
| 95 | + -0.79669356, -0.34496522, -0.88750815, -0.17995334, 0.34707904, |
| 96 | + -0.09201193, 0.5363934, -0.87229705, -0.5073328, -0.95894027, |
| 97 | + 0.5481839], |
| 98 | + [-0.84093595, -0.2341497, -0.86047816, 0.43370056, -0.39073753, |
| 99 | + 0.37730122, 0.48026466, 0.3004985, -0.60727096, 0.9043884, |
| 100 | + -0.37619448, 0.22490788, -0.03739262, 0.61672115, 0.478899, |
| 101 | + -0.40780973], |
| 102 | + [0.31202435, -0.22045255, -0.6087918, 0.95115066, 0.00199413, |
| 103 | + -0.688287, -0.1103518, 0.4169519, 0.7913246, -0.9844644, |
| 104 | + -0.6193857, 0.38659644, -0.4726901, -0.44781208, -0.5174744, |
| 105 | + -0.605911], |
| 106 | + [0.66771054, 0.34912825, 0.22297978, -0.4990945, 0.24057317, |
| 107 | + -0.5540829, 0.92277217, 0.74939895, -0.35278273, -0.21587133, |
| 108 | + -0.28613377, -0.8794241, -0.40119147, 0.67175174, -0.22741508, |
| 109 | + 0.37898326]], dtype=np.float32)) |
| 110 | + dense_initializer = init_ops.constant_initializer( |
| 111 | + np.array([[0.56177187, -0.6233454, 0.73997784, 0.35032558, 0.6479795], |
| 112 | + [0.6831174, -0.34233975, 0.39330363, 0.45177555, -0.49649096], |
| 113 | + [-0.98890066, 0.6175642, 0.09800482, -0.6721206, 0.48805737], |
| 114 | + [0.19671416, 0.2623148, 0.742548, 0.13555217, 0.56009054]], dtype=np.float32)) |
| 115 | + |
| 116 | + cell = rnn.LSTMCell( |
| 117 | + num_units=num_units, |
| 118 | + initializer=cell_initializer, |
| 119 | + state_is_tuple=True) |
| 120 | + |
| 121 | + helper = tf.contrib.seq2seq.GreedyEmbeddingHelper( |
| 122 | + embedding=embedding, |
| 123 | + start_tokens=tf.tile([go_token], [batch_size]), |
| 124 | + end_token=end_token) |
| 125 | + |
| 126 | + output_layer = tf.layers.Dense(vocab_size, kernel_initializer=dense_initializer) |
| 127 | + decoder = tf.contrib.seq2seq.BasicDecoder( |
| 128 | + cell=cell, |
| 129 | + helper=helper, |
| 130 | + initial_state=encoder_state, |
| 131 | + output_layer=output_layer) |
| 132 | + |
| 133 | + outputs, state, sequence_lengths = tf.contrib.seq2seq.dynamic_decode( |
| 134 | + decoder=decoder, |
| 135 | + maximum_iterations=6) |
| 136 | + |
| 137 | + _ = tf.identity(outputs.rnn_output, name="rnn_output") |
| 138 | + _ = tf.identity(outputs.sample_id, name="sample_id") |
| 139 | + _ = tf.identity(state, name="state") |
| 140 | + _ = tf.identity(sequence_lengths, name="sequence_lengths") |
| 141 | + |
| 142 | + output_names_with_port = [ |
| 143 | + "rnn_output:0", |
| 144 | + # "sample_id:0", # incomplete type support for Transpose on onnxruntime 0.2.1 |
| 145 | + "state:0", |
| 146 | + ] |
| 147 | + |
| 148 | + self.run_test_case({}, [], output_names_with_port, atol=1e-06, rtol=1e-6) |
| 149 | + |
| 150 | + |
| 151 | +if __name__ == '__main__': |
| 152 | + unittest_main() |
0 commit comments