Skip to content

Commit 966a65d

Browse files
authored
Merge pull request #383 from lucienwang1009/rnn_rewriter
refactor lstm_rewriter and support rewriting LSTMBlockCell
2 parents a2b8085 + 79e86e0 commit 966a65d

14 files changed

+1401
-364
lines changed

tests/test_gru.py

Lines changed: 83 additions & 91 deletions
Original file line numberDiff line numberDiff line change
@@ -61,32 +61,30 @@ def test_multiple_dynamic_gru(self):
6161

6262
gru_output_list = []
6363
gru_cell_state_list = []
64-
if True:
65-
# no scope
66-
cell = rnn.GRUCell(
67-
units,
68-
activation=None)
64+
# no scope
65+
cell = rnn.GRUCell(
66+
units,
67+
activation=None)
68+
outputs, cell_state = tf.nn.dynamic_rnn(
69+
cell,
70+
x,
71+
dtype=tf.float32)
72+
gru_output_list.append(outputs)
73+
gru_cell_state_list.append(cell_state)
74+
75+
# given scope
76+
cell = rnn.GRUCell(
77+
units,
78+
activation=None)
79+
with variable_scope.variable_scope("root1") as scope:
6980
outputs, cell_state = tf.nn.dynamic_rnn(
7081
cell,
7182
x,
72-
dtype=tf.float32)
73-
gru_output_list.append(outputs)
74-
gru_cell_state_list.append(cell_state)
75-
76-
if True:
77-
# given scope
78-
cell = rnn.GRUCell(
79-
units,
80-
activation=None)
81-
with variable_scope.variable_scope("root1") as scope:
82-
outputs, cell_state = tf.nn.dynamic_rnn(
83-
cell,
84-
x,
85-
dtype=tf.float32,
86-
sequence_length=[4],
87-
scope=scope)
88-
gru_output_list.append(outputs)
89-
gru_cell_state_list.append(cell_state)
83+
dtype=tf.float32,
84+
sequence_length=[4],
85+
scope=scope)
86+
gru_output_list.append(outputs)
87+
gru_cell_state_list.append(cell_state)
9088

9189
_ = tf.identity(gru_output_list, name="output")
9290
_ = tf.identity(gru_cell_state_list, name="cell_state")
@@ -325,19 +323,18 @@ def test_dynamic_bigru(self):
325323
x = tf.placeholder(tf.float32, x_val.shape, name="input_1")
326324
initializer = init_ops.constant_initializer(0.5)
327325

328-
if True:
329-
# bigru, no scope
330-
cell1 = rnn.GRUCell(
331-
units,
332-
kernel_initializer=initializer)
333-
cell2 = rnn.GRUCell(
334-
units,
335-
kernel_initializer=initializer)
336-
outputs, cell_state = tf.nn.bidirectional_dynamic_rnn(
337-
cell1,
338-
cell2,
339-
x,
340-
dtype=tf.float32)
326+
# bigru, no scope
327+
cell1 = rnn.GRUCell(
328+
units,
329+
kernel_initializer=initializer)
330+
cell2 = rnn.GRUCell(
331+
units,
332+
kernel_initializer=initializer)
333+
outputs, cell_state = tf.nn.bidirectional_dynamic_rnn(
334+
cell1,
335+
cell2,
336+
x,
337+
dtype=tf.float32)
341338

342339
_ = tf.identity(outputs, name="output")
343340
_ = tf.identity(cell_state, name="cell_state")
@@ -357,19 +354,18 @@ def test_dynamic_bigru_output_consumed_only(self):
357354
x = tf.placeholder(tf.float32, x_val.shape, name="input_1")
358355
initializer = init_ops.constant_initializer(0.5)
359356

360-
if True:
361-
# bigru, no scope
362-
cell1 = rnn.GRUCell(
363-
units,
364-
kernel_initializer=initializer)
365-
cell2 = rnn.GRUCell(
366-
units,
367-
kernel_initializer=initializer)
368-
outputs, _ = tf.nn.bidirectional_dynamic_rnn(
369-
cell1,
370-
cell2,
371-
x,
372-
dtype=tf.float32)
357+
# bigru, no scope
358+
cell1 = rnn.GRUCell(
359+
units,
360+
kernel_initializer=initializer)
361+
cell2 = rnn.GRUCell(
362+
units,
363+
kernel_initializer=initializer)
364+
outputs, _ = tf.nn.bidirectional_dynamic_rnn(
365+
cell1,
366+
cell2,
367+
x,
368+
dtype=tf.float32)
373369

