Skip to content

Commit d3f2d1e

Browse files
authored
Merge pull request #445 from lucienwang1009/gru_rewriter
Refactor gru_rewriter
2 parents 757ed9a + dbf6d40 commit d3f2d1e

12 files changed

+401
-1491
lines changed

tests/test_custom_rnncell.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -198,6 +198,53 @@ def test_attention_wrapper_lstm_encoder(self):
198198
output_names_with_port = ["output_0:0", "output:0", "final_state:0"]
199199
self.run_test_case(feed_dict, input_names_with_port, output_names_with_port, 0.1)
200200

201+
@check_opset_min_version(8, "Scan")
202+
@check_tf_min_version("1.8")
203+
def test_attention_wrapper_gru_encoder(self):
204+
size = 5
205+
time_step = 3
206+
input_size = 4
207+
attn_size = size
208+
209+
batch_size = 9
210+
211+
# shape [batch size, time step, size]
212+
# attention_state: usually the output of an RNN encoder.
213+
# This tensor should be shaped `[batch_size, max_time, ...]`
214+
encoder_time_step = time_step
215+
encoder_x_val = np.random.randn(encoder_time_step, input_size).astype('f')
216+
encoder_x_val = np.stack([encoder_x_val] * batch_size)
217+
encoder_x = tf.placeholder(tf.float32, encoder_x_val.shape, name="input_1")
218+
encoder_cell = tf.nn.rnn_cell.GRUCell(size)
219+
output, attr_state = tf.nn.dynamic_rnn(encoder_cell, encoder_x, dtype=tf.float32)
220+
_ = tf.identity(output, name="output_0")
221+
attention_states = output
222+
attention_mechanism = tf.contrib.seq2seq.BahdanauAttention(attn_size,
223+
attention_states)
224+
225+
match_input_fn = lambda curr_input, state: tf.concat([curr_input, state], axis=-1)
226+
cell = tf.nn.rnn_cell.GRUCell(size)
227+
match_cell_fw = tf.contrib.seq2seq.AttentionWrapper(cell,
228+
attention_mechanism,
229+
attention_layer_size=attn_size,
230+
cell_input_fn=match_input_fn,
231+
output_attention=False)
232+
233+
decoder_time_step = 6
234+
decoder_x_val = np.random.randn(decoder_time_step, input_size).astype('f')
235+
decoder_x_val = np.stack([decoder_x_val] * batch_size)
236+
237+
decoder_x = tf.placeholder(tf.float32, decoder_x_val.shape, name="input_2")
238+
output, attr_state = tf.nn.dynamic_rnn(match_cell_fw, decoder_x, dtype=tf.float32)
239+
240+
_ = tf.identity(output, name="output")
241+
_ = tf.identity(attr_state.cell_state, name="final_state")
242+
243+
feed_dict = {"input_1:0": encoder_x_val, "input_2:0": decoder_x_val}
244+
input_names_with_port = ["input_1:0", "input_2:0"]
245+
output_names_with_port = ["output_0:0", "output:0", "final_state:0"]
246+
self.run_test_case(feed_dict, input_names_with_port, output_names_with_port, 0.1)
247+
201248
@check_opset_min_version(8, "Scan")
202249
@check_tf_min_version("1.8")
203250
def test_attention_wrapper_lstm_encoder_input_has_none_dim(self):

tf2onnx/rewriter/__init__.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from tf2onnx.rewriter.random_uniform import rewrite_random_uniform, rewrite_random_uniform_fold_const
1111
from tf2onnx.rewriter.leakyrelu_rewriter import rewrite_leakyrelu
1212
from tf2onnx.rewriter.rnn import rewrite_single_direction_lstm, rewrite_bi_direction_lstm, \
13-
rewrite_single_direction_gru, rewrite_single_direction_grublock, rewrite_bi_direction_gru, \
13+
rewrite_single_direction_gru, rewrite_bi_direction_gru, \
1414
rewrite_custom_rnn_cell, rewrite_generic_loop
1515

1616
__all__ = [
@@ -21,7 +21,6 @@
2121
"rewrite_single_direction_lstm",
2222
"rewrite_bi_direction_lstm",
2323
"rewrite_single_direction_gru",
24-
"rewrite_single_direction_grublock",
2524
"rewrite_bi_direction_gru",
2625
"rewrite_custom_rnn_cell",
2726
"rewrite_generic_loop"

tf2onnx/rewriter/gru_rewriter.py

Lines changed: 110 additions & 213 deletions
Large diffs are not rendered by default.

tf2onnx/rewriter/grublock_rewriter.py

Lines changed: 0 additions & 125 deletions
This file was deleted.

tf2onnx/rewriter/loop_rewriter_base.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -193,10 +193,13 @@ def rewrite(self, context):
193193
return REWRITER_RESULT.FAIL
194194

195195
def run_internal(self):
196+
loopcond_ops = []
196197
for op in self.g.get_nodes():
197-
if not is_loopcond_op(op):
198-
continue
198+
if is_loopcond_op(op):
199+
loopcond_ops.append(op)
199200

201+
# self.g.get_nodes may change inside this loop so that we parse all LoopCond first
202+
for op in loopcond_ops:
200203
log.debug("======================\n handling loop cond node called %s", op.name)
201204
context = self.create_context()
202205
context.loop_cond = op

0 commit comments

Comments
 (0)