Skip to content

Commit 64f4dd5

Browse files
authored
Merge pull request #925 from buddhapuneeth/master
Add stacked LSTM support
2 parents 2c5a914 + 70f40a4 commit 64f4dd5

File tree

4 files changed

+362
-79
lines changed

4 files changed

+362
-79
lines changed

tests/test_stacked_lstm.py

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
"""Unit Tests for layered lstm"""
2+
3+
import numpy as np
4+
import tensorflow as tf
5+
6+
from tensorflow.python.ops import init_ops
7+
from backend_test_base import Tf2OnnxBackendTestBase
8+
from common import unittest_main, check_lstm_count, check_opset_after_tf_version, skip_tf2
9+
10+
from tf2onnx.tf_loader import is_tf2
11+
12+
13+
# pylint: disable=missing-docstring,invalid-name,unused-argument,using-constant-test,cell-var-from-loop
14+
# pylint: disable=invalid-name
15+
16+
if is_tf2():
17+
LSTMCell = tf.compat.v1.nn.rnn_cell.LSTMCell
18+
MultiRNNCell = tf.compat.v1.nn.rnn_cell.MultiRNNCell
19+
dynamic_rnn = tf.compat.v1.nn.dynamic_rnn
20+
else:
21+
LSTMCell = tf.contrib.rnn.LSTMCell
22+
LSTMBlockCell = tf.contrib.rnn.LSTMBlockCell
23+
MultiRNNCell = tf.contrib.rnn.MultiRNNCell
24+
dynamic_rnn = tf.nn.dynamic_rnn
25+
26+
27+
class LSTMLayeredTests(Tf2OnnxBackendTestBase):
28+
@check_opset_after_tf_version("1.15", 8, "might need Scan")
29+
@skip_tf2()
30+
def test_layered_lstm(self):
31+
units = 5
32+
batch_size = 6
33+
x_val = np.array([[1., 1.], [2., 2.], [3., 3.], [4., 4.]], dtype=np.float32)
34+
x_val = np.stack([x_val] * batch_size)
35+
36+
def func(x):
37+
initializer = init_ops.constant_initializer(0.5)
38+
num_layers = 2
39+
40+
# no scope
41+
def lstm_cell():
42+
return LSTMCell(
43+
units,
44+
initializer=initializer,
45+
state_is_tuple=True)
46+
47+
stacked_lstm = MultiRNNCell(
48+
[lstm_cell() for _ in range(num_layers)])
49+
outputs, cell_state = dynamic_rnn(
50+
stacked_lstm,
51+
x,
52+
dtype=tf.float32)
53+
return tf.identity(outputs, name="output"), tf.identity(cell_state, name="cell_state")
54+
55+
input_names_with_port = ["input_1:0"]
56+
feed_dict = {"input_1:0": x_val}
57+
58+
output_names_with_port = ["output:0", "cell_state:0"]
59+
self.run_test_case(func, feed_dict, input_names_with_port, output_names_with_port, rtol=1e-06,
60+
graph_validator=lambda g: check_lstm_count(g, 2))
61+
62+
63+
if __name__ == '__main__':
64+
unittest_main()

0 commit comments

Comments
 (0)