374370
_ = tf.identity(outputs, name="output")
375371

@@ -388,21 +384,20 @@ def test_dynamic_bigru_state_consumed_only(self):
388384
x = tf.placeholder(tf.float32, x_val.shape, name="input_1")
389385
initializer = init_ops.constant_initializer(0.5)
390386

391-
if True:
392-
# bigru, no scope
393-
cell1 = rnn.GRUCell(
394-
units,
395-
kernel_initializer=initializer)
396-
cell2 = rnn.GRUCell(
397-
units,
398-
kernel_initializer=initializer)
399-
_, cell_state = tf.nn.bidirectional_dynamic_rnn(
400-
cell1,
401-
cell2,
402-
x,
403-
dtype=tf.float32)
387+
# bigru, no scope
388+
cell1 = rnn.GRUCell(
389+
units,
390+
kernel_initializer=initializer)
391+
cell2 = rnn.GRUCell(
392+
units,
393+
kernel_initializer=initializer)
394+
_, cell_state = tf.nn.bidirectional_dynamic_rnn(
395+
cell1,
396+
cell2,
397+
x,
398+
dtype=tf.float32)
404399

405-
_ = tf.identity(cell_state, name="cell_state")
400+
_ = tf.identity(cell_state, name="cell_state")
406401

407402
feed_dict = {"input_1:0": x_val}
408403
input_names_with_port = ["input_1:0"]
@@ -419,16 +414,15 @@ def test_dynamic_bidirectional_but_one_gru(self):
419414
x = tf.placeholder(tf.float32, x_val.shape, name="input_1")
420415
initializer = init_ops.constant_initializer(0.5)
421416

422-
if True:
423-
# bigru, no scope
424-
cell = rnn.GRUCell(
425-
units,
426-
kernel_initializer=initializer)
427-
outputs, cell_state = tf.nn.bidirectional_dynamic_rnn(
428-
cell,
429-
cell,
430-
x,
431-
dtype=tf.float32)
417+
# bigru, no scope
418+
cell = rnn.GRUCell(
419+
units,
420+
kernel_initializer=initializer)
421+
outputs, cell_state = tf.nn.bidirectional_dynamic_rnn(
422+
cell,
423+
cell,
424+
x,
425+
dtype=tf.float32)
432426

433427
_ = tf.identity(outputs, name="output")
434428
_ = tf.identity(cell_state, name="cell_state")
@@ -447,15 +441,14 @@ def test_dynamic_bidirectional_but_one_gru_and_output_consumed_only(self):
447441

448442
x = tf.placeholder(tf.float32, x_val.shape, name="input_1")
449443

450-
if True:
451-
# bigru, no scope
452-
cell = rnn.GRUCell(
453-
units)
454-
outputs, _ = tf.nn.bidirectional_dynamic_rnn(
455-
cell,
456-
cell,
457-
x,
458-
dtype=tf.float32)
444+
# bigru, no scope
445+
cell = rnn.GRUCell(
446+
units)
447+
outputs, _ = tf.nn.bidirectional_dynamic_rnn(
448+
cell,
449+
cell,
450+
x,
451+
dtype=tf.float32)
459452

460453
_ = tf.identity(outputs, name="output")
461454

@@ -473,15 +466,14 @@ def test_dynamic_bidirectional_but_one_gru_and_state_consumed_only(self):
473466

474467
x = tf.placeholder(tf.float32, x_val.shape, name="input_1")
475468

476-
if True:
477-
# bigru, no scope
478-
cell = rnn.GRUCell(
479-
units)
480-
_, cell_state = tf.nn.bidirectional_dynamic_rnn(
481-
cell,
482-
cell,
483-
x,
484-
dtype=tf.float32)
469+
# bigru, no scope
470+
cell = rnn.GRUCell(
471+
units)
472+
_, cell_state = tf.nn.bidirectional_dynamic_rnn(
473+
cell,
474+
cell,
475+
x,
476+
dtype=tf.float32)
485477

486478
_ = tf.identity(cell_state, name="cell_state")
487479

0 commit comments

Comments
 (0)