Skip to content

Commit 85bca92

Browse files
Merge pull request #1012 from onnx/tom/FixLSTM
Fixed LSTM conversion for TF2
2 parents 38b19b2 + 29522aa commit 85bca92

File tree

2 files changed

+117
-101
lines changed

2 files changed

+117
-101
lines changed

tests/test_lstm.py

Lines changed: 44 additions & 75 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from tensorflow.python.ops import init_ops
1515
from tensorflow.python.ops import variable_scope
1616
from backend_test_base import Tf2OnnxBackendTestBase
17-
from common import unittest_main, check_lstm_count, check_opset_after_tf_version, skip_tf2
17+
from common import unittest_main, check_opset_after_tf_version, skip_tf2, skip_tf_versions
1818

1919
from tf2onnx.tf_loader import is_tf2
2020

@@ -41,12 +41,10 @@
4141

4242
class LSTMTests(Tf2OnnxBackendTestBase):
4343
@check_opset_after_tf_version("1.15", 8, "might need Scan")
44-
@skip_tf2()
4544
def test_test_single_dynamic_lstm_state_is_tuple(self):
4645
self.internal_test_single_dynamic_lstm(True)
4746

4847
@check_opset_after_tf_version("1.15", 8, "might need Scan")
49-
@skip_tf2()
5048
def test_test_single_dynamic_lstm_state_is_not_tuple(self):
5149
self.internal_test_single_dynamic_lstm(False)
5250

@@ -74,11 +72,9 @@ def func(x):
7472
feed_dict = {"input_1:0": x_val}
7573

7674
output_names_with_port = ["output:0", "cell_state:0"]
77-
self.run_test_case(func, feed_dict, input_names_with_port, output_names_with_port, rtol=1e-06,
78-
graph_validator=lambda g: check_lstm_count(g, 1))
75+
self.run_test_case(func, feed_dict, input_names_with_port, output_names_with_port, rtol=1e-06)
7976

8077
@check_opset_after_tf_version("1.15", 8, "might need Scan")
81-
@skip_tf2()
8278
def test_single_dynamic_lstm_time_major(self):
8379
units = 5
8480
seq_len = 6
@@ -104,11 +100,9 @@ def func(x):
104100
feed_dict = {"input_1:0": x_val}
105101

106102
output_names_with_port = ["output:0", "cell_state:0"]
107-
self.run_test_case(func, feed_dict, input_names_with_port, output_names_with_port, rtol=1e-06,
108-
graph_validator=lambda g: check_lstm_count(g, 1))
103+
self.run_test_case(func, feed_dict, input_names_with_port, output_names_with_port, rtol=1e-06)
109104

110105
@check_opset_after_tf_version("1.15", 8, "might need Scan")
111-
@skip_tf2()
112106
def test_single_dynamic_lstm_forget_bias(self):
113107
units = 5
114108
seq_len = 6
@@ -135,11 +129,9 @@ def func(x):
135129
feed_dict = {"input_1:0": x_val}
136130

137131
output_names_with_port = ["output:0", "cell_state:0"]
138-
self.run_test_case(func, feed_dict, input_names_with_port, output_names_with_port, rtol=1e-06,
139-
graph_validator=lambda g: check_lstm_count(g, 1))
132+
self.run_test_case(func, feed_dict, input_names_with_port, output_names_with_port, rtol=1e-06)
140133

141134
@check_opset_after_tf_version("1.15", 8, "might need Select")
142-
@skip_tf2()
143135
def test_single_dynamic_lstm_seq_length_is_const(self):
144136
units = 5
145137
batch_size = 6
@@ -165,11 +157,9 @@ def func(x):
165157
feed_dict = {"input_1:0": x_val}
166158
input_names_with_port = ["input_1:0"]
167159
output_names_with_port = ["output:0", "cell_state:0"]
168-
self.run_test_case(func, feed_dict, input_names_with_port, output_names_with_port, rtol=1e-06,
169-
graph_validator=lambda g: check_lstm_count(g, 1))
160+
self.run_test_case(func, feed_dict, input_names_with_port, output_names_with_port, rtol=1e-06)
170161

171162
@check_opset_after_tf_version("1.15", 8, "might need Select")
172-
@skip_tf2()
173163
def test_single_dynamic_lstm_seq_length_is_not_const(self):
174164
for np_dtype in [np.int32, np.int64, np.float32]:
175165
units = 5
@@ -197,11 +187,9 @@ def func(x, seq_length):
197187
feed_dict = {"input_1:0": x_val, "input_2:0": y_val}
198188
input_names_with_port = ["input_1:0", "input_2:0"]
199189
output_names_with_port = ["output:0", "cell_state:0"]
200-
self.run_test_case(func, feed_dict, input_names_with_port, output_names_with_port, rtol=1e-06,
201-
graph_validator=lambda g: check_lstm_count(g, 1))
190+
self.run_test_case(func, feed_dict, input_names_with_port, output_names_with_port, rtol=1e-06)
202191

