Skip to content

Commit 723e3a4

Browse files
committed
Add stacked LSTM support
- added new base class exclusive for LSTM which supports stacked LSTMs - this base class can be combined with unit_rnn_rewriter_base once we add support for stacked layers in other RNN types like GRU - modified LSTM rewriter such that it allows multiple cell matches and for all the variable finder and weight-bias methods we pass the LSTM index - added new test for the same
1 parent 68d7b88 commit 723e3a4

File tree

4 files changed

+355
-80
lines changed

4 files changed

+355
-80
lines changed

tests/test_stacked_lstm.py

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
"""Unit Tests for layered lstm"""
2+
3+
import numpy as np
4+
import tensorflow as tf
5+
6+
from tensorflow.python.ops import init_ops
7+
from backend_test_base import Tf2OnnxBackendTestBase
8+
from common import unittest_main, check_lstm_count, skip_tf2
9+
10+
from tf2onnx.tf_loader import is_tf2
11+
12+
if is_tf2():
13+
LSTMCell = tf.compat.v1.nn.rnn_cell.LSTMCell
14+
MultiRNNCell = tf.compat.v1.nn.rnn_cell.MultiRNNCell
15+
dynamic_rnn = tf.compat.v1.nn.dynamic_rnn
16+
else:
17+
LSTMCell = tf.contrib.rnn.LSTMCell
18+
LSTMBlockCell = tf.contrib.rnn.LSTMBlockCell
19+
MultiRNNCell = tf.contrib.rnn.MultiRNNCell
20+
dynamic_rnn = tf.nn.dynamic_rnn
21+
22+
class LSTMLayeredTests(Tf2OnnxBackendTestBase):
23+
def test_layered_lstm(self):
24+
units = 5
25+
batch_size = 6
26+
x_val = np.array([[1., 1.], [2., 2.], [3., 3.], [4., 4.]], dtype=np.float32)
27+
x_val = np.stack([x_val] * batch_size)
28+
29+
def func(x):
30+
initializer = init_ops.constant_initializer(0.5)
31+
num_layers = 2
32+
33+
# no scope
34+
def lstm_cell():
35+
return LSTMCell(
36+
units,
37+
initializer=initializer,
38+
state_is_tuple=True)
39+
40+
stacked_lstm = MultiRNNCell(
41+
[lstm_cell() for _ in range(num_layers)])
42+
outputs, cell_state = dynamic_rnn(
43+
stacked_lstm,
44+
x,
45+
dtype=tf.float32)
46+
return tf.identity(outputs, name="output"), tf.identity(cell_state, name="cell_state")
47+
48+
input_names_with_port = ["input_1:0"]
49+
feed_dict = {"input_1:0": x_val}
50+
51+
output_names_with_port = ["output:0", "cell_state:0"]
52+
self.run_test_case(func, feed_dict, input_names_with_port, output_names_with_port, rtol=1e-06,
53+
graph_validator=lambda g: check_lstm_count(g, 2))
54+
55+
if __name__ == '__main__':
56+
unittest_main()

0 commit comments

Comments
 (0)