14
14
from tensorflow .python .ops import init_ops
15
15
from tensorflow .python .ops import variable_scope
16
16
from backend_test_base import Tf2OnnxBackendTestBase
17
- from common import unittest_main , check_lstm_count , check_opset_after_tf_version , skip_tf2
17
+ from common import unittest_main , check_opset_after_tf_version , skip_tf2 , skip_tf_versions
18
18
19
19
from tf2onnx .tf_loader import is_tf2
20
20
41
41
42
42
class LSTMTests (Tf2OnnxBackendTestBase ):
43
43
@check_opset_after_tf_version ("1.15" , 8 , "might need Scan" )
44
- @skip_tf2 ()
45
44
def test_test_single_dynamic_lstm_state_is_tuple (self ):
46
45
self .internal_test_single_dynamic_lstm (True )
47
46
48
47
@check_opset_after_tf_version ("1.15" , 8 , "might need Scan" )
49
- @skip_tf2 ()
50
48
def test_test_single_dynamic_lstm_state_is_not_tuple (self ):
51
49
self .internal_test_single_dynamic_lstm (False )
52
50
@@ -74,11 +72,9 @@ def func(x):
74
72
feed_dict = {"input_1:0" : x_val }
75
73
76
74
output_names_with_port = ["output:0" , "cell_state:0" ]
77
- self .run_test_case (func , feed_dict , input_names_with_port , output_names_with_port , rtol = 1e-06 ,
78
- graph_validator = lambda g : check_lstm_count (g , 1 ))
75
+ self .run_test_case (func , feed_dict , input_names_with_port , output_names_with_port , rtol = 1e-06 )
79
76
80
77
@check_opset_after_tf_version ("1.15" , 8 , "might need Scan" )
81
- @skip_tf2 ()
82
78
def test_single_dynamic_lstm_time_major (self ):
83
79
units = 5
84
80
seq_len = 6
@@ -104,11 +100,9 @@ def func(x):
104
100
feed_dict = {"input_1:0" : x_val }
105
101
106
102
output_names_with_port = ["output:0" , "cell_state:0" ]
107
- self .run_test_case (func , feed_dict , input_names_with_port , output_names_with_port , rtol = 1e-06 ,
108
- graph_validator = lambda g : check_lstm_count (g , 1 ))
103
+ self .run_test_case (func , feed_dict , input_names_with_port , output_names_with_port , rtol = 1e-06 )
109
104
110
105
@check_opset_after_tf_version ("1.15" , 8 , "might need Scan" )
111
- @skip_tf2 ()
112
106
def test_single_dynamic_lstm_forget_bias (self ):
113
107
units = 5
114
108
seq_len = 6
@@ -135,11 +129,9 @@ def func(x):
135
129
feed_dict = {"input_1:0" : x_val }
136
130
137
131
output_names_with_port = ["output:0" , "cell_state:0" ]
138
- self .run_test_case (func , feed_dict , input_names_with_port , output_names_with_port , rtol = 1e-06 ,
139
- graph_validator = lambda g : check_lstm_count (g , 1 ))
132
+ self .run_test_case (func , feed_dict , input_names_with_port , output_names_with_port , rtol = 1e-06 )
140
133
141
134
@check_opset_after_tf_version ("1.15" , 8 , "might need Select" )
142
- @skip_tf2 ()
143
135
def test_single_dynamic_lstm_seq_length_is_const (self ):
144
136
units = 5
145
137
batch_size = 6
@@ -165,11 +157,9 @@ def func(x):
165
157
feed_dict = {"input_1:0" : x_val }
166
158
input_names_with_port = ["input_1:0" ]
167
159
output_names_with_port = ["output:0" , "cell_state:0" ]
168
- self .run_test_case (func , feed_dict , input_names_with_port , output_names_with_port , rtol = 1e-06 ,
169
- graph_validator = lambda g : check_lstm_count (g , 1 ))
160
+ self .run_test_case (func , feed_dict , input_names_with_port , output_names_with_port , rtol = 1e-06 )
170
161
171
162
@check_opset_after_tf_version ("1.15" , 8 , "might need Select" )
172
- @skip_tf2 ()
173
163
def test_single_dynamic_lstm_seq_length_is_not_const (self ):
174
164
for np_dtype in [np .int32 , np .int64 , np .float32 ]:
175
165
units = 5
@@ -197,11 +187,9 @@ def func(x, seq_length):
197
187
feed_dict = {"input_1:0" : x_val , "input_2:0" : y_val }
198
188
input_names_with_port = ["input_1:0" , "input_2:0" ]
199
189
output_names_with_port = ["output:0" , "cell_state:0" ]
200
- self .run_test_case (func , feed_dict , input_names_with_port , output_names_with_port , rtol = 1e-06 ,
201
- graph_validator = lambda g : check_lstm_count (g , 1 ))
190
+ self .run_test_case (func , feed_dict , input_names_with_port , output_names_with_port , rtol = 1e-06 )
202
191
203
192
@check_opset_after_tf_version ("1.15" , 8 , "might need Scan" )
204
- @skip_tf2 ()
205
193
def test_single_dynamic_lstm_placeholder_input (self ):
206
194
units = 5
207
195
x_val = np .array ([[1. , 1. ], [2. , 2. ], [3. , 3. ], [4. , 4. ]], dtype = np .float32 )
@@ -225,11 +213,9 @@ def func(x):
225
213
feed_dict = {"input_1:0" : x_val }
226
214
input_names_with_port = ["input_1:0" ]
227
215
output_names_with_port = ["output:0" , "cell_state:0" ]
228
- self .run_test_case (func , feed_dict , input_names_with_port , output_names_with_port , rtol = 1e-06 ,
229
- graph_validator = lambda g : check_lstm_count (g , 1 ))
216
+ self .run_test_case (func , feed_dict , input_names_with_port , output_names_with_port , rtol = 1e-06 )
230
217
231
218
@check_opset_after_tf_version ("1.15" , 8 , "might need Scan" )
232
- @skip_tf2 ()
233
219
def test_single_dynamic_lstm_ch_zero_state_initializer (self ):
234
220
units = 5
235
221
batch_size = 6
@@ -258,11 +244,9 @@ def func(x):
258
244
feed_dict = {"input_1:0" : x_val }
259
245
input_names_with_port = ["input_1:0" ]
260
246
output_names_with_port = ["output:0" , "cell_state:0" ]
261
- self .run_test_case (func , feed_dict , input_names_with_port , output_names_with_port , rtol = 1e-06 ,
262
- graph_validator = lambda g : check_lstm_count (g , 1 ))
247
+ self .run_test_case (func , feed_dict , input_names_with_port , output_names_with_port , rtol = 1e-06 )
263
248
264
249
@check_opset_after_tf_version ("1.15" , 8 , "might need Scan" )
265
- @skip_tf2 ()
266
250
def test_single_dynamic_lstm_consume_one_of_ch_tuple (self ):
267
251
units = 5
268
252
batch_size = 6
@@ -288,19 +272,17 @@ def func(x):
288
272
feed_dict = {"input_1:0" : x_val }
289
273
input_names_with_port = ["input_1:0" ]
290
274
output_names_with_port = ["output:0" , "cell_state_c:0" ]
291
- self .run_test_case (func , feed_dict , input_names_with_port , output_names_with_port , rtol = 1e-06 ,
292
- graph_validator = lambda g : check_lstm_count (g , 1 ))
275
+ self .run_test_case (func , feed_dict , input_names_with_port , output_names_with_port , rtol = 1e-06 )
293
276
294
277
@check_opset_after_tf_version ("1.15" , 8 , "might need Scan" )
295
- @skip_tf2 ()
296
278
def test_single_dynamic_lstm_random_weights (self , state_is_tuple = True ):
297
279
hidden_size = 5
298
280
batch_size = 6
299
281
x_val = np .array ([[1. , 1. ], [2. , 2. ], [3. , 3. ], [4. , 4. ]], dtype = np .float32 )
300
282
x_val = np .stack ([x_val ] * batch_size )
301
283
302
284
def func (x ):
303
- initializer = tf .random_uniform_initializer (- 1.0 , 1.0 )
285
+ initializer = tf .random_uniform_initializer (- 1.0 , 1.0 , seed = 42 )
304
286
305
287
# no scope
306
288
cell = LSTMCell (
@@ -318,19 +300,17 @@ def func(x):
318
300
feed_dict = {"input_1:0" : x_val }
319
301
input_names_with_port = ["input_1:0" ]
320
302
output_names_with_port = ["output:0" , "cell_state:0" ]
321
- self .run_test_case (func , feed_dict , input_names_with_port , output_names_with_port , rtol = 0.0001 ,
322
- graph_validator = lambda g : check_lstm_count (g , 1 ))
303
+ self .run_test_case (func , feed_dict , input_names_with_port , output_names_with_port , rtol = 0.0001 )
323
304
324
305
@check_opset_after_tf_version ("1.15" , 8 , "might need Select" )
325
- @skip_tf2 ()
326
306
def test_single_dynamic_lstm_random_weights2 (self , state_is_tuple = True ):
327
307
hidden_size = 128
328
308
batch_size = 1
329
309
x_val = np .random .randn (1 , 133 ).astype ('f' )
330
310
x_val = np .stack ([x_val ] * batch_size )
331
311
332
312
def func (x ):
333
- initializer = tf .random_uniform_initializer (0.0 , 1.0 )
313
+ initializer = tf .random_uniform_initializer (0.0 , 1.0 , seed = 42 )
334
314
# no scope
335
315
cell = LSTMCell (
336
316
hidden_size ,
@@ -347,15 +327,12 @@ def func(x):
347
327
feed_dict = {"input_1:0" : x_val }
348
328
input_names_with_port = ["input_1:0" ]
349
329
output_names_with_port = ["output:0" , "cell_state:0" ]
350
- self .run_test_case (func , feed_dict , input_names_with_port , output_names_with_port , rtol = 0.01 ,
351
- graph_validator = lambda g : check_lstm_count (g , 1 ))
330
+ self .run_test_case (func , feed_dict , input_names_with_port , output_names_with_port , rtol = 0.01 )
352
331
353
332
@check_opset_after_tf_version ("1.15" , 8 , "might need Select" )
354
- @skip_tf2 ()
355
333
def test_multiple_dynamic_lstm_state_is_tuple (self ):
356
334
self .internal_test_multiple_dynamic_lstm_with_parameters (True )
357
335
358
- @skip_tf2 ()
359
336
@check_opset_after_tf_version ("1.15" , 8 , "might need Scan" )
360
337
def test_multiple_dynamic_lstm_state_is_not_tuple (self ):
361
338
self .internal_test_multiple_dynamic_lstm_with_parameters (False )
@@ -406,7 +383,7 @@ def func(x):
406
383
self .run_test_case (func , feed_dict , input_names_with_port , output_names_with_port , rtol = 1e-06 )
407
384
408
385
@check_opset_after_tf_version ("1.15" , 8 , "might need Scan" )
409
- @skip_tf2 ()
386
+ @skip_tf2 () # Still failing likely due to inconsistent random number initialization
410
387
def test_dynamic_basiclstm (self ):
411
388
units = 5
412
389
batch_size = 6
@@ -426,20 +403,20 @@ def func(x):
426
403
feed_dict = {"input_1:0" : x_val }
427
404
input_names_with_port = ["input_1:0" ]
428
405
output_names_with_port = ["output:0" , "cell_state:0" ]
429
- self .run_test_case (func , feed_dict , input_names_with_port , output_names_with_port , rtol = 0.0001 , atol = 1e-06 ,
430
- graph_validator = lambda g : check_lstm_count (g , 1 ))
406
+ self .run_test_case (func , feed_dict , input_names_with_port , output_names_with_port , rtol = 0.0001 , atol = 1e-06 )
431
407
432
408
@check_opset_after_tf_version ("1.15" , 8 , "might need Scan" )
433
- @skip_tf2 ()
434
409
def test_dynamic_lstm_output_consumed_only (self ):
435
410
units = 5
436
411
batch_size = 6
437
412
x_val = np .array ([[1. , 1. ], [2. , 2. ], [3. , 3. ]], dtype = np .float32 )
438
413
x_val = np .stack ([x_val ] * batch_size )
439
414
440
415
def func (x ):
416
+ initializer = tf .random_uniform_initializer (0.0 , 1.0 , seed = 42 )
441
417
cell1 = LSTMCell (
442
418
units ,
419
+ initializer = initializer ,
443
420
state_is_tuple = True )
444
421
445
422
outputs , _ = dynamic_rnn (
@@ -452,35 +429,31 @@ def func(x):
452
429
feed_dict = {"input_1:0" : x_val }
453
430
input_names_with_port = ["input_1:0" ]
454
431
output_names_with_port = ["output:0" ]
455
- self .run_test_case (func , feed_dict , input_names_with_port , output_names_with_port , rtol = 0.0001 , atol = 1e-07 ,
456
- graph_validator = lambda g : check_lstm_count (g , 1 ))
432
+ self .run_test_case (func , feed_dict , input_names_with_port , output_names_with_port , rtol = 0.0001 , atol = 1e-07 )
457
433
458
434
@check_opset_after_tf_version ("1.15" , 8 , "might need Scan" )
459
- @skip_tf2 ()
460
435
def test_dynamic_lstm_state_consumed_only (self ):
461
436
units = 5
462
437
batch_size = 6
463
438
x_val = np .array ([[1. , 1. ], [2. , 2. ], [3. , 3. ]], dtype = np .float32 )
464
439
x_val = np .stack ([x_val ] * batch_size )
465
440
466
441
def func (x ):
467
- cell1 = LSTMCell (units , state_is_tuple = True )
442
+ initializer = tf .random_uniform_initializer (0.0 , 1.0 , seed = 42 )
443
+ cell1 = LSTMCell (units , initializer = initializer , state_is_tuple = True )
468
444
_ , cell_state = dynamic_rnn (cell1 , x , dtype = tf .float32 )
469
445
return tf .identity (cell_state , name = "cell_state" )
470
446
471
447
feed_dict = {"input_1:0" : x_val }
472
448
input_names_with_port = ["input_1:0" ]
473
449
output_names_with_port = ["cell_state:0" ]
474
- self .run_test_case (func , feed_dict , input_names_with_port , output_names_with_port , rtol = 0.0001 ,
475
- graph_validator = lambda g : check_lstm_count (g , 1 ))
450
+ self .run_test_case (func , feed_dict , input_names_with_port , output_names_with_port , rtol = 0.0001 )
476
451
477
452
@check_opset_after_tf_version ("1.15" , 10 , "might need ReverseV2" )
478
- @skip_tf2 ()
479
453
def test_dynamic_bilstm_state_is_tuple (self ):
480
454
self .internal_test_dynamic_bilstm_with_parameters (True )
481
455
482
456
@check_opset_after_tf_version ("1.15" , 10 , "might need ReverseV2" )
483
- @skip_tf2 ()
484
457
def test_dynamic_bilstm_state_is_not_tuple (self ):
485
458
self .internal_test_dynamic_bilstm_with_parameters (False )
486
459
@@ -513,11 +486,9 @@ def func(x):
513
486
feed_dict = {"input_1:0" : x_val }
514
487
input_names_with_port = ["input_1:0" ]
515
488
output_names_with_port = ["output:0" , "cell_state:0" ]
516
- self .run_test_case (func , feed_dict , input_names_with_port , output_names_with_port , rtol = 1e-06 ,
517
- graph_validator = lambda g : check_lstm_count (g , 1 ))
489
+ self .run_test_case (func , feed_dict , input_names_with_port , output_names_with_port , rtol = 1e-06 )
518
490
519
491
@check_opset_after_tf_version ("1.15" , 10 , "might need ReverseV2" )
520
- @skip_tf2 ()
521
492
def test_dynamic_bilstm_output_consumed_only (self , state_is_tuple = True ):
522
493
units = 5
523
494
batch_size = 6
@@ -547,11 +518,9 @@ def func(x):
547
518
feed_dict = {"input_1:0" : x_val }
548
519
input_names_with_port = ["input_1:0" ]
549
520
output_names_with_port = ["output:0" ]
550
- self .run_test_case (func , feed_dict , input_names_with_port , output_names_with_port , rtol = 1e-06 ,
551
- graph_validator = lambda g : check_lstm_count (g , 1 ))
521
+ self .run_test_case (func , feed_dict , input_names_with_port , output_names_with_port , rtol = 1e-06 )
552
522
553
523
@check_opset_after_tf_version ("1.15" , 10 , "might need ReverseV2" )
554
- @skip_tf2 ()
555
524
def test_dynamic_bilstm_state_consumed_only (self , state_is_tuple = True ):
556
525
units = 5
557
526
batch_size = 6
@@ -581,11 +550,9 @@ def func(x):
581
550
feed_dict = {"input_1:0" : x_val }
582
551
input_names_with_port = ["input_1:0" ]
583
552
output_names_with_port = ["cell_state:0" ]
584
- self .run_test_case (func , feed_dict , input_names_with_port , output_names_with_port , rtol = 1e-06 ,
585
- graph_validator = lambda g : check_lstm_count (g , 1 ))
553
+ self .run_test_case (func , feed_dict , input_names_with_port , output_names_with_port , rtol = 1e-06 )
586
554
587
555
@check_opset_after_tf_version ("1.15" , 10 , "might need ReverseV2" )
588
- @skip_tf2 ()
589
556
def test_dynamic_bilstm_outputs_partially_consumed (self , state_is_tuple = True ):
590
557
units = 5
591
558
batch_size = 6
@@ -615,11 +582,9 @@ def func(x):
615
582
feed_dict = {"input_1:0" : x_val }
616
583
input_names_with_port = ["input_1:0" ]
617
584
output_names_with_port = ["output:0" , "cell_state:0" ]
618
- self .run_test_case (func , feed_dict , input_names_with_port , output_names_with_port , rtol = 1e-06 ,
619
- graph_validator = lambda g : check_lstm_count (g , 1 ))
585
+ self .run_test_case (func , feed_dict , input_names_with_port , output_names_with_port , rtol = 1e-06 )
620
586
621
587
@check_opset_after_tf_version ("1.15" , 10 , "might need ReverseV2" )
622
- @skip_tf2 ()
623
588
def test_dynamic_bilstm_unknown_batch_size (self , state_is_tuple = True ):
624
589
units = 5
625
590
batch_size = 6
@@ -649,20 +614,23 @@ def func(x):
649
614
feed_dict = {"input_1:0" : x_val }
650
615
input_names_with_port = ["input_1:0" ]
651
616
output_names_with_port = ["cell_state:0" ]
652
- self .run_test_case (func , feed_dict , input_names_with_port , output_names_with_port , rtol = 1e-06 ,
653
- graph_validator = lambda g : check_lstm_count (g , 1 ))
617
+ self .run_test_case (func , feed_dict , input_names_with_port , output_names_with_port , rtol = 1e-06 )
654
618
655
619
@check_opset_after_tf_version ("1.15" , 10 , "might need ReverseV2" )
656
- @skip_tf2 ( )
620
+ @skip_tf_versions ( "2.1" , "Bug in TF 2.1" )
657
621
def test_dynamic_multi_bilstm_with_same_input_hidden_size (self ):
658
622
batch_size = 10
659
623
x_val = np .array ([[1. , 1. ], [2. , 2. ], [3. , 3. ]], dtype = np .float32 )
660
624
x_val = np .stack ([x_val ] * batch_size )
661
625
662
626
def func (x ):
627
+ initializer1 = tf .random_uniform_initializer (0.0 , 1.0 , seed = 42 )
628
+ initializer2 = tf .random_uniform_initializer (0.0 , 1.0 , seed = 43 )
629
+ initializer3 = tf .random_uniform_initializer (0.0 , 1.0 , seed = 44 )
630
+ initializer4 = tf .random_uniform_initializer (0.0 , 1.0 , seed = 45 )
663
631
units = 5
664
- cell1 = LSTMCell (units , name = "cell1" )
665
- cell2 = LSTMCell (units , name = "cell2" )
632
+ cell1 = LSTMCell (units , name = "cell1" , initializer = initializer1 )
633
+ cell2 = LSTMCell (units , name = "cell2" , initializer = initializer2 )
666
634
outputs_1 , cell_state_1 = bidirectional_dynamic_rnn (
667
635
cell1 ,
668
636
cell2 ,
@@ -672,8 +640,8 @@ def func(x):
672
640
)
673
641
674
642
units = 10
675
- cell3 = LSTMCell (units , name = "cell3" )
676
- cell4 = LSTMCell (units , name = "cell4" )
643
+ cell3 = LSTMCell (units , name = "cell3" , initializer = initializer3 )
644
+ cell4 = LSTMCell (units , name = "cell4" , initializer = initializer4 )
677
645
outputs_2 , cell_state_2 = bidirectional_dynamic_rnn (
678
646
cell3 ,
679
647
cell4 ,
@@ -691,10 +659,9 @@ def func(x):
691
659
input_names_with_port = ["input_1:0" ]
692
660
output_names_with_port = ["output_1:0" , "cell_state_1:0" , "output_2:0" , "cell_state_2:0" ]
693
661
self .run_test_case (func , feed_dict , input_names_with_port , output_names_with_port , rtol = 1e-3 , atol = 1e-06 )
694
- # graph_validator=lambda g: check_lstm_count(g, 2))
695
662
696
663
@check_opset_after_tf_version ("1.15" , 10 , "might need ReverseV2" )
697
- @skip_tf2 ( )
664
+ @skip_tf_versions ( "2.1" , "Bug in TF 2.1" )
698
665
def test_dynamic_multi_bilstm_with_same_input_seq_len (self ):
699
666
units = 5
700
667
batch_size = 10
@@ -703,9 +670,11 @@ def test_dynamic_multi_bilstm_with_same_input_seq_len(self):
703
670
seq_len_val = np .array ([3 ], dtype = np .int32 )
704
671
705
672
def func (x , y1 , y2 ):
673
+ initializer1 = tf .random_uniform_initializer (0.0 , 1.0 , seed = 42 )
674
+ initializer2 = tf .random_uniform_initializer (0.0 , 1.0 , seed = 43 )
706
675
seq_len1 = tf .tile (y1 , [batch_size ])
707
- cell1 = LSTMCell (units )
708
- cell2 = LSTMCell (units )
676
+ cell1 = LSTMCell (units , initializer = initializer1 )
677
+ cell2 = LSTMCell (units , initializer = initializer2 )
709
678
outputs_1 , cell_state_1 = bidirectional_dynamic_rnn (
710
679
cell1 ,
711
680
cell2 ,
@@ -714,10 +683,11 @@ def func(x, y1, y2):
714
683
dtype = tf .float32 ,
715
684
scope = "bilstm_1"
716
685
)
717
-
686
+ initializer1 = tf .random_uniform_initializer (0.0 , 1.0 , seed = 44 )
687
+ initializer2 = tf .random_uniform_initializer (0.0 , 1.0 , seed = 45 )
718
688
seq_len2 = tf .tile (y2 , [batch_size ])
719
- cell1 = LSTMCell (units )
720
- cell2 = LSTMCell (units )
689
+ cell1 = LSTMCell (units , initializer = initializer1 )
690
+ cell2 = LSTMCell (units , initializer = initializer2 )
721
691
outputs_2 , cell_state_2 = bidirectional_dynamic_rnn (
722
692
cell1 ,
723
693
cell2 ,
@@ -736,7 +706,6 @@ def func(x, y1, y2):
736
706
input_names_with_port = ["input_1:0" , "input_2:0" , "input_3:0" ]
737
707
output_names_with_port = ["output_1:0" , "cell_state_1:0" , "output_2:0" , "cell_state_2:0" ]
738
708
self .run_test_case (func , feed_dict , input_names_with_port , output_names_with_port , rtol = 1e-3 , atol = 1e-06 )
739
- # graph_validator=lambda g: check_lstm_count(g, 2))
740
709
741
710
742
711
if __name__ == '__main__' :
0 commit comments