203192
@check_opset_after_tf_version("1.15", 8, "might need Scan")
204-
@skip_tf2()
205193
def test_single_dynamic_lstm_placeholder_input(self):
206194
units = 5
207195
x_val = np.array([[1., 1.], [2., 2.], [3., 3.], [4., 4.]], dtype=np.float32)
@@ -225,11 +213,9 @@ def func(x):
225213
feed_dict = {"input_1:0": x_val}
226214
input_names_with_port = ["input_1:0"]
227215
output_names_with_port = ["output:0", "cell_state:0"]
228-
self.run_test_case(func, feed_dict, input_names_with_port, output_names_with_port, rtol=1e-06,
229-
graph_validator=lambda g: check_lstm_count(g, 1))
216+
self.run_test_case(func, feed_dict, input_names_with_port, output_names_with_port, rtol=1e-06)
230217

231218
@check_opset_after_tf_version("1.15", 8, "might need Scan")
232-
@skip_tf2()
233219
def test_single_dynamic_lstm_ch_zero_state_initializer(self):
234220
units = 5
235221
batch_size = 6
@@ -258,11 +244,9 @@ def func(x):
258244
feed_dict = {"input_1:0": x_val}
259245
input_names_with_port = ["input_1:0"]
260246
output_names_with_port = ["output:0", "cell_state:0"]
261-
self.run_test_case(func, feed_dict, input_names_with_port, output_names_with_port, rtol=1e-06,
262-
graph_validator=lambda g: check_lstm_count(g, 1))
247+
self.run_test_case(func, feed_dict, input_names_with_port, output_names_with_port, rtol=1e-06)
263248

264249
@check_opset_after_tf_version("1.15", 8, "might need Scan")
265-
@skip_tf2()
266250
def test_single_dynamic_lstm_consume_one_of_ch_tuple(self):
267251
units = 5
268252
batch_size = 6
@@ -288,19 +272,17 @@ def func(x):
288272
feed_dict = {"input_1:0": x_val}
289273
input_names_with_port = ["input_1:0"]
290274
output_names_with_port = ["output:0", "cell_state_c:0"]
291-
self.run_test_case(func, feed_dict, input_names_with_port, output_names_with_port, rtol=1e-06,
292-
graph_validator=lambda g: check_lstm_count(g, 1))
275+
self.run_test_case(func, feed_dict, input_names_with_port, output_names_with_port, rtol=1e-06)
293276

294277
@check_opset_after_tf_version("1.15", 8, "might need Scan")
295-
@skip_tf2()
296278
def test_single_dynamic_lstm_random_weights(self, state_is_tuple=True):
297279
hidden_size = 5
298280
batch_size = 6
299281
x_val = np.array([[1., 1.], [2., 2.], [3., 3.], [4., 4.]], dtype=np.float32)
300282
x_val = np.stack([x_val] * batch_size)
301283

302284
def func(x):
303-
initializer = tf.random_uniform_initializer(-1.0, 1.0)
285+
initializer = tf.random_uniform_initializer(-1.0, 1.0, seed=42)
304286

