Skip to content

Commit f19a036

Browse files
q-ycong-pcongyc
andauthored
Add Keras LSTM support (#1752)
Signed-off-by: congyc <[email protected]> Co-authored-by: congyc <[email protected]>
1 parent 42e800d commit f19a036

File tree

6 files changed

+394
-164
lines changed

6 files changed

+394
-164
lines changed

tests/keras2onnx_unit_tests/test_layers.py

Lines changed: 120 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -45,13 +45,15 @@
4545
Layer = keras.layers.Layer
4646
LeakyReLU = keras.layers.LeakyReLU
4747
LSTM = keras.layers.LSTM
48+
LSTMCell = keras.layers.LSTMCell
4849
Maximum = keras.layers.Maximum
4950
MaxPool1D = keras.layers.MaxPool1D
5051
MaxPool3D = keras.layers.MaxPool3D
5152
MaxPooling2D = keras.layers.MaxPooling2D
5253
Model = keras.models.Model
5354
Multiply = keras.layers.Multiply
5455
Reshape = keras.layers.Reshape
56+
RNN = keras.layers.RNN
5557
SeparableConv1D = keras.layers.SeparableConv1D
5658
SeparableConv2D = keras.layers.SeparableConv2D
5759
Sequential = keras.models.Sequential
@@ -66,11 +68,15 @@
6668
if not is_keras_older_than("2.2.4"):
6769
ReLU = keras.layers.ReLU
6870

71+
GRU_CLASSES = [(GRU, "v1")]
72+
LSTM_CLASSES = [(LSTM, LSTMCell, "v1")]
6973
RNN_CLASSES = [SimpleRNN, GRU, LSTM]
7074

7175
if is_tf_keras and is_tensorflow_later_than("1.14.0"):
7276
# Add the TF v2 compatability layers (available after TF 1.14)
7377
from tensorflow.python.keras.layers import recurrent_v2
78+
GRU_CLASSES.append((recurrent_v2.GRU, "v2"))
79+
LSTM_CLASSES.append((recurrent_v2.LSTM, recurrent_v2.LSTMCell, "v2"))
7480
RNN_CLASSES.extend([recurrent_v2.GRU, recurrent_v2.LSTM])
7581

7682

@@ -1829,134 +1835,175 @@ def test_simpleRNN(runner):
18291835
assert runner(onnx_model.graph.name, onnx_model, [x, s], expected)
18301836

18311837

1832-
def test_GRU(runner):
1838+
@pytest.mark.parametrize("gru_class, rnn_version", GRU_CLASSES)
1839+
@pytest.mark.parametrize("return_sequences", [True, False])
1840+
def test_GRU(runner, gru_class, rnn_version, return_sequences):
18331841
inputs1 = keras.Input(shape=(3, 1))
18341842

1835-
cls = GRU(2, return_state=False, return_sequences=False)
1843+
# GRU with no initial state
1844+
cls = gru_class(2, return_state=False, return_sequences=False)
18361845
oname = cls(inputs1)
18371846
model = keras.Model(inputs=inputs1, outputs=[oname])
18381847
onnx_model = convert_keras(model, model.name)
1839-
1848+
if rnn_version == "v2":
1849+
assert no_loops_in_tf2(onnx_model)
18401850
data = np.array([0.1, 0.2, 0.3]).astype(np.float32).reshape((1, 3, 1))
18411851
expected = model.predict(data)
18421852
assert runner(onnx_model.graph.name, onnx_model, data, expected)
18431853

18441854
# GRU with initial state
1845-
for return_sequences in [True, False]:
1846-
cls = GRU(2, return_state=False, return_sequences=return_sequences)
1847-
initial_state_input = keras.Input(shape=(2,))
1848-
oname = cls(inputs1, initial_state=initial_state_input)
1849-
model = keras.Model(inputs=[inputs1, initial_state_input], outputs=[oname])
1850-
onnx_model = convert_keras(model, model.name)
1851-
1852-
data = np.array([0.1, 0.2, 0.3]).astype(np.float32).reshape((1, 3, 1))
1853-
init_state = np.array([0.4, 0.5]).astype(np.float32).reshape((1, 2))
1854-
init_state_onnx = np.array([0.4, 0.5]).astype(np.float32).reshape((1, 2))
1855-
expected = model.predict([data, init_state])
1856-
assert runner(onnx_model.graph.name, onnx_model, [data, init_state_onnx], expected)
1855+
cls = gru_class(2, return_state=False, return_sequences=return_sequences)
1856+
initial_state_input = keras.Input(shape=(2,))
1857+
oname = cls(inputs1, initial_state=initial_state_input)
1858+
model = keras.Model(inputs=[inputs1, initial_state_input], outputs=[oname])
1859+
onnx_model = convert_keras(model, model.name)
1860+
if rnn_version == "v2":
1861+
assert no_loops_in_tf2(onnx_model)
1862+
data = np.array([0.1, 0.2, 0.3]).astype(np.float32).reshape((1, 3, 1))
1863+
init_state = np.array([0.4, 0.5]).astype(np.float32).reshape((1, 2))
1864+
init_state_onnx = np.array([0.4, 0.5]).astype(np.float32).reshape((1, 2))
1865+
expected = model.predict([data, init_state])
1866+
assert runner(onnx_model.graph.name, onnx_model, [data, init_state_onnx], expected)
18571867

18581868

18591869
@pytest.mark.skipif(not is_tf_keras and is_tf2 and is_tensorflow_older_than('2.2'),
18601870
reason="Fails due to some reason involving bad graph captures. Works in new versions and tf_keras")
1861-
def test_GRU_2(runner):
1871+
@pytest.mark.parametrize("gru_class, rnn_version", GRU_CLASSES)
1872+
def test_GRU_2(runner, gru_class, rnn_version):
18621873
model = keras.Sequential(name='TestGRU')
1863-
model.add(keras.layers.GRU(400, reset_after=True, input_shape=(1, 257)))
1864-
model.add(keras.layers.Dense(257, activation='sigmoid'))
1874+
model.add(gru_class(400, reset_after=True, input_shape=(1, 257)))
1875+
model.add(Dense(257, activation='sigmoid'))
18651876
onnx_model = convert_keras(model, name=model.name)
1877+
if rnn_version == "v2":
1878+
assert no_loops_in_tf2(onnx_model)
18661879
data = np.random.rand(3, 257).astype(np.float32).reshape((3, 1, 257))
18671880
expected = model.predict(data)
18681881
assert runner(onnx_model.graph.name, onnx_model, data, expected)
18691882

18701883

1884+
@pytest.mark.parametrize("lstm_class, lstmcell_class, rnn_version", LSTM_CLASSES)
18711885
@pytest.mark.parametrize('return_sequences', [False, True])
1872-
def test_LSTM(runner, return_sequences):
1886+
@pytest.mark.parametrize('use_bias', [False, True])
1887+
def test_LSTM(runner, lstm_class, lstmcell_class, rnn_version, return_sequences, use_bias):
18731888
inputs1 = keras.Input(shape=(3, 5))
18741889
data = np.random.rand(3, 5).astype(np.float32).reshape((1, 3, 5))
1875-
for use_bias in [True, False]:
1876-
for return_sequences in [True, False]:
1877-
cls = LSTM(units=2, return_state=True, return_sequences=return_sequences, use_bias=use_bias)
1878-
lstm1, state_h, state_c = cls(inputs1)
1879-
model = keras.Model(inputs=inputs1, outputs=[lstm1, state_h, state_c])
1880-
onnx_model = convert_keras(model, model.name)
1881-
expected = model.predict(data)
1882-
assert runner(onnx_model.graph.name, onnx_model, data, expected)
1890+
cls1 = lstm_class(units=2, return_state=True, return_sequences=return_sequences, use_bias=use_bias)
1891+
cls2 = RNN(lstmcell_class(units=2, use_bias=use_bias), return_state=True, return_sequences=return_sequences)
1892+
lstm1, state_h, state_c = cls1(inputs1)
1893+
lstm2, state_h_2, state_c_2 = cls2(inputs1)
1894+
model = keras.Model(inputs=inputs1, outputs=[lstm1, state_h, state_c, lstm2, state_h_2, state_c_2])
1895+
onnx_model = convert_keras(model, model.name)
1896+
if rnn_version == "v2":
1897+
assert no_loops_in_tf2(onnx_model)
1898+
expected = model.predict(data)
1899+
assert runner(onnx_model.graph.name, onnx_model, data, expected)
1900+
18831901

18841902
@pytest.mark.skipif((is_tensorflow_older_than('1.14.0') or (not is_tf_keras)), reason='old tf version')
1885-
def test_LSTM_rev(runner):
1903+
@pytest.mark.parametrize("lstm_class, lstmcell_class, rnn_version", LSTM_CLASSES)
1904+
@pytest.mark.parametrize('return_sequences', [False, True])
1905+
@pytest.mark.parametrize('use_bias', [False, True])
1906+
def test_LSTM_rev(runner, lstm_class, lstmcell_class, rnn_version, return_sequences, use_bias):
18861907
inputs1 = keras.Input(shape=(3, 5))
18871908
data = np.random.rand(3, 5).astype(np.float32).reshape((1, 3, 5))
1888-
for use_bias in [True, False]:
1889-
for return_sequences in [True, False]:
1890-
cls = LSTM(units=2, return_state=True, go_backwards=True, return_sequences=return_sequences, use_bias=use_bias)
1891-
lstm1, state_h, state_c = cls(inputs1)
1892-
model = keras.Model(inputs=inputs1, outputs=[lstm1, state_h, state_c])
1893-
onnx_model = convert_keras(model, model.name)
1894-
expected = model.predict(data)
1895-
assert runner(onnx_model.graph.name, onnx_model, data, expected)
1909+
cls = lstm_class(units=2, return_state=True, go_backwards=True, return_sequences=return_sequences, use_bias=use_bias)
1910+
lstm1, state_h, state_c = cls(inputs1)
1911+
model = keras.Model(inputs=inputs1, outputs=[lstm1, state_h, state_c])
1912+
onnx_model = convert_keras(model, model.name)
1913+
if rnn_version == "v2":
1914+
assert no_loops_in_tf2(onnx_model)
1915+
expected = model.predict(data)
1916+
assert runner(onnx_model.graph.name, onnx_model, data, expected)
18961917

18971918

18981919
@pytest.mark.skipif((is_tensorflow_older_than('1.14.0') or (not is_tf_keras)),
18991920
reason="keras LSTM does not have time_major attribute")
1900-
def test_LSTM_time_major_return_seq_true(runner):
1921+
@pytest.mark.parametrize("lstm_class, lstmcell_class, rnn_version", LSTM_CLASSES)
1922+
def test_LSTM_time_major_return_seq_true(runner, lstm_class, lstmcell_class, rnn_version):
19011923
inputs1 = keras.Input(shape=(3, 5))
19021924
data = np.random.rand(1, 3, 5).astype(np.float32)
19031925
# Transpose input to be time major
19041926
input_transposed = tf.transpose(inputs1, perm=[1, 0, 2])
1905-
lstm1, state_h, state_c = LSTM(units=2, time_major=True, return_state=True,
1927+
lstm1, state_h, state_c = lstm_class(units=2, time_major=True, return_state=True,
19061928
return_sequences=True)(input_transposed)
1929+
lstm2, state_h_2, state_c_2 = RNN(lstmcell_class(units=2), time_major=True, return_state=True,
1930+
return_sequences=True)(input_transposed)
19071931
lstm1_trans = tf.transpose(lstm1, perm=[1, 0, 2])
1908-
model = keras.Model(inputs=inputs1, outputs=[lstm1_trans, state_h, state_c])
1932+
lstm2_trans = tf.transpose(lstm2, perm=[1,0,2])
1933+
model = keras.Model(inputs=inputs1, outputs=[lstm1_trans, state_h, state_c,
1934+
lstm2_trans, state_h_2, state_c_2])
19091935
onnx_model = convert_keras(model, model.name)
1936+
if rnn_version == "v2":
1937+
assert no_loops_in_tf2(onnx_model)
19101938
expected = model.predict(data)
19111939
assert runner(onnx_model.graph.name, onnx_model, data, expected)
19121940

19131941

19141942
@pytest.mark.skipif((is_tensorflow_older_than('1.14.0') or (not is_tf_keras)),
19151943
reason="keras LSTM does not have time_major attribute")
1916-
def test_LSTM_time_major_return_seq_false(runner):
1944+
@pytest.mark.parametrize("lstm_class, lstmcell_class, rnn_version", LSTM_CLASSES)
1945+
def test_LSTM_time_major_return_seq_false(runner, lstm_class, lstmcell_class, rnn_version):
19171946
inputs1 = keras.Input(shape=(3, 5))
19181947
data = np.random.rand(1, 3, 5).astype(np.float32)
19191948
# Transpose input to be time major
19201949
input_transposed = tf.transpose(inputs1, perm=[1, 0, 2])
1921-
lstm1, state_h, state_c = LSTM(units=2, time_major=True, return_state=True,
1950+
lstm1, state_h, state_c = lstm_class(units=2, time_major=True, return_state=True,
19221951
return_sequences=False)(input_transposed)
1923-
model = keras.Model(inputs=inputs1, outputs=[lstm1, state_h, state_c])
1952+
lstm2, state_h_2, state_c_2 = RNN(lstmcell_class(units=2), time_major=True, return_state=True,
1953+
return_sequences=False)(input_transposed)
1954+
model = keras.Model(inputs=inputs1, outputs=[lstm1, state_h, state_c,
1955+
lstm2, state_h_2, state_c_2])
19241956
onnx_model = convert_keras(model, model.name)
1957+
if rnn_version == "v2":
1958+
assert no_loops_in_tf2(onnx_model)
19251959
expected = model.predict(data)
19261960
assert runner(onnx_model.graph.name, onnx_model, data, expected)
19271961

19281962

1929-
def test_LSTM_with_bias(runner):
1963+
@pytest.mark.parametrize("lstm_class, lstmcell_class, rnn_version", LSTM_CLASSES)
1964+
def test_LSTM_with_bias(runner, lstm_class, lstmcell_class, rnn_version):
19301965
inputs1 = keras.Input(shape=(1, 1))
1931-
cls = LSTM(units=1, return_state=True, return_sequences=True)
1966+
cls = lstm_class(units=1, return_state=True, return_sequences=True)
19321967
lstm1, state_h, state_c = cls(inputs1)
1933-
model = keras.Model(inputs=inputs1, outputs=[lstm1, state_h, state_c])
1968+
lstm2, state_h_2, state_c_2 = RNN(lstmcell_class(units=1), return_state=True,
1969+
return_sequences=True)(inputs1)
1970+
model = keras.Model(inputs=inputs1, outputs=[lstm1, state_h, state_c,
1971+
lstm2, state_h_2, state_c_2])
19341972
# Set weights: kernel, recurrent_kernel and bias
1935-
model.set_weights((np.array([[1, 2, 3, 4]]), np.array([[5, 6, 7, 8]]), np.array([1, 2, 3, 4])))
1973+
model.set_weights((np.array([[1, 2, 3, 4]]), np.array([[5, 6, 7, 8]]), np.array([1, 2, 3, 4]),
1974+
np.array([[1, 2, 3, 4]]), np.array([[5, 6, 7, 8]]), np.array([1, 2, 3, 4])))
19361975
data = np.random.rand(1, 1).astype(np.float32).reshape((1, 1, 1))
19371976
onnx_model = convert_keras(model, model.name)
1938-
1977+
if rnn_version == "v2":
1978+
assert no_loops_in_tf2(onnx_model)
19391979
expected = model.predict(data)
19401980
assert runner(onnx_model.graph.name, onnx_model, data, expected)
19411981

19421982

1943-
def test_LSTM_reshape(runner):
1983+
@pytest.mark.parametrize("lstm_class, lstmcell_class, rnn_version", LSTM_CLASSES)
1984+
def test_LSTM_reshape(runner, lstm_class, lstmcell_class, rnn_version):
19441985
input_dim = 7
19451986
sequence_len = 3
19461987
inputs1 = keras.Input(shape=(sequence_len, input_dim))
1947-
cls = LSTM(units=5, return_state=False, return_sequences=True)
1988+
cls = lstm_class(units=5, return_state=False, return_sequences=True)
19481989
lstm1 = cls(inputs1)
1990+
lstm2 = RNN(lstmcell_class(units=5), return_state=False, return_sequences=True)(inputs1)
1991+
19491992
output = Reshape((sequence_len, 5))(lstm1)
1950-
model = keras.Model(inputs=inputs1, outputs=output)
1993+
output_2 = Reshape((sequence_len, 5))(lstm2)
1994+
model = keras.Model(inputs=inputs1, outputs=[output, output_2])
19511995
model.compile(optimizer='sgd', loss='mse')
19521996

19531997
onnx_model = convert_keras(model, 'test')
1998+
if rnn_version == "v2":
1999+
assert no_loops_in_tf2(onnx_model)
19542000
data = np.random.rand(input_dim, sequence_len).astype(np.float32).reshape((1, sequence_len, input_dim))
19552001
expected = model.predict(data)
19562002
assert runner('tf_lstm', onnx_model, data, expected)
19572003

19582004

1959-
def test_LSTM_with_initializer(runner):
2005+
@pytest.mark.parametrize("lstm_class, lstmcell_class, rnn_version", LSTM_CLASSES)
2006+
def test_LSTM_with_initializer(runner, lstm_class, lstmcell_class, rnn_version):
19602007
# batch_size = N
19612008
# seq_length = H
19622009
# input_size = W
@@ -1971,34 +2018,44 @@ def test_LSTM_with_initializer(runner):
19712018
state_c = keras.Input(shape=(C,), name='state_c')
19722019

19732020
# create keras model
1974-
lstm_layer = LSTM(units=C, activation='relu', return_sequences=True)(inputs,
2021+
lstm_layer = lstm_class(units=C, activation='relu', return_sequences=True)(inputs,
19752022
initial_state=[state_h,
19762023
state_c])
2024+
lstm_layer_2 = RNN(lstmcell_class(units=C, activation='relu'),
2025+
return_sequences=True)(inputs, initial_state=[state_h, state_c])
19772026
outputs = Dense(W, activation='sigmoid')(lstm_layer)
1978-
keras_model = keras.Model(inputs=[inputs, state_h, state_c], outputs=outputs)
2027+
outputs_2 = Dense(W, activation='sigmoid')(lstm_layer_2)
2028+
keras_model = keras.Model(inputs=[inputs, state_h, state_c],
2029+
outputs=[outputs, outputs_2])
19792030

19802031
x = np.random.rand(1, H, W).astype(np.float32)
19812032
sh = np.random.rand(1, C).astype(np.float32)
19822033
sc = np.random.rand(1, C).astype(np.float32)
1983-
expected = keras_model.predict([x, sh, sc])
19842034
onnx_model = convert_keras(keras_model, keras_model.name)
2035+
if rnn_version == "v2":
2036+
assert no_loops_in_tf2(onnx_model)
2037+
expected = keras_model.predict([x, sh, sc])
19852038
assert runner(onnx_model.graph.name, onnx_model, {"inputs": x, 'state_h': sh, 'state_c': sc}, expected)
19862039

19872040

19882041
@pytest.mark.skipif(get_maximum_opset_supported() < 5,
19892042
reason="None seq_length LSTM is not supported before opset 5.")
19902043
@pytest.mark.skipif(is_tensorflow_older_than('2.2'), reason='require 2.2 to fix freezing')
1991-
def test_LSTM_seqlen_none(runner):
2044+
@pytest.mark.parametrize("lstm_class, lstmcell_class, rnn_version", LSTM_CLASSES)
2045+
@pytest.mark.parametrize('return_sequences', [False, True])
2046+
def test_LSTM_seqlen_none(runner, lstm_class, lstmcell_class, rnn_version, return_sequences):
19922047
lstm_dim = 2
19932048
data = np.random.rand(1, 5, 1).astype(np.float32)
1994-
for return_sequences in [True, False]:
1995-
inp = Input(batch_shape=(1, None, 1))
1996-
out = LSTM(lstm_dim, return_sequences=return_sequences, stateful=True)(inp)
1997-
keras_model = keras.Model(inputs=inp, outputs=out)
1998-
1999-
onnx_model = convert_keras(keras_model)
2000-
expected = keras_model.predict(data)
2001-
assert runner(onnx_model.graph.name, onnx_model, data, expected)
2049+
inp = Input(batch_shape=(1, None, 1))
2050+
out = lstm_class(lstm_dim, return_sequences=return_sequences, stateful=True)(inp)
2051+
out_2 = RNN(lstmcell_class(lstm_dim), return_sequences=return_sequences, stateful=True)(inp)
2052+
keras_model = keras.Model(inputs=inp, outputs=[out, out_2])
2053+
2054+
onnx_model = convert_keras(keras_model)
2055+
if rnn_version == "v2":
2056+
assert no_loops_in_tf2(onnx_model)
2057+
expected = keras_model.predict(data)
2058+
assert runner(onnx_model.graph.name, onnx_model, data, expected)
20022059

20032060

20042061
@pytest.mark.parametrize("return_sequences", [True, False])

tests/test_backend.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5464,6 +5464,7 @@ def func(input_val):
54645464
self.config.opset = current_opset
54655465

54665466
@check_tf_min_version("1.14")
5467+
@skip_tfjs("Fails to run tfjs model")
54675468
def test_rfft_ops(self):
54685469

54695470
def dft_slow(x, M, fft_length):

tests/test_lstm.py

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,14 +9,16 @@
99
from tensorflow.python.ops import init_ops
1010
from tensorflow.python.ops import variable_scope
1111
from backend_test_base import Tf2OnnxBackendTestBase
12-
from common import unittest_main, check_opset_after_tf_version, skip_tf2, skip_tf_versions, check_op_count
12+
from common import check_tf_min_version, unittest_main, check_opset_after_tf_version, \
13+
skip_tf2, skip_tf_versions, check_op_count
1314

1415
from tf2onnx.tf_loader import is_tf2
1516

1617

1718
# pylint: disable=missing-docstring,invalid-name,unused-argument,using-constant-test,cell-var-from-loop
1819
# pylint: disable=invalid-name
1920

21+
2022
if is_tf2():
2123
# There is no LSTMBlockCell in tf-2.x
2224
BasicLSTMCell = tf.compat.v1.nn.rnn_cell.BasicLSTMCell
@@ -726,6 +728,28 @@ def func(x, y1, y2):
726728
self.run_test_case(func, feed_dict, input_names_with_port, output_names_with_port, rtol=1e-3, atol=1e-06,
727729
require_lstm_count=2)
728730

731+
@check_tf_min_version("2.0")
732+
def test_keras_lstm(self):
733+
in_shape = [10, 3]
734+
x_val = np.random.uniform(size=[2, 10, 3]).astype(np.float32)
735+
736+
model_in = tf.keras.layers.Input(tuple(in_shape), batch_size=2)
737+
x = tf.keras.layers.LSTM(
738+
units=5,
739+
return_sequences=True,
740+
return_state=True,
741+
kernel_initializer=tf.random_uniform_initializer(0.0, 1.0, seed=42),
742+
recurrent_initializer=tf.random_uniform_initializer(0.0, 1.0, seed=44),
743+
bias_initializer=tf.random_uniform_initializer(0.0, 1.0, seed=43)
744+
)(model_in)
745+
model = tf.keras.models.Model(inputs=model_in, outputs=x)
746+
747+
def func(x):
748+
y = model(x)
749+
# names for input and outputs for tests
750+
return tf.identity(y[0], name="output"), tf.identity(y[1], name="output1")
751+
self.run_test_case(func, {"input:0": x_val}, [], ["output:0", "output1:0"], rtol=1e-05, atol=1e-06)
752+
729753

730754
if __name__ == '__main__':
731755
unittest_main()

0 commit comments

Comments
 (0)