Skip to content

Commit b05a993

Browse files
Implement GRU rewriter for tf2 (#1688)
Signed-off-by: Tom Wildenhain <[email protected]> Co-authored-by: Guenther Schmuelling <[email protected]>
1 parent 04f9e86 commit b05a993

File tree

9 files changed

+277
-67
lines changed

9 files changed

+277
-67
lines changed

tests/test_gru.py

Lines changed: 81 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from tensorflow.python.ops import init_ops
1010
from tensorflow.python.ops import variable_scope
1111
from backend_test_base import Tf2OnnxBackendTestBase
12-
from common import unittest_main, check_gru_count, check_opset_after_tf_version, skip_tf2
12+
from common import unittest_main, check_gru_count, check_opset_after_tf_version, check_op_count
1313
from tf2onnx.tf_loader import is_tf2
1414

1515
# pylint: disable=missing-docstring,invalid-name,unused-argument,using-constant-test,cell-var-from-loop
@@ -35,17 +35,25 @@
3535
# TODO: as a workaround, set batch_size to 1 for now to bypass a onnxruntime bug, revert it when the bug is fixed
3636
class GRUTests(Tf2OnnxBackendTestBase):
3737

38-
def run_test_case(self, *args, **kwargs): #pylint: disable=arguments-differ
38+
def run_test_case(self, *args, graph_validator=None, **kwargs): #pylint: disable=arguments-differ
3939
# TF GRU has an unknown dim
4040
tmp = self.config.allow_missing_shapes
4141
self.config.allow_missing_shapes = True
42+
def new_graph_validator(g):
43+
good = True
44+
if graph_validator is not None:
45+
good = good and graph_validator(g)
46+
if is_tf2() and ':' in g.outputs[0]:
47+
# Only check for tf2 and tfjs, not tflite
48+
good = good and check_op_count(g, "Loop", 0, disabled=False)
49+
good = good and check_op_count(g, "Scan", 0, disabled=False)
50+
return good
4251
try:
43-
super().run_test_case(*args, **kwargs)
52+
super().run_test_case(*args, graph_validator=new_graph_validator, **kwargs)
4453
finally:
4554
self.config.allow_missing_shapes = tmp
4655

4756
@check_opset_after_tf_version("1.15", 8, "might need Scan")
48-
@skip_tf2()
4957
def test_single_dynamic_gru(self):
5058
units = 5
5159
batch_size = 1
@@ -56,7 +64,9 @@ def func(x):
5664
# no scope
5765
cell = GRUCell(
5866
units,
59-
activation=None)
67+
activation=None,
68+
kernel_initializer=tf.random_uniform_initializer(0.0, 1.0, seed=42),
69+
bias_initializer=tf.random_uniform_initializer(0.0, 1.0, seed=43))
6070
outputs, cell_state = dynamic_rnn(
6171
cell,
6272
x,
@@ -71,7 +81,6 @@ def func(x):
7181
graph_validator=lambda g: check_gru_count(g, 1))
7282

7383
@check_opset_after_tf_version("1.15", 8, "might need Scan")
74-
@skip_tf2()
7584
def test_multiple_dynamic_gru(self):
7685
units = 5
7786
batch_size = 1
@@ -84,7 +93,9 @@ def func(x):
8493
# no scope
8594
cell = GRUCell(
8695
units,
87-
activation=None)
96+
activation=None,
97+
kernel_initializer=tf.random_uniform_initializer(0.0, 1.0, seed=42),
98+
bias_initializer=tf.random_uniform_initializer(0.0, 1.0, seed=43))
8899
outputs, cell_state = dynamic_rnn(
89100
cell,
90101
x,
@@ -95,7 +106,9 @@ def func(x):
95106
# given scope
96107
cell = GRUCell(
97108
units,
98-
activation=None)
109+
activation=None,
110+
kernel_initializer=tf.random_uniform_initializer(0.0, 1.0, seed=44),
111+
bias_initializer=tf.random_uniform_initializer(0.0, 1.0, seed=45))
99112
with variable_scope.variable_scope("root1") as scope:
100113
outputs, cell_state = dynamic_rnn(
101114
cell,
@@ -115,7 +128,6 @@ def func(x):
115128
# graph_validator=lambda g: check_gru_count(g, 2))
116129

117130
@check_opset_after_tf_version("1.15", 8, "might need Select")
118-
@skip_tf2()
119131
def test_single_dynamic_gru_seq_length_is_const(self):
120132
units = 5
121133
batch_size = 1
@@ -143,7 +155,6 @@ def func(x):
143155
graph_validator=lambda g: check_gru_count(g, 1))
144156

