@@ -198,6 +198,53 @@ def test_attention_wrapper_lstm_encoder(self):
198
198
output_names_with_port = ["output_0:0" , "output:0" , "final_state:0" ]
199
199
self .run_test_case (feed_dict , input_names_with_port , output_names_with_port , 0.1 )
200
200
201
+ @check_opset_min_version (8 , "Scan" )
202
+ @check_tf_min_version ("1.8" )
203
+ def test_attention_wrapper_gru_encoder (self ):
204
+ size = 5
205
+ time_step = 3
206
+ input_size = 4
207
+ attn_size = size
208
+
209
+ batch_size = 9
210
+
211
+ # shape [batch size, time step, size]
212
+ # attention_state: usually the output of an RNN encoder.
213
+ # This tensor should be shaped `[batch_size, max_time, ...]`
214
+ encoder_time_step = time_step
215
+ encoder_x_val = np .random .randn (encoder_time_step , input_size ).astype ('f' )
216
+ encoder_x_val = np .stack ([encoder_x_val ] * batch_size )
217
+ encoder_x = tf .placeholder (tf .float32 , encoder_x_val .shape , name = "input_1" )
218
+ encoder_cell = tf .nn .rnn_cell .GRUCell (size )
219
+ output , attr_state = tf .nn .dynamic_rnn (encoder_cell , encoder_x , dtype = tf .float32 )
220
+ _ = tf .identity (output , name = "output_0" )
221
+ attention_states = output
222
+ attention_mechanism = tf .contrib .seq2seq .BahdanauAttention (attn_size ,
223
+ attention_states )
224
+
225
+ match_input_fn = lambda curr_input , state : tf .concat ([curr_input , state ], axis = - 1 )
226
+ cell = tf .nn .rnn_cell .GRUCell (size )
227
+ match_cell_fw = tf .contrib .seq2seq .AttentionWrapper (cell ,
228
+ attention_mechanism ,
229
+ attention_layer_size = attn_size ,
230
+ cell_input_fn = match_input_fn ,
231
+ output_attention = False )
232
+
233
+ decoder_time_step = 6
234
+ decoder_x_val = np .random .randn (decoder_time_step , input_size ).astype ('f' )
235
+ decoder_x_val = np .stack ([decoder_x_val ] * batch_size )
236
+
237
+ decoder_x = tf .placeholder (tf .float32 , decoder_x_val .shape , name = "input_2" )
238
+ output , attr_state = tf .nn .dynamic_rnn (match_cell_fw , decoder_x , dtype = tf .float32 )
239
+
240
+ _ = tf .identity (output , name = "output" )
241
+ _ = tf .identity (attr_state .cell_state , name = "final_state" )
242
+
243
+ feed_dict = {"input_1:0" : encoder_x_val , "input_2:0" : decoder_x_val }
244
+ input_names_with_port = ["input_1:0" , "input_2:0" ]
245
+ output_names_with_port = ["output_0:0" , "output:0" , "final_state:0" ]
246
+ self .run_test_case (feed_dict , input_names_with_port , output_names_with_port , 0.1 )
247
+
201
248
@check_opset_min_version (8 , "Scan" )
202
249
@check_tf_min_version ("1.8" )
203
250
def test_attention_wrapper_lstm_encoder_input_has_none_dim (self ):
0 commit comments