Skip to content

Commit 9d0ee88

Browse files
authored
Merge pull request #306 from zhijxu-MS/zhijxu/model_converter_support
fix bug when only "state" of gru is consumed;
2 parents 78a09b5 + 3a9264b commit 9d0ee88

File tree

5 files changed

+287
-24
lines changed

5 files changed

+287
-24
lines changed

tests/run_pretrained_models.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,9 @@
2323
import tensorflow as tf
2424
from tensorflow.core.framework import graph_pb2
2525
from tensorflow.python.framework.graph_util import convert_variables_to_constants
26+
# contrib ops are registered only when the module is imported, the following import statement is needed,
27+
# otherwise tf runtime error will show up when the tf model is restored from pb file because of un-registered ops.
28+
import tensorflow.contrib.rnn # pylint: disable=unused-import
2629
import yaml
2730
import PIL.Image
2831

tests/test_gru.py

Lines changed: 103 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -255,6 +255,54 @@ def test_single_dynamic_gru_random_weights2(self):
255255
output_names_with_port = ["output:0", "cell_state:0"]
256256
self.run_test_case(feed_dict, input_names_with_port, output_names_with_port, 0.01)
257257

258+
def test_dynamic_gru_output_consumed_only(self):
259+
units = 5
260+
batch_size = 6
261+
x_val = np.array([[1., 1.], [2., 2.], [3., 3.]], dtype=np.float32)
262+
x_val = np.stack([x_val] * batch_size)
263+
264+
x = tf.placeholder(tf.float32, x_val.shape, name="input_1")
265+
initializer = tf.random_uniform_initializer(-1.0, 1.0)
266+
cell1 = rnn.GRUCell(
267+
units,
268+
kernel_initializer=initializer)
269+
270+
outputs, _ = tf.nn.dynamic_rnn(
271+
cell1,
272+
x,
273+
dtype=tf.float32)
274+
275+
_ = tf.identity(outputs, name="output")
276+
277+
feed_dict = {"input_1:0": x_val}
278+
input_names_with_port = ["input_1:0"]
279+
output_names_with_port = ["output:0"]
280+
self.run_test_case(feed_dict, input_names_with_port, output_names_with_port, 0.0001)
281+
282+
def test_dynamic_gru_state_consumed_only(self):
283+
units = 5
284+
batch_size = 6
285+
x_val = np.array([[1., 1.], [2., 2.], [3., 3.]], dtype=np.float32)
286+
x_val = np.stack([x_val] * batch_size)
287+
288+
x = tf.placeholder(tf.float32, x_val.shape, name="input_1")
289+
initializer = tf.random_uniform_initializer(-1.0, 1.0)
290+
cell1 = rnn.GRUCell(
291+
units,
292+
kernel_initializer=initializer)
293+
294+
_, cell_state = tf.nn.dynamic_rnn(
295+
cell1,
296+
x,
297+
dtype=tf.float32)
298+
299+
_ = tf.identity(cell_state, name="cell_state")
300+
301+
feed_dict = {"input_1:0": x_val}
302+
input_names_with_port = ["input_1:0"]
303+
output_names_with_port = ["cell_state:0"]
304+
self.run_test_case(feed_dict, input_names_with_port, output_names_with_port, 0.0001)
305+
258306
def test_dynamic_bigru(self):
259307
units = 5
260308
batch_size = 1
@@ -264,7 +312,6 @@ def test_dynamic_bigru(self):
264312
x = tf.placeholder(tf.float32, x_val.shape, name="input_1")
265313
initializer = init_ops.constant_initializer(0.5)
266314

267-
gru_list = []
268315
if True:
269316
# bigru, no scope
270317
cell1 = rnn.GRUCell(
@@ -278,7 +325,6 @@ def test_dynamic_bigru(self):
278325
cell2,
279326
x,
280327
dtype=tf.float32)
281-
gru_list.append(outputs)
282328

283329
_ = tf.identity(outputs, name="output")
284330
_ = tf.identity(cell_state, name="cell_state")
@@ -297,7 +343,6 @@ def test_dynamic_bigru_output_consumed_only(self):
297343
x = tf.placeholder(tf.float32, x_val.shape, name="input_1")
298344
initializer = init_ops.constant_initializer(0.5)
299345

