@@ -255,6 +255,54 @@ def test_single_dynamic_gru_random_weights2(self):
255
255
output_names_with_port = ["output:0" , "cell_state:0" ]
256
256
self .run_test_case (feed_dict , input_names_with_port , output_names_with_port , 0.01 )
257
257
258
+ def test_dynamic_gru_output_consumed_only (self ):
259
+ units = 5
260
+ batch_size = 6
261
+ x_val = np .array ([[1. , 1. ], [2. , 2. ], [3. , 3. ]], dtype = np .float32 )
262
+ x_val = np .stack ([x_val ] * batch_size )
263
+
264
+ x = tf .placeholder (tf .float32 , x_val .shape , name = "input_1" )
265
+ initializer = tf .random_uniform_initializer (- 1.0 , 1.0 )
266
+ cell1 = rnn .GRUCell (
267
+ units ,
268
+ kernel_initializer = initializer )
269
+
270
+ outputs , _ = tf .nn .dynamic_rnn (
271
+ cell1 ,
272
+ x ,
273
+ dtype = tf .float32 )
274
+
275
+ _ = tf .identity (outputs , name = "output" )
276
+
277
+ feed_dict = {"input_1:0" : x_val }
278
+ input_names_with_port = ["input_1:0" ]
279
+ output_names_with_port = ["output:0" ]
280
+ self .run_test_case (feed_dict , input_names_with_port , output_names_with_port , 0.0001 )
281
+
282
+ def test_dynamic_gru_state_consumed_only (self ):
283
+ units = 5
284
+ batch_size = 6
285
+ x_val = np .array ([[1. , 1. ], [2. , 2. ], [3. , 3. ]], dtype = np .float32 )
286
+ x_val = np .stack ([x_val ] * batch_size )
287
+
288
+ x = tf .placeholder (tf .float32 , x_val .shape , name = "input_1" )
289
+ initializer = tf .random_uniform_initializer (- 1.0 , 1.0 )
290
+ cell1 = rnn .GRUCell (
291
+ units ,
292
+ kernel_initializer = initializer )
293
+
294
+ _ , cell_state = tf .nn .dynamic_rnn (
295
+ cell1 ,
296
+ x ,
297
+ dtype = tf .float32 )
298
+
299
+ _ = tf .identity (cell_state , name = "cell_state" )
300
+
301
+ feed_dict = {"input_1:0" : x_val }
302
+ input_names_with_port = ["input_1:0" ]
303
+ output_names_with_port = ["cell_state:0" ]
304
+ self .run_test_case (feed_dict , input_names_with_port , output_names_with_port , 0.0001 )
305
+
258
306
def test_dynamic_bigru (self ):
259
307
units = 5
260
308
batch_size = 1
@@ -264,7 +312,6 @@ def test_dynamic_bigru(self):
264
312
x = tf .placeholder (tf .float32 , x_val .shape , name = "input_1" )
265
313
initializer = init_ops .constant_initializer (0.5 )
266
314
267
- gru_list = []
268
315
if True :
269
316
# bigru, no scope
270
317
cell1 = rnn .GRUCell (
@@ -278,7 +325,6 @@ def test_dynamic_bigru(self):
278
325
cell2 ,
279
326
x ,
280
327
dtype = tf .float32 )
281
- gru_list .append (outputs )
282
328
283
329
_ = tf .identity (outputs , name = "output" )
284
330
_ = tf .identity (cell_state , name = "cell_state" )
@@ -297,7 +343,6 @@ def test_dynamic_bigru_output_consumed_only(self):
297
343
x = tf .placeholder (tf .float32 , x_val .shape , name = "input_1" )
298
344
initializer = init_ops .constant_initializer (0.5 )
299
345
300
- gru_list = []
301
346
if True :
302
347
# bigru, no scope
303
348
cell1 = rnn .GRUCell (
@@ -311,7 +356,6 @@ def test_dynamic_bigru_output_consumed_only(self):
311
356
cell2 ,
312
357
x ,
313
358
dtype = tf .float32 )
314
- gru_list .append (outputs )
315
359
316
360
_ = tf .identity (outputs , name = "output" )
317
361
@@ -320,6 +364,36 @@ def test_dynamic_bigru_output_consumed_only(self):
320
364
output_names_with_port = ["output:0" ]
321
365
self .run_test_case (feed_dict , input_names_with_port , output_names_with_port , rtol = 1e-3 )
322
366
367
+ def test_dynamic_bigru_state_consumed_only (self ):
368
+ units = 5
369
+ batch_size = 1
370
+ x_val = np .array ([[1. , 1. ], [2. , 2. ], [3. , 3. ]], dtype = np .float32 )
371
+ x_val = np .stack ([x_val ] * batch_size )
372
+
373
+ x = tf .placeholder (tf .float32 , x_val .shape , name = "input_1" )
374
+ initializer = init_ops .constant_initializer (0.5 )
375
+
376
+ if True :
377
+ # bigru, no scope
378
+ cell1 = rnn .GRUCell (
379
+ units ,
380
+ kernel_initializer = initializer )
381
+ cell2 = rnn .GRUCell (
382
+ units ,
383
+ kernel_initializer = initializer )
384
+ _ , cell_state = tf .nn .bidirectional_dynamic_rnn (
385
+ cell1 ,
386
+ cell2 ,
387
+ x ,
388
+ dtype = tf .float32 )
389
+
390
+ _ = tf .identity (cell_state , name = "cell_state" )
391
+
392
+ feed_dict = {"input_1:0" : x_val }
393
+ input_names_with_port = ["input_1:0" ]
394
+ output_names_with_port = ["cell_state:0" ]
395
+ self .run_test_case (feed_dict , input_names_with_port , output_names_with_port , rtol = 1e-3 )
396
+
323
397
def test_dynamic_bidirectional_but_one_gru (self ):
324
398
units = 5
325
399
batch_size = 1
@@ -329,7 +403,6 @@ def test_dynamic_bidirectional_but_one_gru(self):
329
403
x = tf .placeholder (tf .float32 , x_val .shape , name = "input_1" )
330
404
initializer = init_ops .constant_initializer (0.5 )
331
405
332
- gru_list = []
333
406
if True :
334
407
# bigru, no scope
335
408
cell = rnn .GRUCell (
@@ -340,7 +413,6 @@ def test_dynamic_bidirectional_but_one_gru(self):
340
413
cell ,
341
414
x ,
342
415
dtype = tf .float32 )
343
- gru_list .append (outputs )
344
416
345
417
_ = tf .identity (outputs , name = "output" )
346
418
_ = tf .identity (cell_state , name = "cell_state" )
@@ -358,7 +430,6 @@ def test_dynamic_bidirectional_but_one_gru_and_output_consumed_only(self):
358
430
359
431
x = tf .placeholder (tf .float32 , x_val .shape , name = "input_1" )
360
432
361
- gru_list = []
362
433
if True :
363
434
# bigru, no scope
364
435
cell = rnn .GRUCell (
@@ -368,7 +439,6 @@ def test_dynamic_bidirectional_but_one_gru_and_output_consumed_only(self):
368
439
cell ,
369
440
x ,
370
441
dtype = tf .float32 )
371
- gru_list .append (outputs )
372
442
373
443
_ = tf .identity (outputs , name = "output" )
374
444
@@ -377,6 +447,31 @@ def test_dynamic_bidirectional_but_one_gru_and_output_consumed_only(self):
377
447
output_names_with_port = ["output:0" ]
378
448
self .run_test_case (feed_dict , input_names_with_port , output_names_with_port , rtol = 1e-3 )
379
449
450
+ def test_dynamic_bidirectional_but_one_gru_and_state_consumed_only (self ):
451
+ units = 5
452
+ batch_size = 1
453
+ x_val = np .array ([[1. , 1. ], [2. , 2. ], [3. , 3. ]], dtype = np .float32 )
454
+ x_val = np .stack ([x_val ] * batch_size )
455
+
456
+ x = tf .placeholder (tf .float32 , x_val .shape , name = "input_1" )
457
+
458
+ if True :
459
+ # bigru, no scope
460
+ cell = rnn .GRUCell (
461
+ units )
462
+ _ , cell_state = tf .nn .bidirectional_dynamic_rnn (
463
+ cell ,
464
+ cell ,
465
+ x ,
466
+ dtype = tf .float32 )
467
+
468
+ _ = tf .identity (cell_state , name = "cell_state" )
469
+
470
+ feed_dict = {"input_1:0" : x_val }
471
+ input_names_with_port = ["input_1:0" ]
472
+ output_names_with_port = ["cell_state:0" ]
473
+ self .run_test_case (feed_dict , input_names_with_port , output_names_with_port , rtol = 1e-3 )
474
+
380
475
381
476
if __name__ == '__main__' :
382
477
Tf2OnnxBackendTestBase .trigger (GRUTests )
0 commit comments