305287
# no scope
306288
cell = LSTMCell(
@@ -318,19 +300,17 @@ def func(x):
318300
feed_dict = {"input_1:0": x_val}
319301
input_names_with_port = ["input_1:0"]
320302
output_names_with_port = ["output:0", "cell_state:0"]
321-
self.run_test_case(func, feed_dict, input_names_with_port, output_names_with_port, rtol=0.0001,
322-
graph_validator=lambda g: check_lstm_count(g, 1))
303+
self.run_test_case(func, feed_dict, input_names_with_port, output_names_with_port, rtol=0.0001)
323304

324305
@check_opset_after_tf_version("1.15", 8, "might need Select")
325-
@skip_tf2()
326306
def test_single_dynamic_lstm_random_weights2(self, state_is_tuple=True):
327307
hidden_size = 128
328308
batch_size = 1
329309
x_val = np.random.randn(1, 133).astype('f')
330310
x_val = np.stack([x_val] * batch_size)
331311

332312
def func(x):
333-
initializer = tf.random_uniform_initializer(0.0, 1.0)
313+
initializer = tf.random_uniform_initializer(0.0, 1.0, seed=42)
334314
# no scope
335315
cell = LSTMCell(
336316
hidden_size,
@@ -347,15 +327,12 @@ def func(x):
347327
feed_dict = {"input_1:0": x_val}
348328
input_names_with_port = ["input_1:0"]
349329
output_names_with_port = ["output:0", "cell_state:0"]
350-
self.run_test_case(func, feed_dict, input_names_with_port, output_names_with_port, rtol=0.01,
351-
graph_validator=lambda g: check_lstm_count(g, 1))
330+
self.run_test_case(func, feed_dict, input_names_with_port, output_names_with_port, rtol=0.01)
352331

353332
@check_opset_after_tf_version("1.15", 8, "might need Select")
354-
@skip_tf2()
355333
def test_multiple_dynamic_lstm_state_is_tuple(self):
356334
self.internal_test_multiple_dynamic_lstm_with_parameters(True)
357335

358-
@skip_tf2()
359336
@check_opset_after_tf_version("1.15", 8, "might need Scan")
360337
def test_multiple_dynamic_lstm_state_is_not_tuple(self):
361338
self.internal_test_multiple_dynamic_lstm_with_parameters(False)
@@ -406,7 +383,7 @@ def func(x):
406383
self.run_test_case(func, feed_dict, input_names_with_port, output_names_with_port, rtol=1e-06)
407384

408385
@check_opset_after_tf_version("1.15", 8, "might need Scan")
409-
@skip_tf2()
386+
@skip_tf2() # Still failing likely due to inconsistent random number initialization
410387
def test_dynamic_basiclstm(self):
411388
units = 5
412389
batch_size = 6
@@ -426,20 +403,20 @@ def func(x):
426403
feed_dict = {"input_1:0": x_val}
427404
input_names_with_port = ["input_1:0"]
428405
output_names_with_port = ["output:0", "cell_state:0"]
429-
self.run_test_case(func, feed_dict, input_names_with_port, output_names_with_port, rtol=0.0001, atol=1e-06,
430-
graph_validator=lambda g: check_lstm_count(g, 1))
406+
self.run_test_case(func, feed_dict, input_names_with_port, output_names_with_port, rtol=0.0001, atol=1e-06)
431407

432408
@check_opset_after_tf_version("1.15", 8, "might need Scan")
433-
@skip_tf2()
434409
def test_dynamic_lstm_output_consumed_only(self):
435410
units = 5
436411
batch_size = 6
437412
x_val = np.array([[1., 1.], [2., 2.], [3., 3.]], dtype=np.float32)
438413
x_val = np.stack([x_val] * batch_size)
439414

440415
def func(x):
416+
initializer = tf.random_uniform_initializer(0.0, 1.0, seed=42)
441417
cell1 = LSTMCell(
442418
units,
419+
initializer=initializer,
443420
state_is_tuple=True)
444421

445422
outputs, _ = dynamic_rnn(
@@ -452,35 +429,31 @@ def func(x):
452429
feed_dict = {"input_1:0": x_val}
453430
input_names_with_port = ["input_1:0"]
454431
output_names_with_port = ["output:0"]
455-
self.run_test_case(func, feed_dict, input_names_with_port, output_names_with_port, rtol=0.0001, atol=1e-07,
456-
graph_validator=lambda g: check_lstm_count(g, 1))
432+
self.run_test_case(func, feed_dict, input_names_with_port, output_names_with_port, rtol=0.0001, atol=1e-07)
457433

458434
@check_opset_after_tf_version("1.15", 8, "might need Scan")
459-
@skip_tf2()
460435
def test_dynamic_lstm_state_consumed_only(self):
461436
units = 5
462437
batch_size = 6
463438
x_val = np.array([[1., 1.], [2., 2.], [3., 3.]], dtype=np.float32)
464439
x_val = np.stack([x_val] * batch_size)
465440

466441
def func(x):
467-
cell1 = LSTMCell(units, state_is_tuple=True)
442+
initializer = tf.random_uniform_initializer(0.0, 1.0, seed=42)
443+
cell1 = LSTMCell(units, initializer=initializer, state_is_tuple=True)
468444
_, cell_state = dynamic_rnn(cell1, x, dtype=tf.float32)
469445
return tf.identity(cell_state, name="cell_state")
470446

471447
feed_dict = {"input_1:0": x_val}
472448
input_names_with_port = ["input_1:0"]
473449
output_names_with_port = ["cell_state:0"]
474-
self.run_test_case(func, feed_dict, input_names_with_port, output_names_with_port, rtol=0.0001,
475-
graph_validator=lambda g: check_lstm_count(g, 1))
450+
self.run_test_case(func, feed_dict, input_names_with_port, output_names_with_port, rtol=0.0001)
476451

477452
@check_opset_after_tf_version("1.15", 10, "might need ReverseV2")
478-
@skip_tf2()
479453
def test_dynamic_bilstm_state_is_tuple(self):
480454
self.internal_test_dynamic_bilstm_with_parameters(True)
481455

482456
@check_opset_after_tf_version("1.15", 10, "might need ReverseV2")
483-
@skip_tf2()
484457
def test_dynamic_bilstm_state_is_not_tuple(self):
485458
self.internal_test_dynamic_bilstm_with_parameters(False)
486459

@@ -513,11 +486,9 @@ def func(x):
513486
feed_dict = {"input_1:0": x_val}
514487
input_names_with_port = ["input_1:0"]
515488
output_names_with_port = ["output:0", "cell_state:0"]
516-
self.run_test_case(func, feed_dict, input_names_with_port, output_names_with_port, rtol=1e-06,
517-
graph_validator=lambda g: check_lstm_count(g, 1))
489+
self.run_test_case(func, feed_dict, input_names_with_port, output_names_with_port, rtol=1e-06)
518490

519491
@check_opset_after_tf_version("1.15", 10, "might need ReverseV2")
520-
@skip_tf2()
521492
def test_dynamic_bilstm_output_consumed_only(self, state_is_tuple=True):
522493
units = 5
523494
batch_size = 6
@@ -547,11 +518,9 @@ def func(x):
547518
feed_dict = {"input_1:0": x_val}
548519
input_names_with_port = ["input_1:0"]
549520
output_names_with_port = ["output:0"]
550-
self.run_test_case(func, feed_dict, input_names_with_port, output_names_with_port, rtol=1e-06,
551-
graph_validator=lambda g: check_lstm_count(g, 1))
521+
self.run_test_case(func, feed_dict, input_names_with_port, output_names_with_port, rtol=1e-06)
552522

553523
@check_opset_after_tf_version("1.15", 10, "might need ReverseV2")
554-
@skip_tf2()
555524
def test_dynamic_bilstm_state_consumed_only(self, state_is_tuple=True):
556525
units = 5
557526
batch_size = 6
@@ -581,11 +550,9 @@ def func(x):
581550
feed_dict = {"input_1:0": x_val}
582551
input_names_with_port = ["input_1:0"]
583552
output_names_with_port = ["cell_state:0"]
584-
self.run_test_case(func, feed_dict, input_names_with_port, output_names_with_port, rtol=1e-06,
585-
graph_validator=lambda g: check_lstm_count(g, 1))
553+
self.run_test_case(func, feed_dict, input_names_with_port, output_names_with_port, rtol=1e-06)
586554

587555
@check_opset_after_tf_version("1.15", 10, "might need ReverseV2")
588-
@skip_tf2()
589556
def test_dynamic_bilstm_outputs_partially_consumed(self, state_is_tuple=True):
590557
units = 5
591558
batch_size = 6
@@ -615,11 +582,9 @@ def func(x):
615582
feed_dict = {"input_1:0": x_val}
616583
input_names_with_port = ["input_1:0"]
617584
output_names_with_port = ["output:0", "cell_state:0"]
618-
self.run_test_case(func, feed_dict, input_names_with_port, output_names_with_port, rtol=1e-06,
619-
graph_validator=lambda g: check_lstm_count(g, 1))
585+
self.run_test_case(func, feed_dict, input_names_with_port, output_names_with_port, rtol=1e-06)
620586

621587
@check_opset_after_tf_version("1.15", 10, "might need ReverseV2")
622-
@skip_tf2()
623588
def test_dynamic_bilstm_unknown_batch_size(self, state_is_tuple=True):
624589
units = 5
625590
batch_size = 6
@@ -649,20 +614,23 @@ def func(x):
649614
feed_dict = {"input_1:0": x_val}
650615
input_names_with_port = ["input_1:0"]
651616
output_names_with_port = ["cell_state:0"]
652-
self.run_test_case(func, feed_dict, input_names_with_port, output_names_with_port, rtol=1e-06,
653-
graph_validator=lambda g: check_lstm_count(g, 1))
617+
self.run_test_case(func, feed_dict, input_names_with_port, output_names_with_port, rtol=1e-06)
654618

655619
@check_opset_after_tf_version("1.15", 10, "might need ReverseV2")
656-
@skip_tf2()
620+
@skip_tf_versions("2.1", "Bug in TF 2.1")
657621
def test_dynamic_multi_bilstm_with_same_input_hidden_size(self):
658622
batch_size = 10
659623
x_val = np.array([[1., 1.], [2., 2.], [3., 3.]], dtype=np.float32)
660624
x_val = np.stack([x_val] * batch_size)
661625

662626
def func(x):
627+
initializer1 = tf.random_uniform_initializer(0.0, 1.0, seed=42)
628+
initializer2 = tf.random_uniform_initializer(0.0, 1.0, seed=43)
629+
initializer3 = tf.random_uniform_initializer(0.0, 1.0, seed=44)
630+
initializer4 = tf.random_uniform_initializer(0.0, 1.0, seed=45)
663631
units = 5
664-
cell1 = LSTMCell(units, name="cell1")
665-
cell2 = LSTMCell(units, name="cell2")
632+
cell1 = LSTMCell(units, name="cell1", initializer=initializer1)
633+
cell2 = LSTMCell(units, name="cell2", initializer=initializer2)
666634
outputs_1, cell_state_1 = bidirectional_dynamic_rnn(
667635
cell1,
668636
cell2,
@@ -672,8 +640,8 @@ def func(x):
672640
)
673641

674642
units = 10
675-
cell3 = LSTMCell(units, name="cell3")
676-
cell4 = LSTMCell(units, name="cell4")
643+
cell3 = LSTMCell(units, name="cell3", initializer=initializer3)
644+
cell4 = LSTMCell(units, name="cell4", initializer=initializer4)
677645
outputs_2, cell_state_2 = bidirectional_dynamic_rnn(
678646
cell3,
679647
cell4,
@@ -691,10 +659,9 @@ def func(x):
691659
input_names_with_port = ["input_1:0"]
692660
output_names_with_port = ["output_1:0", "cell_state_1:0", "output_2:0", "cell_state_2:0"]
693661
self.run_test_case(func, feed_dict, input_names_with_port, output_names_with_port, rtol=1e-3, atol=1e-06)
694-
# graph_validator=lambda g: check_lstm_count(g, 2))
695662

696663
@check_opset_after_tf_version("1.15", 10, "might need ReverseV2")
697-
@skip_tf2()
664+
@skip_tf_versions("2.1", "Bug in TF 2.1")
698665
def test_dynamic_multi_bilstm_with_same_input_seq_len(self):
699666
units = 5
700667
batch_size = 10
@@ -703,9 +670,11 @@ def test_dynamic_multi_bilstm_with_same_input_seq_len(self):
703670
seq_len_val = np.array([3], dtype=np.int32)
704671

705672
def func(x, y1, y2):
673+
initializer1 = tf.random_uniform_initializer(0.0, 1.0, seed=42)
674+
initializer2 = tf.random_uniform_initializer(0.0, 1.0, seed=43)
706675
seq_len1 = tf.tile(y1, [batch_size])
707-
cell1 = LSTMCell(units)
708-
cell2 = LSTMCell(units)
676+
cell1 = LSTMCell(units, initializer=initializer1)
677+
cell2 = LSTMCell(units, initializer=initializer2)
709678
outputs_1, cell_state_1 = bidirectional_dynamic_rnn(
710679
cell1,
711680
cell2,
@@ -714,10 +683,11 @@ def func(x, y1, y2):
714683
dtype=tf.float32,
715684
scope="bilstm_1"
716685
)
717-
686+
initializer1 = tf.random_uniform_initializer(0.0, 1.0, seed=44)
687+
initializer2 = tf.random_uniform_initializer(0.0, 1.0, seed=45)
718688
seq_len2 = tf.tile(y2, [batch_size])
719-
cell1 = LSTMCell(units)
720-
cell2 = LSTMCell(units)
689+
cell1 = LSTMCell(units, initializer=initializer1)
690+
cell2 = LSTMCell(units, initializer=initializer2)
721691
outputs_2, cell_state_2 = bidirectional_dynamic_rnn(
722692
cell1,
723693
cell2,
@@ -736,7 +706,6 @@ def func(x, y1, y2):
736706
input_names_with_port = ["input_1:0", "input_2:0", "input_3:0"]
737707
output_names_with_port = ["output_1:0", "cell_state_1:0", "output_2:0", "cell_state_2:0"]
738708
self.run_test_case(func, feed_dict, input_names_with_port, output_names_with_port, rtol=1e-3, atol=1e-06)
739-
# graph_validator=lambda g: check_lstm_count(g, 2))
740709

741710

742711
if __name__ == '__main__':

0 commit comments

Comments
 (0)