300-
gru_list = []
301346
if True:
302347
# bigru, no scope
303348
cell1 = rnn.GRUCell(
@@ -311,7 +356,6 @@ def test_dynamic_bigru_output_consumed_only(self):
311356
cell2,
312357
x,
313358
dtype=tf.float32)
314-
gru_list.append(outputs)
315359

316360
_ = tf.identity(outputs, name="output")
317361

@@ -320,6 +364,36 @@ def test_dynamic_bigru_output_consumed_only(self):
320364
output_names_with_port = ["output:0"]
321365
self.run_test_case(feed_dict, input_names_with_port, output_names_with_port, rtol=1e-3)
322366

367+
def test_dynamic_bigru_state_consumed_only(self):
368+
units = 5
369+
batch_size = 1
370+
x_val = np.array([[1., 1.], [2., 2.], [3., 3.]], dtype=np.float32)
371+
x_val = np.stack([x_val] * batch_size)
372+
373+
x = tf.placeholder(tf.float32, x_val.shape, name="input_1")
374+
initializer = init_ops.constant_initializer(0.5)
375+
376+
if True:
377+
# bigru, no scope
378+
cell1 = rnn.GRUCell(
379+
units,
380+
kernel_initializer=initializer)
381+
cell2 = rnn.GRUCell(
382+
units,
383+
kernel_initializer=initializer)
384+
_, cell_state = tf.nn.bidirectional_dynamic_rnn(
385+
cell1,
386+
cell2,
387+
x,
388+
dtype=tf.float32)
389+
390+
_ = tf.identity(cell_state, name="cell_state")
391+
392+
feed_dict = {"input_1:0": x_val}
393+
input_names_with_port = ["input_1:0"]
394+
output_names_with_port = ["cell_state:0"]
395+
self.run_test_case(feed_dict, input_names_with_port, output_names_with_port, rtol=1e-3)
396+
323397
def test_dynamic_bidirectional_but_one_gru(self):
324398
units = 5
325399
batch_size = 1
@@ -329,7 +403,6 @@ def test_dynamic_bidirectional_but_one_gru(self):
329403
x = tf.placeholder(tf.float32, x_val.shape, name="input_1")
330404
initializer = init_ops.constant_initializer(0.5)
331405

332-
gru_list = []
333406
if True:
334407
# bigru, no scope
335408
cell = rnn.GRUCell(
@@ -340,7 +413,6 @@ def test_dynamic_bidirectional_but_one_gru(self):
340413
cell,
341414
x,
342415
dtype=tf.float32)
343-
gru_list.append(outputs)
344416

345417
_ = tf.identity(outputs, name="output")
346418
_ = tf.identity(cell_state, name="cell_state")
@@ -358,7 +430,6 @@ def test_dynamic_bidirectional_but_one_gru_and_output_consumed_only(self):
358430

359431
x = tf.placeholder(tf.float32, x_val.shape, name="input_1")
360432

361-
gru_list = []
362433
if True:
363434
# bigru, no scope
364435
cell = rnn.GRUCell(
@@ -368,7 +439,6 @@ def test_dynamic_bidirectional_but_one_gru_and_output_consumed_only(self):
368439
cell,
369440
x,
370441
dtype=tf.float32)
371-
gru_list.append(outputs)
372442

373443
_ = tf.identity(outputs, name="output")
374444

@@ -377,6 +447,31 @@ def test_dynamic_bidirectional_but_one_gru_and_output_consumed_only(self):
377447
output_names_with_port = ["output:0"]
378448
self.run_test_case(feed_dict, input_names_with_port, output_names_with_port, rtol=1e-3)
379449