145157
@check_opset_after_tf_version("1.15", 8, "might need Select")
146-
@skip_tf2()
147158
def test_single_dynamic_gru_seq_length_is_not_const(self):
148159
for np_dtype in [np.int32, np.int64, np.float32]:
149160
units = 5
@@ -174,7 +185,6 @@ def func(x, seq_length):
174185
graph_validator=lambda g: check_gru_count(g, 1))
175186

176187
@check_opset_after_tf_version("1.15", 8, "might need Scan")
177-
@skip_tf2()
178188
def test_single_dynamic_gru_placeholder_input(self):
179189
units = 5
180190
x_val = np.array([[1., 1.], [2., 2.], [3., 3.], [4., 4.]], dtype=np.float32)
@@ -200,7 +210,6 @@ def func(x):
200210
graph_validator=lambda g: check_gru_count(g, 1))
201211

202212
@check_opset_after_tf_version("1.15", 8, "might need Scan")
203-
@skip_tf2()
204213
def test_single_dynamic_gru_ch_zero_state_initializer(self):
205214
units = 5
206215
batch_size = 1
@@ -231,15 +240,14 @@ def func(x):
231240
graph_validator=lambda g: check_gru_count(g, 1))
232241

233242
@check_opset_after_tf_version("1.15", 8, "might need Scan")
234-
@skip_tf2()
235243
def test_single_dynamic_gru_random_weights(self):
236244
hidden_size = 5
237245
batch_size = 1
238246
x_val = np.array([[1., 1.], [2., 2.], [3., 3.], [4., 4.]], dtype=np.float32)
239247
x_val = np.stack([x_val] * batch_size)
240248

241249
def func(x):
242-
initializer = tf.random_uniform_initializer(-1.0, 1.0)
250+
initializer = tf.random_uniform_initializer(-1.0, 1.0, seed=42)
243251

244252
# no scope
245253
cell = GRUCell(
@@ -260,15 +268,16 @@ def func(x):
260268
graph_validator=lambda g: check_gru_count(g, 1))
261269

262270
@check_opset_after_tf_version("1.15", 8, "might need Scan")
263-
@skip_tf2()
264271
def test_single_dynamic_gru_random_weights2(self):
265272
hidden_size = 128
266273
batch_size = 1
267274
x_val = np.random.randn(1, 133).astype('f')
268275
x_val = np.stack([x_val] * batch_size)
269276

277+
270278
def func(x):
271-
initializer = tf.random_uniform_initializer(0.0, 1.0)
279+
#initializer = tf.constant_initializer(5.0)
280+
initializer = tf.random_uniform_initializer(0.0, 1.0, seed=42)
272281
# no scope
273282
cell = GRUCell(
274283
hidden_size,
@@ -288,15 +297,14 @@ def func(x):
288297
graph_validator=lambda g: check_gru_count(g, 1))
289298

290299
@check_opset_after_tf_version("1.15", 8, "might need Scan")
291-
@skip_tf2()
292300
def test_dynamic_gru_output_consumed_only(self):
293301
units = 5
294302
batch_size = 6
295303
x_val = np.array([[1., 1.], [2., 2.], [3., 3.]], dtype=np.float32)
296304
x_val = np.stack([x_val] * batch_size)
297305

298306
def func(x):
299-
initializer = tf.random_uniform_initializer(-1.0, 1.0)
307+
initializer = tf.random_uniform_initializer(-1.0, 1.0, seed=42)
300308
cell1 = GRUCell(
301309
units,
302310
kernel_initializer=initializer)
@@ -315,15 +323,14 @@ def func(x):
315323
graph_validator=lambda g: check_gru_count(g, 1))
316324

317325
@check_opset_after_tf_version("1.15", 8, "might need Scan")
318-
@skip_tf2()
319326
def test_dynamic_gru_state_consumed_only(self):
320327
units = 5
321328
batch_size = 6
322329
x_val = np.array([[1., 1.], [2., 2.], [3., 3.]], dtype=np.float32)
323330
x_val = np.stack([x_val] * batch_size)
324331

325332
def func(x):
326-
initializer = tf.random_uniform_initializer(-1.0, 1.0)
333+
initializer = tf.random_uniform_initializer(-1.0, 1.0, seed=42)
327334
cell1 = GRUCell(
328335
units,
329336
kernel_initializer=initializer)
@@ -342,7 +349,6 @@ def func(x):
342349
graph_validator=lambda g: check_gru_count(g, 1))
343350

344351
@check_opset_after_tf_version("1.15", 10, "might need ReverseV2")
345-
@skip_tf2()
346352
def test_dynamic_bigru(self):
347353
units = 5
348354
batch_size = 1
@@ -374,7 +380,6 @@ def func(x):
374380
graph_validator=lambda g: check_gru_count(g, 1))
375381

