|
| 1 | +"""Unit Tests for layered lstm""" |
| 2 | + |
| 3 | +import numpy as np |
| 4 | +import tensorflow as tf |
| 5 | + |
| 6 | +from tensorflow.python.ops import init_ops |
| 7 | +from backend_test_base import Tf2OnnxBackendTestBase |
| 8 | +from common import unittest_main, check_lstm_count, check_opset_after_tf_version, skip_tf2 |
| 9 | + |
| 10 | +from tf2onnx.tf_loader import is_tf2 |
| 11 | + |
| 12 | + |
| 13 | +# pylint: disable=missing-docstring,invalid-name,unused-argument,using-constant-test,cell-var-from-loop |
| 14 | +# pylint: disable=invalid-name |
| 15 | + |
| 16 | +if is_tf2(): |
| 17 | + LSTMCell = tf.compat.v1.nn.rnn_cell.LSTMCell |
| 18 | + MultiRNNCell = tf.compat.v1.nn.rnn_cell.MultiRNNCell |
| 19 | + dynamic_rnn = tf.compat.v1.nn.dynamic_rnn |
| 20 | +else: |
| 21 | + LSTMCell = tf.contrib.rnn.LSTMCell |
| 22 | + LSTMBlockCell = tf.contrib.rnn.LSTMBlockCell |
| 23 | + MultiRNNCell = tf.contrib.rnn.MultiRNNCell |
| 24 | + dynamic_rnn = tf.nn.dynamic_rnn |
| 25 | + |
| 26 | + |
| 27 | +class LSTMLayeredTests(Tf2OnnxBackendTestBase): |
| 28 | + @check_opset_after_tf_version("1.15", 8, "might need Scan") |
| 29 | + @skip_tf2() |
| 30 | + def test_layered_lstm(self): |
| 31 | + units = 5 |
| 32 | + batch_size = 6 |
| 33 | + x_val = np.array([[1., 1.], [2., 2.], [3., 3.], [4., 4.]], dtype=np.float32) |
| 34 | + x_val = np.stack([x_val] * batch_size) |
| 35 | + |
| 36 | + def func(x): |
| 37 | + initializer = init_ops.constant_initializer(0.5) |
| 38 | + num_layers = 2 |
| 39 | + |
| 40 | + # no scope |
| 41 | + def lstm_cell(): |
| 42 | + return LSTMCell( |
| 43 | + units, |
| 44 | + initializer=initializer, |
| 45 | + state_is_tuple=True) |
| 46 | + |
| 47 | + stacked_lstm = MultiRNNCell( |
| 48 | + [lstm_cell() for _ in range(num_layers)]) |
| 49 | + outputs, cell_state = dynamic_rnn( |
| 50 | + stacked_lstm, |
| 51 | + x, |
| 52 | + dtype=tf.float32) |
| 53 | + return tf.identity(outputs, name="output"), tf.identity(cell_state, name="cell_state") |
| 54 | + |
| 55 | + input_names_with_port = ["input_1:0"] |
| 56 | + feed_dict = {"input_1:0": x_val} |
| 57 | + |
| 58 | + output_names_with_port = ["output:0", "cell_state:0"] |
| 59 | + self.run_test_case(func, feed_dict, input_names_with_port, output_names_with_port, rtol=1e-06, |
| 60 | + graph_validator=lambda g: check_lstm_count(g, 2)) |
| 61 | + |
| 62 | + |
| 63 | +if __name__ == '__main__': |
| 64 | + unittest_main() |
0 commit comments