450+
def test_dynamic_bidirectional_but_one_gru_and_state_consumed_only(self):
451+
units = 5
452+
batch_size = 1
453+
x_val = np.array([[1., 1.], [2., 2.], [3., 3.]], dtype=np.float32)
454+
x_val = np.stack([x_val] * batch_size)
455+
456+
x = tf.placeholder(tf.float32, x_val.shape, name="input_1")
457+
458+
if True:
459+
# bigru, no scope
460+
cell = rnn.GRUCell(
461+
units)
462+
_, cell_state = tf.nn.bidirectional_dynamic_rnn(
463+
cell,
464+
cell,
465+
x,
466+
dtype=tf.float32)
467+
468+
_ = tf.identity(cell_state, name="cell_state")
469+
470+
feed_dict = {"input_1:0": x_val}
471+
input_names_with_port = ["input_1:0"]
472+
output_names_with_port = ["cell_state:0"]
473+
self.run_test_case(feed_dict, input_names_with_port, output_names_with_port, rtol=1e-3)
474+
380475

381476
if __name__ == '__main__':
382477
Tf2OnnxBackendTestBase.trigger(GRUTests)

tests/test_grublock.py

Lines changed: 96 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -238,6 +238,50 @@ def test_single_dynamic_gru_random_weights2(self):
238238
output_names_with_port = ["output:0", "cell_state:0"]
239239
self.run_test_case(feed_dict, input_names_with_port, output_names_with_port, 0.01)
240240

241+
def test_dynamic_gru_output_consumed_only(self):
242+
units = 5
243+
batch_size = 6
244+
x_val = np.array([[1., 1.], [2., 2.], [3., 3.]], dtype=np.float32)
245+
x_val = np.stack([x_val] * batch_size)
246+
247+
x = tf.placeholder(tf.float32, x_val.shape, name="input_1")
248+
cell1 = rnn.GRUBlockCell(
249+
units)
250+
251+
outputs, _ = tf.nn.dynamic_rnn(
252+
cell1,
253+
x,
254+
dtype=tf.float32)
255+
256+
_ = tf.identity(outputs, name="output")
257+
258+
feed_dict = {"input_1:0": x_val}
259+
input_names_with_port = ["input_1:0"]
260+
output_names_with_port = ["output:0"]
261+
self.run_test_case(feed_dict, input_names_with_port, output_names_with_port, 0.0001)
262+
263+
def test_dynamic_gru_state_consumed_only(self):
264+
units = 5
265+
batch_size = 6
266+
x_val = np.array([[1., 1.], [2., 2.], [3., 3.]], dtype=np.float32)
267+
x_val = np.stack([x_val] * batch_size)
268+
269+
x = tf.placeholder(tf.float32, x_val.shape, name="input_1")
270+
cell1 = rnn.GRUBlockCell(
271+
units)
272+
273+
_, cell_state = tf.nn.dynamic_rnn(
274+
cell1,
275+
x,
276+
dtype=tf.float32)
277+
278+
_ = tf.identity(cell_state, name="cell_state")
279+
280+
feed_dict = {"input_1:0": x_val}
281+
input_names_with_port = ["input_1:0"]
282+
output_names_with_port = ["cell_state:0"]
283+
self.run_test_case(feed_dict, input_names_with_port, output_names_with_port, 0.0001)
284+
241285
def test_dynamic_bigru(self):
242286
units = 5
243287
batch_size = 1
@@ -246,7 +290,6 @@ def test_dynamic_bigru(self):
246290

247291
x = tf.placeholder(tf.float32, x_val.shape, name="input_1")
248292

249-
gru_list = []
250293
if True:
251294
# bigru, no scope
252295
cell1 = rnn.GRUBlockCell(
@@ -258,7 +301,6 @@ def test_dynamic_bigru(self):
258301
cell2,
259302
x,
260303
dtype=tf.float32)
261-
gru_list.append(outputs)
262304

263305
_ = tf.identity(outputs, name="output")
264306
_ = tf.identity(cell_state, name="cell_state")
@@ -276,7 +318,6 @@ def test_dynamic_bigru_output_consumed_only(self):
276318

277319
x = tf.placeholder(tf.float32, x_val.shape, name="input_1")
278320

279-
gru_list = []
280321
if True:
281322
# bigru, no scope
282323
cell1 = rnn.GRUBlockCell(
@@ -288,7 +329,6 @@ def test_dynamic_bigru_output_consumed_only(self):
288329
cell2,
289330
x,
290331
dtype=tf.float32)
291-
gru_list.append(outputs)
292332

293333
_ = tf.identity(outputs, name="output")
294334