376382
@check_opset_after_tf_version("1.15", 10, "might need ReverseV2")
377-
@skip_tf2()
378383
def test_dynamic_bigru_output_consumed_only(self):
379384
units = 5
380385
batch_size = 1
@@ -406,7 +411,6 @@ def func(x):
406411
graph_validator=lambda g: check_gru_count(g, 1))
407412

408413
@check_opset_after_tf_version("1.15", 10, "might need ReverseV2")
409-
@skip_tf2()
410414
def test_dynamic_bigru_state_consumed_only(self):
411415
units = 5
412416
batch_size = 1
@@ -438,7 +442,6 @@ def func(x):
438442
graph_validator=lambda g: check_gru_count(g, 1))
439443

440444
@check_opset_after_tf_version("1.15", 10, "might need ReverseV2")
441-
@skip_tf2()
442445
def test_dynamic_bidirectional_but_one_gru(self):
443446
units = 5
444447
batch_size = 1
@@ -467,7 +470,6 @@ def func(x):
467470
graph_validator=lambda g: check_gru_count(g, 1))
468471

469472
@check_opset_after_tf_version("1.15", 10, "might need ReverseV2")
470-
@skip_tf2()
471473
def test_dynamic_bidirectional_but_one_gru_and_output_consumed_only(self):
472474
units = 5
473475
batch_size = 1
@@ -478,7 +480,9 @@ def func(x):
478480

479481
# bigru, no scope
480482
cell = GRUCell(
481-
units)
483+
units,
484+
kernel_initializer=tf.random_uniform_initializer(0.0, 1.0, seed=42),
485+
bias_initializer=tf.random_uniform_initializer(0.0, 1.0, seed=43))
482486
outputs, _ = bidirectional_dynamic_rnn(
483487
cell,
484488
cell,
@@ -494,7 +498,6 @@ def func(x):
494498
graph_validator=lambda g: check_gru_count(g, 1))
495499

496500
@check_opset_after_tf_version("1.15", 10, "might need ReverseV2")
497-
@skip_tf2()
498501
def test_dynamic_bidirectional_but_one_gru_and_state_consumed_only(self):
499502
units = 5
500503
batch_size = 1
@@ -505,7 +508,9 @@ def func(x):
505508

506509
# bigru, no scope
507510
cell = GRUCell(
508-
units)
511+
units,
512+
kernel_initializer=tf.random_uniform_initializer(0.0, 1.0, seed=42),
513+
bias_initializer=tf.random_uniform_initializer(0.0, 1.0, seed=43))
509514
_, cell_state = bidirectional_dynamic_rnn(
510515
cell,
511516
cell,
@@ -521,7 +526,6 @@ def func(x):
521526
graph_validator=lambda g: check_gru_count(g, 1))
522527

523528
@check_opset_after_tf_version("1.15", 10, "might need ReverseV2")
524-
@skip_tf2()
525529
def test_dynamic_bigru_unknown_batch_size(self):
526530
units = 5
527531
batch_size = 6
@@ -530,8 +534,14 @@ def test_dynamic_bigru_unknown_batch_size(self):
530534

531535
def func(x):
532536

533-
cell1 = GRUCell(units)
534-
cell2 = GRUCell(units)
537+
cell1 = GRUCell(
538+
units,
539+
kernel_initializer=tf.random_uniform_initializer(0.0, 1.0, seed=42),
540+
bias_initializer=tf.random_uniform_initializer(0.0, 1.0, seed=43))
541+
cell2 = GRUCell(
542+
units,
543+
kernel_initializer=tf.random_uniform_initializer(0.0, 1.0, seed=44),
544+
bias_initializer=tf.random_uniform_initializer(0.0, 1.0, seed=45))
535545
_, cell_state = bidirectional_dynamic_rnn(
536546
cell1,
537547
cell2,
@@ -548,7 +558,6 @@ def func(x):
548558
graph_validator=lambda g: check_gru_count(g, 1))
549559

550560
@check_opset_after_tf_version("1.15", 10, "might need ReverseV2")
551-
@skip_tf2()
552561
def test_dynamic_bigru_outputs_partially_consumed(self):
553562
units = 5
554563
batch_size = 6
@@ -557,8 +566,14 @@ def test_dynamic_bigru_outputs_partially_consumed(self):
557566

558567
def func(x):
559568

