16
16
from tensorflow .python .ops import init_ops
17
17
from tensorflow .python .ops import variable_scope
18
18
from backend_test_base import Tf2OnnxBackendTestBase
19
- from common import unittest_main
19
+ from common import unittest_main , check_gru_count
20
20
21
21
22
22
# pylint: disable=missing-docstring,invalid-name,unused-argument,using-constant-test
@@ -47,7 +47,8 @@ def test_single_dynamic_gru(self):
47
47
input_names_with_port = ["input_1:0" ]
48
48
feed_dict = {"input_1:0" : x_val }
49
49
output_names_with_port = ["output:0" , "cell_state:0" ]
50
- self .run_test_case (feed_dict , input_names_with_port , output_names_with_port , rtol = 1e-03 , atol = 1e-06 )
50
+ self .run_test_case (feed_dict , input_names_with_port , output_names_with_port , rtol = 1e-03 , atol = 1e-06 ,
51
+ graph_validator = lambda g : check_gru_count (g , 1 ))
51
52
52
53
def test_multiple_dynamic_gru (self ):
53
54
units = 5
@@ -93,7 +94,8 @@ def test_multiple_dynamic_gru(self):
93
94
feed_dict = {"input_1:0" : x_val }
94
95
input_names_with_port = ["input_1:0" ]
95
96
output_names_with_port = ["output:0" , "cell_state:0" ]
96
- self .run_test_case (feed_dict , input_names_with_port , output_names_with_port , rtol = 1e-3 , atol = 1e-06 )
97
+ self .run_test_case (feed_dict , input_names_with_port , output_names_with_port , rtol = 1e-3 , atol = 1e-06 ,
98
+ graph_validator = lambda g : check_gru_count (g , 2 ))
97
99
98
100
def test_single_dynamic_gru_seq_length_is_const (self ):
99
101
units = 5
@@ -119,7 +121,8 @@ def test_single_dynamic_gru_seq_length_is_const(self):
119
121
feed_dict = {"input_1:0" : x_val }
120
122
input_names_with_port = ["input_1:0" ]
121
123
output_names_with_port = ["output:0" , "cell_state:0" ]
122
- self .run_test_case (feed_dict , input_names_with_port , output_names_with_port , rtol = 1e-3 , atol = 1e-06 )
124
+ self .run_test_case (feed_dict , input_names_with_port , output_names_with_port , rtol = 1e-3 , atol = 1e-06 ,
125
+ graph_validator = lambda g : check_gru_count (g , 1 ))
123
126
124
127
def test_single_dynamic_gru_seq_length_is_not_const (self ):
125
128
units = 5
@@ -148,7 +151,8 @@ def test_single_dynamic_gru_seq_length_is_not_const(self):
148
151
feed_dict = {"input_1:0" : x_val , "input_2:0" : y_val }
149
152
input_names_with_port = ["input_1:0" , "input_2:0" ]
150
153
output_names_with_port = ["output:0" , "cell_state:0" ]
151
- self .run_test_case (feed_dict , input_names_with_port , output_names_with_port , rtol = 1e-03 , atol = 1e-06 )
154
+ self .run_test_case (feed_dict , input_names_with_port , output_names_with_port , rtol = 1e-03 , atol = 1e-06 ,
155
+ graph_validator = lambda g : check_gru_count (g , 1 ))
152
156
153
157
def test_single_dynamic_gru_placeholder_input (self ):
154
158
units = 5
@@ -172,7 +176,8 @@ def test_single_dynamic_gru_placeholder_input(self):
172
176
feed_dict = {"input_1:0" : x_val }
173
177
input_names_with_port = ["input_1:0" ]
174
178
output_names_with_port = ["output:0" , "cell_state:0" ]
175
- self .run_test_case (feed_dict , input_names_with_port , output_names_with_port , rtol = 1e-03 , atol = 1e-06 )
179
+ self .run_test_case (feed_dict , input_names_with_port , output_names_with_port , rtol = 1e-03 , atol = 1e-06 ,
180
+ graph_validator = lambda g : check_gru_count (g , 1 ))
176
181
177
182
def test_single_dynamic_gru_ch_zero_state_initializer (self ):
178
183
units = 5
@@ -201,7 +206,8 @@ def test_single_dynamic_gru_ch_zero_state_initializer(self):
201
206
feed_dict = {"input_1:0" : x_val }
202
207
input_names_with_port = ["input_1:0" ]
203
208
output_names_with_port = ["output:0" , "cell_state:0" ]
204
- self .run_test_case (feed_dict , input_names_with_port , output_names_with_port , rtol = 1e-03 , atol = 1e-06 )
209
+ self .run_test_case (feed_dict , input_names_with_port , output_names_with_port , rtol = 1e-03 , atol = 1e-06 ,
210
+ graph_validator = lambda g : check_gru_count (g , 1 ))
205
211
206
212
@unittest .skip ("FIXME: disable for now for accuracy problem" )
207
213
def test_single_dynamic_gru_random_weights (self ):
@@ -229,7 +235,8 @@ def test_single_dynamic_gru_random_weights(self):
229
235
feed_dict = {"input_1:0" : x_val }
230
236
input_names_with_port = ["input_1:0" ]
231
237
output_names_with_port = ["output:0" , "cell_state:0" ]
232
- self .run_test_case (feed_dict , input_names_with_port , output_names_with_port , 0.0001 )
238
+ self .run_test_case (feed_dict , input_names_with_port , output_names_with_port , 0.0001 ,
239
+ graph_validator = lambda g : check_gru_count (g , 1 ))
233
240
234
241
@unittest .skip ("FIXME: disable for now for accuracy problem" )
235
242
def test_single_dynamic_gru_random_weights2 (self ):
@@ -256,7 +263,8 @@ def test_single_dynamic_gru_random_weights2(self):
256
263
feed_dict = {"input_1:0" : x_val }
257
264
input_names_with_port = ["input_1:0" ]
258
265
output_names_with_port = ["output:0" , "cell_state:0" ]
259
- self .run_test_case (feed_dict , input_names_with_port , output_names_with_port , 0.01 )
266
+ self .run_test_case (feed_dict , input_names_with_port , output_names_with_port , 0.01 ,
267
+ graph_validator = lambda g : check_gru_count (g , 1 ))
260
268
261
269
def test_dynamic_gru_output_consumed_only (self ):
262
270
units = 5
@@ -280,7 +288,8 @@ def test_dynamic_gru_output_consumed_only(self):
280
288
feed_dict = {"input_1:0" : x_val }
281
289
input_names_with_port = ["input_1:0" ]
282
290
output_names_with_port = ["output:0" ]
283
- self .run_test_case (feed_dict , input_names_with_port , output_names_with_port , 0.0001 )
291
+ self .run_test_case (feed_dict , input_names_with_port , output_names_with_port , 0.0001 ,
292
+ graph_validator = lambda g : check_gru_count (g , 1 ))
284
293
285
294
def test_dynamic_gru_state_consumed_only (self ):
286
295
units = 5
@@ -304,7 +313,8 @@ def test_dynamic_gru_state_consumed_only(self):
304
313
feed_dict = {"input_1:0" : x_val }
305
314
input_names_with_port = ["input_1:0" ]
306
315
output_names_with_port = ["cell_state:0" ]
307
- self .run_test_case (feed_dict , input_names_with_port , output_names_with_port , rtol = 0.0001 , atol = 1e-06 )
316
+ self .run_test_case (feed_dict , input_names_with_port , output_names_with_port , rtol = 0.0001 , atol = 1e-06 ,
317
+ graph_validator = lambda g : check_gru_count (g , 1 ))
308
318
309
319
def test_dynamic_bigru (self ):
310
320
units = 5
@@ -335,7 +345,8 @@ def test_dynamic_bigru(self):
335
345
feed_dict = {"input_1:0" : x_val }
336
346
input_names_with_port = ["input_1:0" ]
337
347
output_names_with_port = ["output:0" , "cell_state:0" ]
338
- self .run_test_case (feed_dict , input_names_with_port , output_names_with_port , rtol = 1e-3 , atol = 1e-06 )
348
+ self .run_test_case (feed_dict , input_names_with_port , output_names_with_port , rtol = 1e-3 , atol = 1e-06 ,
349
+ graph_validator = lambda g : check_gru_count (g , 1 ))
339
350
340
351
def test_dynamic_bigru_output_consumed_only (self ):
341
352
units = 5
@@ -365,7 +376,8 @@ def test_dynamic_bigru_output_consumed_only(self):
365
376
feed_dict = {"input_1:0" : x_val }
366
377
input_names_with_port = ["input_1:0" ]
367
378
output_names_with_port = ["output:0" ]
368
- self .run_test_case (feed_dict , input_names_with_port , output_names_with_port , rtol = 1e-3 , atol = 1e-06 )
379
+ self .run_test_case (feed_dict , input_names_with_port , output_names_with_port , rtol = 1e-3 , atol = 1e-06 ,
380
+ graph_validator = lambda g : check_gru_count (g , 1 ))
369
381
370
382
def test_dynamic_bigru_state_consumed_only (self ):
371
383
units = 5
@@ -395,7 +407,8 @@ def test_dynamic_bigru_state_consumed_only(self):
395
407
feed_dict = {"input_1:0" : x_val }
396
408
input_names_with_port = ["input_1:0" ]
397
409
output_names_with_port = ["cell_state:0" ]
398
- self .run_test_case (feed_dict , input_names_with_port , output_names_with_port , rtol = 1e-3 , atol = 1e-06 )
410
+ self .run_test_case (feed_dict , input_names_with_port , output_names_with_port , rtol = 1e-3 , atol = 1e-06 ,
411
+ graph_validator = lambda g : check_gru_count (g , 1 ))
399
412
400
413
def test_dynamic_bidirectional_but_one_gru (self ):
401
414
units = 5
@@ -423,7 +436,8 @@ def test_dynamic_bidirectional_but_one_gru(self):
423
436
feed_dict = {"input_1:0" : x_val }
424
437
input_names_with_port = ["input_1:0" ]
425
438
output_names_with_port = ["output:0" , "cell_state:0" ]
426
- self .run_test_case (feed_dict , input_names_with_port , output_names_with_port , rtol = 1e-3 , atol = 1e-06 )
439
+ self .run_test_case (feed_dict , input_names_with_port , output_names_with_port , rtol = 1e-3 , atol = 1e-06 ,
440
+ graph_validator = lambda g : check_gru_count (g , 1 ))
427
441
428
442
def test_dynamic_bidirectional_but_one_gru_and_output_consumed_only (self ):
429
443
units = 5
@@ -448,7 +462,8 @@ def test_dynamic_bidirectional_but_one_gru_and_output_consumed_only(self):
448
462
feed_dict = {"input_1:0" : x_val }
449
463
input_names_with_port = ["input_1:0" ]
450
464
output_names_with_port = ["output:0" ]
451
- self .run_test_case (feed_dict , input_names_with_port , output_names_with_port , rtol = 1e-3 , atol = 1e-06 )
465
+ self .run_test_case (feed_dict , input_names_with_port , output_names_with_port , rtol = 1e-3 , atol = 1e-06 ,
466
+ graph_validator = lambda g : check_gru_count (g , 1 ))
452
467
453
468
def test_dynamic_bidirectional_but_one_gru_and_state_consumed_only (self ):
454
469
units = 5
@@ -473,7 +488,8 @@ def test_dynamic_bidirectional_but_one_gru_and_state_consumed_only(self):
473
488
feed_dict = {"input_1:0" : x_val }
474
489
input_names_with_port = ["input_1:0" ]
475
490
output_names_with_port = ["cell_state:0" ]
476
- self .run_test_case (feed_dict , input_names_with_port , output_names_with_port , rtol = 1e-3 , atol = 1e-06 )
491
+ self .run_test_case (feed_dict , input_names_with_port , output_names_with_port , rtol = 1e-3 , atol = 1e-06 ,
492
+ graph_validator = lambda g : check_gru_count (g , 1 ))
477
493
478
494
479
495
if __name__ == '__main__' :
0 commit comments