@@ -297,6 +337,33 @@ def test_dynamic_bigru_output_consumed_only(self):
297337
output_names_with_port = ["output:0"]
298338
self.run_test_case(feed_dict, input_names_with_port, output_names_with_port, rtol=1e-3)
299339

340+
def test_dynamic_bigru_state_consumed_only(self):
341+
units = 5
342+
batch_size = 1
343+
x_val = np.array([[1., 1.], [2., 2.], [3., 3.]], dtype=np.float32)
344+
x_val = np.stack([x_val] * batch_size)
345+
346+
x = tf.placeholder(tf.float32, x_val.shape, name="input_1")
347+
348+
if True:
349+
# bigru, no scope
350+
cell1 = rnn.GRUBlockCell(
351+
units)
352+
cell2 = rnn.GRUBlockCell(
353+
units)
354+
_, cell_state = tf.nn.bidirectional_dynamic_rnn(
355+
cell1,
356+
cell2,
357+
x,
358+
dtype=tf.float32)
359+
360+
_ = tf.identity(cell_state, name="cell_state")
361+
362+
feed_dict = {"input_1:0": x_val}
363+
input_names_with_port = ["input_1:0"]
364+
output_names_with_port = ["cell_state:0"]
365+
self.run_test_case(feed_dict, input_names_with_port, output_names_with_port, rtol=1e-3)
366+
300367
def test_dynamic_bidirectional_but_one_gru(self):
301368
units = 5
302369
batch_size = 1
@@ -305,7 +372,6 @@ def test_dynamic_bidirectional_but_one_gru(self):
305372

306373
x = tf.placeholder(tf.float32, x_val.shape, name="input_1")
307374

308-
gru_list = []
309375
if True:
310376
# bigru, no scope
311377
cell = rnn.GRUBlockCell(
@@ -315,7 +381,6 @@ def test_dynamic_bidirectional_but_one_gru(self):
315381
cell,
316382
x,
317383
dtype=tf.float32)
318-
gru_list.append(outputs)
319384

320385
_ = tf.identity(outputs, name="output")
321386
_ = tf.identity(cell_state, name="cell_state")
@@ -333,7 +398,6 @@ def test_dynamic_bidirectional_but_one_gru_and_output_consumed_only(self):
333398

334399
x = tf.placeholder(tf.float32, x_val.shape, name="input_1")
335400

336-
gru_list = []
337401
if True:
338402
# bigru, no scope
339403
cell = rnn.GRUBlockCell(
@@ -343,7 +407,6 @@ def test_dynamic_bidirectional_but_one_gru_and_output_consumed_only(self):
343407
cell,
344408
x,
345409
dtype=tf.float32)
346-
gru_list.append(outputs)
347410

348411
_ = tf.identity(outputs, name="output")
349412

@@ -352,6 +415,31 @@ def test_dynamic_bidirectional_but_one_gru_and_output_consumed_only(self):
352415
output_names_with_port = ["output:0"]
353416
self.run_test_case(feed_dict, input_names_with_port, output_names_with_port, rtol=1e-3)
354417

418+
def test_dynamic_bidirectional_but_one_gru_and_state_consumed_only(self):
419+
units = 5
420+
batch_size = 1
421+
x_val = np.array([[1., 1.], [2., 2.], [3., 3.]], dtype=np.float32)
422+
x_val = np.stack([x_val] * batch_size)
423+
424+
x = tf.placeholder(tf.float32, x_val.shape, name="input_1")
425+
426+
if True:
427+
# bigru, no scope
428+
cell = rnn.GRUBlockCell(
429+
units)
430+
_, cell_state = tf.nn.bidirectional_dynamic_rnn(
431+
cell,
432+
cell,
433+
x,
434+
dtype=tf.float32)
435+
436+
_ = tf.identity(cell_state, name="cell_state")
437+
438+
feed_dict = {"input_1:0": x_val}
439+
input_names_with_port = ["input_1:0"]
440+
output_names_with_port = ["cell_state:0"]
441+
self.run_test_case(feed_dict, input_names_with_port, output_names_with_port, rtol=1e-3)
442+
355443

356444
if __name__ == '__main__':
357445
Tf2OnnxBackendTestBase.trigger(GRUBlockTests)

0 commit comments

Comments
 (0)