560-
cell1 = GRUCell(units)
561-
cell2 = GRUCell(units)
569+
cell1 = GRUCell(
570+
units,
571+
kernel_initializer=tf.random_uniform_initializer(0.0, 1.0, seed=42),
572+
bias_initializer=tf.random_uniform_initializer(0.0, 1.0, seed=43))
573+
cell2 = GRUCell(
574+
units,
575+
kernel_initializer=tf.random_uniform_initializer(0.0, 1.0, seed=44),
576+
bias_initializer=tf.random_uniform_initializer(0.0, 1.0, seed=45))
562577
(output_fw, _), (_, state_bw) = bidirectional_dynamic_rnn(
563578
cell1,
564579
cell2,
@@ -574,7 +589,6 @@ def func(x):
574589
graph_validator=lambda g: check_gru_count(g, 1))
575590

576591
@check_opset_after_tf_version("1.15", 10, "might need ReverseV2")
577-
@skip_tf2()
578592
def test_dynamic_multi_bigru_with_same_input_hidden_size(self):
579593
batch_size = 10
580594
x_val = np.array([[1., 1.], [2., 2.], [3., 3.]], dtype=np.float32)
@@ -583,8 +597,14 @@ def test_dynamic_multi_bigru_with_same_input_hidden_size(self):
583597
def func(x):
584598
# bigru, no scope
585599
units = 5
586-
cell1 = GRUCell(units)
587-
cell2 = GRUCell(units)
600+
cell1 = GRUCell(
601+
units,
602+
kernel_initializer=tf.random_uniform_initializer(0.0, 1.0, seed=42),
603+
bias_initializer=tf.random_uniform_initializer(0.0, 1.0, seed=43))
604+
cell2 = GRUCell(
605+
units,
606+
kernel_initializer=tf.random_uniform_initializer(0.0, 1.0, seed=44),
607+
bias_initializer=tf.random_uniform_initializer(0.0, 1.0, seed=45))
588608
outputs_1, cell_state_1 = bidirectional_dynamic_rnn(
589609
cell1,
590610
cell2,
@@ -594,8 +614,14 @@ def func(x):
594614
)
595615

596616
units = 10
597-
cell1 = GRUCell(units)
598-
cell2 = GRUCell(units)
617+
cell1 = GRUCell(
618+
units,
619+
kernel_initializer=tf.random_uniform_initializer(0.0, 1.0, seed=42),
620+
bias_initializer=tf.random_uniform_initializer(0.0, 1.0, seed=43))
621+
cell2 = GRUCell(
622+
units,
623+
kernel_initializer=tf.random_uniform_initializer(0.0, 1.0, seed=44),
624+
bias_initializer=tf.random_uniform_initializer(0.0, 1.0, seed=45))
599625
outputs_2, cell_state_2 = bidirectional_dynamic_rnn(
600626
cell1,
601627
cell2,
@@ -616,7 +642,6 @@ def func(x):
616642
# graph_validator=lambda g: check_gru_count(g, 2))
617643

618644
@check_opset_after_tf_version("1.15", 10, "might need ReverseV2")
619-
@skip_tf2()
620645
def test_dynamic_multi_bigru_with_same_input_seq_len(self):
621646
units = 5
622647
batch_size = 10
@@ -626,8 +651,14 @@ def test_dynamic_multi_bigru_with_same_input_seq_len(self):
626651

627652
def func(x, y1, y2):
628653
seq_len1 = tf.tile(y1, [batch_size])
629-
cell1 = GRUCell(units)
630-
cell2 = GRUCell(units)
654+
cell1 = GRUCell(
655+
units,
656+
kernel_initializer=tf.random_uniform_initializer(0.0, 1.0, seed=42),
657+
bias_initializer=tf.random_uniform_initializer(0.0, 1.0, seed=43))
658+
cell2 = GRUCell(
659+
units,
660+
kernel_initializer=tf.random_uniform_initializer(0.0, 1.0, seed=44),
661+
bias_initializer=tf.random_uniform_initializer(0.0, 1.0, seed=45))
631662
outputs_1, cell_state_1 = bidirectional_dynamic_rnn(
632663
cell1,
633664
cell2,
@@ -637,8 +668,14 @@ def func(x, y1, y2):
637668
scope="bigru_1"
638669
)
639670
seq_len2 = tf.tile(y2, [batch_size])
640-
cell1 = GRUCell(units)
641-
cell2 = GRUCell(units)
671+
cell1 = GRUCell(
672+
units,
673+
kernel_initializer=tf.random_uniform_initializer(0.0, 1.0, seed=46),
674+
bias_initializer=tf.random_uniform_initializer(0.0, 1.0, seed=47))
675+
cell2 = GRUCell(
676+
units,
677+
kernel_initializer=tf.random_uniform_initializer(0.0, 1.0, seed=48),
678+
bias_initializer=tf.random_uniform_initializer(0.0, 1.0, seed=49))
642679
outputs_2, cell_state_2 = bidirectional_dynamic_rnn(
643680
cell1,
644681
cell2,

0 commit comments

Comments
 (0)