9
9
from tensorflow .python .ops import init_ops
10
10
from tensorflow .python .ops import variable_scope
11
11
from backend_test_base import Tf2OnnxBackendTestBase
12
- from common import unittest_main , check_gru_count , check_opset_after_tf_version , skip_tf2
12
+ from common import unittest_main , check_gru_count , check_opset_after_tf_version , check_op_count
13
13
from tf2onnx .tf_loader import is_tf2
14
14
15
15
# pylint: disable=missing-docstring,invalid-name,unused-argument,using-constant-test,cell-var-from-loop
35
35
# TODO: as a workaround, set batch_size to 1 for now to bypass a onnxruntime bug, revert it when the bug is fixed
36
36
class GRUTests (Tf2OnnxBackendTestBase ):
37
37
38
- def run_test_case (self , * args , ** kwargs ): #pylint: disable=arguments-differ
38
+ def run_test_case (self , * args , graph_validator = None , ** kwargs ): #pylint: disable=arguments-differ
39
39
# TF GRU has an unknown dim
40
40
tmp = self .config .allow_missing_shapes
41
41
self .config .allow_missing_shapes = True
42
+ def new_graph_validator (g ):
43
+ good = True
44
+ if graph_validator is not None :
45
+ good = good and graph_validator (g )
46
+ if is_tf2 () and ':' in g .outputs [0 ]:
47
+ # Only check for tf2 and tfjs, not tflite
48
+ good = good and check_op_count (g , "Loop" , 0 , disabled = False )
49
+ good = good and check_op_count (g , "Scan" , 0 , disabled = False )
50
+ return good
42
51
try :
43
- super ().run_test_case (* args , ** kwargs )
52
+ super ().run_test_case (* args , graph_validator = new_graph_validator , ** kwargs )
44
53
finally :
45
54
self .config .allow_missing_shapes = tmp
46
55
47
56
@check_opset_after_tf_version ("1.15" , 8 , "might need Scan" )
48
- @skip_tf2 ()
49
57
def test_single_dynamic_gru (self ):
50
58
units = 5
51
59
batch_size = 1
@@ -56,7 +64,9 @@ def func(x):
56
64
# no scope
57
65
cell = GRUCell (
58
66
units ,
59
- activation = None )
67
+ activation = None ,
68
+ kernel_initializer = tf .random_uniform_initializer (0.0 , 1.0 , seed = 42 ),
69
+ bias_initializer = tf .random_uniform_initializer (0.0 , 1.0 , seed = 43 ))
60
70
outputs , cell_state = dynamic_rnn (
61
71
cell ,
62
72
x ,
@@ -71,7 +81,6 @@ def func(x):
71
81
graph_validator = lambda g : check_gru_count (g , 1 ))
72
82
73
83
@check_opset_after_tf_version ("1.15" , 8 , "might need Scan" )
74
- @skip_tf2 ()
75
84
def test_multiple_dynamic_gru (self ):
76
85
units = 5
77
86
batch_size = 1
@@ -84,7 +93,9 @@ def func(x):
84
93
# no scope
85
94
cell = GRUCell (
86
95
units ,
87
- activation = None )
96
+ activation = None ,
97
+ kernel_initializer = tf .random_uniform_initializer (0.0 , 1.0 , seed = 42 ),
98
+ bias_initializer = tf .random_uniform_initializer (0.0 , 1.0 , seed = 43 ))
88
99
outputs , cell_state = dynamic_rnn (
89
100
cell ,
90
101
x ,
@@ -95,7 +106,9 @@ def func(x):
95
106
# given scope
96
107
cell = GRUCell (
97
108
units ,
98
- activation = None )
109
+ activation = None ,
110
+ kernel_initializer = tf .random_uniform_initializer (0.0 , 1.0 , seed = 44 ),
111
+ bias_initializer = tf .random_uniform_initializer (0.0 , 1.0 , seed = 45 ))
99
112
with variable_scope .variable_scope ("root1" ) as scope :
100
113
outputs , cell_state = dynamic_rnn (
101
114
cell ,
@@ -115,7 +128,6 @@ def func(x):
115
128
# graph_validator=lambda g: check_gru_count(g, 2))
116
129
117
130
@check_opset_after_tf_version ("1.15" , 8 , "might need Select" )
118
- @skip_tf2 ()
119
131
def test_single_dynamic_gru_seq_length_is_const (self ):
120
132
units = 5
121
133
batch_size = 1
@@ -143,7 +155,6 @@ def func(x):
143
155
graph_validator = lambda g : check_gru_count (g , 1 ))
144
156
145
157
@check_opset_after_tf_version ("1.15" , 8 , "might need Select" )
146
- @skip_tf2 ()
147
158
def test_single_dynamic_gru_seq_length_is_not_const (self ):
148
159
for np_dtype in [np .int32 , np .int64 , np .float32 ]:
149
160
units = 5
@@ -174,7 +185,6 @@ def func(x, seq_length):
174
185
graph_validator = lambda g : check_gru_count (g , 1 ))
175
186
176
187
@check_opset_after_tf_version ("1.15" , 8 , "might need Scan" )
177
- @skip_tf2 ()
178
188
def test_single_dynamic_gru_placeholder_input (self ):
179
189
units = 5
180
190
x_val = np .array ([[1. , 1. ], [2. , 2. ], [3. , 3. ], [4. , 4. ]], dtype = np .float32 )
@@ -200,7 +210,6 @@ def func(x):
200
210
graph_validator = lambda g : check_gru_count (g , 1 ))
201
211
202
212
@check_opset_after_tf_version ("1.15" , 8 , "might need Scan" )
203
- @skip_tf2 ()
204
213
def test_single_dynamic_gru_ch_zero_state_initializer (self ):
205
214
units = 5
206
215
batch_size = 1
@@ -231,15 +240,14 @@ def func(x):
231
240
graph_validator = lambda g : check_gru_count (g , 1 ))
232
241
233
242
@check_opset_after_tf_version ("1.15" , 8 , "might need Scan" )
234
- @skip_tf2 ()
235
243
def test_single_dynamic_gru_random_weights (self ):
236
244
hidden_size = 5
237
245
batch_size = 1
238
246
x_val = np .array ([[1. , 1. ], [2. , 2. ], [3. , 3. ], [4. , 4. ]], dtype = np .float32 )
239
247
x_val = np .stack ([x_val ] * batch_size )
240
248
241
249
def func (x ):
242
- initializer = tf .random_uniform_initializer (- 1.0 , 1.0 )
250
+ initializer = tf .random_uniform_initializer (- 1.0 , 1.0 , seed = 42 )
243
251
244
252
# no scope
245
253
cell = GRUCell (
@@ -260,15 +268,16 @@ def func(x):
260
268
graph_validator = lambda g : check_gru_count (g , 1 ))
261
269
262
270
@check_opset_after_tf_version ("1.15" , 8 , "might need Scan" )
263
- @skip_tf2 ()
264
271
def test_single_dynamic_gru_random_weights2 (self ):
265
272
hidden_size = 128
266
273
batch_size = 1
267
274
x_val = np .random .randn (1 , 133 ).astype ('f' )
268
275
x_val = np .stack ([x_val ] * batch_size )
269
276
277
+
270
278
def func (x ):
271
- initializer = tf .random_uniform_initializer (0.0 , 1.0 )
279
+ #initializer = tf.constant_initializer(5.0)
280
+ initializer = tf .random_uniform_initializer (0.0 , 1.0 , seed = 42 )
272
281
# no scope
273
282
cell = GRUCell (
274
283
hidden_size ,
@@ -288,15 +297,14 @@ def func(x):
288
297
graph_validator = lambda g : check_gru_count (g , 1 ))
289
298
290
299
@check_opset_after_tf_version ("1.15" , 8 , "might need Scan" )
291
- @skip_tf2 ()
292
300
def test_dynamic_gru_output_consumed_only (self ):
293
301
units = 5
294
302
batch_size = 6
295
303
x_val = np .array ([[1. , 1. ], [2. , 2. ], [3. , 3. ]], dtype = np .float32 )
296
304
x_val = np .stack ([x_val ] * batch_size )
297
305
298
306
def func (x ):
299
- initializer = tf .random_uniform_initializer (- 1.0 , 1.0 )
307
+ initializer = tf .random_uniform_initializer (- 1.0 , 1.0 , seed = 42 )
300
308
cell1 = GRUCell (
301
309
units ,
302
310
kernel_initializer = initializer )
@@ -315,15 +323,14 @@ def func(x):
315
323
graph_validator = lambda g : check_gru_count (g , 1 ))
316
324
317
325
@check_opset_after_tf_version ("1.15" , 8 , "might need Scan" )
318
- @skip_tf2 ()
319
326
def test_dynamic_gru_state_consumed_only (self ):
320
327
units = 5
321
328
batch_size = 6
322
329
x_val = np .array ([[1. , 1. ], [2. , 2. ], [3. , 3. ]], dtype = np .float32 )
323
330
x_val = np .stack ([x_val ] * batch_size )
324
331
325
332
def func (x ):
326
- initializer = tf .random_uniform_initializer (- 1.0 , 1.0 )
333
+ initializer = tf .random_uniform_initializer (- 1.0 , 1.0 , seed = 42 )
327
334
cell1 = GRUCell (
328
335
units ,
329
336
kernel_initializer = initializer )
@@ -342,7 +349,6 @@ def func(x):
342
349
graph_validator = lambda g : check_gru_count (g , 1 ))
343
350
344
351
@check_opset_after_tf_version ("1.15" , 10 , "might need ReverseV2" )
345
- @skip_tf2 ()
346
352
def test_dynamic_bigru (self ):
347
353
units = 5
348
354
batch_size = 1
@@ -374,7 +380,6 @@ def func(x):
374
380
graph_validator = lambda g : check_gru_count (g , 1 ))
375
381
376
382
@check_opset_after_tf_version ("1.15" , 10 , "might need ReverseV2" )
377
- @skip_tf2 ()
378
383
def test_dynamic_bigru_output_consumed_only (self ):
379
384
units = 5
380
385
batch_size = 1
@@ -406,7 +411,6 @@ def func(x):
406
411
graph_validator = lambda g : check_gru_count (g , 1 ))
407
412
408
413
@check_opset_after_tf_version ("1.15" , 10 , "might need ReverseV2" )
409
- @skip_tf2 ()
410
414
def test_dynamic_bigru_state_consumed_only (self ):
411
415
units = 5
412
416
batch_size = 1
@@ -438,7 +442,6 @@ def func(x):
438
442
graph_validator = lambda g : check_gru_count (g , 1 ))
439
443
440
444
@check_opset_after_tf_version ("1.15" , 10 , "might need ReverseV2" )
441
- @skip_tf2 ()
442
445
def test_dynamic_bidirectional_but_one_gru (self ):
443
446
units = 5
444
447
batch_size = 1
@@ -467,7 +470,6 @@ def func(x):
467
470
graph_validator = lambda g : check_gru_count (g , 1 ))
468
471
469
472
@check_opset_after_tf_version ("1.15" , 10 , "might need ReverseV2" )
470
- @skip_tf2 ()
471
473
def test_dynamic_bidirectional_but_one_gru_and_output_consumed_only (self ):
472
474
units = 5
473
475
batch_size = 1
@@ -478,7 +480,9 @@ def func(x):
478
480
479
481
# bigru, no scope
480
482
cell = GRUCell (
481
- units )
483
+ units ,
484
+ kernel_initializer = tf .random_uniform_initializer (0.0 , 1.0 , seed = 42 ),
485
+ bias_initializer = tf .random_uniform_initializer (0.0 , 1.0 , seed = 43 ))
482
486
outputs , _ = bidirectional_dynamic_rnn (
483
487
cell ,
484
488
cell ,
@@ -494,7 +498,6 @@ def func(x):
494
498
graph_validator = lambda g : check_gru_count (g , 1 ))
495
499
496
500
@check_opset_after_tf_version ("1.15" , 10 , "might need ReverseV2" )
497
- @skip_tf2 ()
498
501
def test_dynamic_bidirectional_but_one_gru_and_state_consumed_only (self ):
499
502
units = 5
500
503
batch_size = 1
@@ -505,7 +508,9 @@ def func(x):
505
508
506
509
# bigru, no scope
507
510
cell = GRUCell (
508
- units )
511
+ units ,
512
+ kernel_initializer = tf .random_uniform_initializer (0.0 , 1.0 , seed = 42 ),
513
+ bias_initializer = tf .random_uniform_initializer (0.0 , 1.0 , seed = 43 ))
509
514
_ , cell_state = bidirectional_dynamic_rnn (
510
515
cell ,
511
516
cell ,
@@ -521,7 +526,6 @@ def func(x):
521
526
graph_validator = lambda g : check_gru_count (g , 1 ))
522
527
523
528
@check_opset_after_tf_version ("1.15" , 10 , "might need ReverseV2" )
524
- @skip_tf2 ()
525
529
def test_dynamic_bigru_unknown_batch_size (self ):
526
530
units = 5
527
531
batch_size = 6
@@ -530,8 +534,14 @@ def test_dynamic_bigru_unknown_batch_size(self):
530
534
531
535
def func (x ):
532
536
533
- cell1 = GRUCell (units )
534
- cell2 = GRUCell (units )
537
+ cell1 = GRUCell (
538
+ units ,
539
+ kernel_initializer = tf .random_uniform_initializer (0.0 , 1.0 , seed = 42 ),
540
+ bias_initializer = tf .random_uniform_initializer (0.0 , 1.0 , seed = 43 ))
541
+ cell2 = GRUCell (
542
+ units ,
543
+ kernel_initializer = tf .random_uniform_initializer (0.0 , 1.0 , seed = 44 ),
544
+ bias_initializer = tf .random_uniform_initializer (0.0 , 1.0 , seed = 45 ))
535
545
_ , cell_state = bidirectional_dynamic_rnn (
536
546
cell1 ,
537
547
cell2 ,
@@ -548,7 +558,6 @@ def func(x):
548
558
graph_validator = lambda g : check_gru_count (g , 1 ))
549
559
550
560
@check_opset_after_tf_version ("1.15" , 10 , "might need ReverseV2" )
551
- @skip_tf2 ()
552
561
def test_dynamic_bigru_outputs_partially_consumed (self ):
553
562
units = 5
554
563
batch_size = 6
@@ -557,8 +566,14 @@ def test_dynamic_bigru_outputs_partially_consumed(self):
557
566
558
567
def func (x ):
559
568
560
- cell1 = GRUCell (units )
561
- cell2 = GRUCell (units )
569
+ cell1 = GRUCell (
570
+ units ,
571
+ kernel_initializer = tf .random_uniform_initializer (0.0 , 1.0 , seed = 42 ),
572
+ bias_initializer = tf .random_uniform_initializer (0.0 , 1.0 , seed = 43 ))
573
+ cell2 = GRUCell (
574
+ units ,
575
+ kernel_initializer = tf .random_uniform_initializer (0.0 , 1.0 , seed = 44 ),
576
+ bias_initializer = tf .random_uniform_initializer (0.0 , 1.0 , seed = 45 ))
562
577
(output_fw , _ ), (_ , state_bw ) = bidirectional_dynamic_rnn (
563
578
cell1 ,
564
579
cell2 ,
@@ -574,7 +589,6 @@ def func(x):
574
589
graph_validator = lambda g : check_gru_count (g , 1 ))
575
590
576
591
@check_opset_after_tf_version ("1.15" , 10 , "might need ReverseV2" )
577
- @skip_tf2 ()
578
592
def test_dynamic_multi_bigru_with_same_input_hidden_size (self ):
579
593
batch_size = 10
580
594
x_val = np .array ([[1. , 1. ], [2. , 2. ], [3. , 3. ]], dtype = np .float32 )
@@ -583,8 +597,14 @@ def test_dynamic_multi_bigru_with_same_input_hidden_size(self):
583
597
def func (x ):
584
598
# bigru, no scope
585
599
units = 5
586
- cell1 = GRUCell (units )
587
- cell2 = GRUCell (units )
600
+ cell1 = GRUCell (
601
+ units ,
602
+ kernel_initializer = tf .random_uniform_initializer (0.0 , 1.0 , seed = 42 ),
603
+ bias_initializer = tf .random_uniform_initializer (0.0 , 1.0 , seed = 43 ))
604
+ cell2 = GRUCell (
605
+ units ,
606
+ kernel_initializer = tf .random_uniform_initializer (0.0 , 1.0 , seed = 44 ),
607
+ bias_initializer = tf .random_uniform_initializer (0.0 , 1.0 , seed = 45 ))
588
608
outputs_1 , cell_state_1 = bidirectional_dynamic_rnn (
589
609
cell1 ,
590
610
cell2 ,
@@ -594,8 +614,14 @@ def func(x):
594
614
)
595
615
596
616
units = 10
597
- cell1 = GRUCell (units )
598
- cell2 = GRUCell (units )
617
+ cell1 = GRUCell (
618
+ units ,
619
+ kernel_initializer = tf .random_uniform_initializer (0.0 , 1.0 , seed = 42 ),
620
+ bias_initializer = tf .random_uniform_initializer (0.0 , 1.0 , seed = 43 ))
621
+ cell2 = GRUCell (
622
+ units ,
623
+ kernel_initializer = tf .random_uniform_initializer (0.0 , 1.0 , seed = 44 ),
624
+ bias_initializer = tf .random_uniform_initializer (0.0 , 1.0 , seed = 45 ))
599
625
outputs_2 , cell_state_2 = bidirectional_dynamic_rnn (
600
626
cell1 ,
601
627
cell2 ,
@@ -616,7 +642,6 @@ def func(x):
616
642
# graph_validator=lambda g: check_gru_count(g, 2))
617
643
618
644
@check_opset_after_tf_version ("1.15" , 10 , "might need ReverseV2" )
619
- @skip_tf2 ()
620
645
def test_dynamic_multi_bigru_with_same_input_seq_len (self ):
621
646
units = 5
622
647
batch_size = 10
@@ -626,8 +651,14 @@ def test_dynamic_multi_bigru_with_same_input_seq_len(self):
626
651
627
652
def func (x , y1 , y2 ):
628
653
seq_len1 = tf .tile (y1 , [batch_size ])
629
- cell1 = GRUCell (units )
630
- cell2 = GRUCell (units )
654
+ cell1 = GRUCell (
655
+ units ,
656
+ kernel_initializer = tf .random_uniform_initializer (0.0 , 1.0 , seed = 42 ),
657
+ bias_initializer = tf .random_uniform_initializer (0.0 , 1.0 , seed = 43 ))
658
+ cell2 = GRUCell (
659
+ units ,
660
+ kernel_initializer = tf .random_uniform_initializer (0.0 , 1.0 , seed = 44 ),
661
+ bias_initializer = tf .random_uniform_initializer (0.0 , 1.0 , seed = 45 ))
631
662
outputs_1 , cell_state_1 = bidirectional_dynamic_rnn (
632
663
cell1 ,
633
664
cell2 ,
@@ -637,8 +668,14 @@ def func(x, y1, y2):
637
668
scope = "bigru_1"
638
669
)
639
670
seq_len2 = tf .tile (y2 , [batch_size ])
640
- cell1 = GRUCell (units )
641
- cell2 = GRUCell (units )
671
+ cell1 = GRUCell (
672
+ units ,
673
+ kernel_initializer = tf .random_uniform_initializer (0.0 , 1.0 , seed = 46 ),
674
+ bias_initializer = tf .random_uniform_initializer (0.0 , 1.0 , seed = 47 ))
675
+ cell2 = GRUCell (
676
+ units ,
677
+ kernel_initializer = tf .random_uniform_initializer (0.0 , 1.0 , seed = 48 ),
678
+ bias_initializer = tf .random_uniform_initializer (0.0 , 1.0 , seed = 49 ))
642
679
outputs_2 , cell_state_2 = bidirectional_dynamic_rnn (
643
680
cell1 ,
644
681
cell2 ,
0 commit comments