@@ -205,6 +205,41 @@ def test_pass_initial_state(self):
205
205
output ,
206
206
)
207
207
208
+ def test_pass_return_state (self ):
209
+ sequence = np .arange (24 ).reshape ((2 , 4 , 3 )).astype ("float32" )
210
+ initial_state = np .arange (4 ).reshape ((2 , 2 )).astype ("float32" )
211
+
212
+ # Test with go_backwards=False
213
+ layer = layers .GRU (
214
+ 2 ,
215
+ kernel_initializer = initializers .Constant (0.01 ),
216
+ recurrent_initializer = initializers .Constant (0.02 ),
217
+ bias_initializer = initializers .Constant (0.03 ),
218
+ return_state = True ,
219
+ )
220
+ output , state = layer (sequence , initial_state = initial_state )
221
+ self .assertAllClose (
222
+ np .array ([[0.23774096 , 0.33508456 ], [0.83659905 , 1.0227708 ]]),
223
+ output ,
224
+ )
225
+ self .assertAllClose (output , state )
226
+
227
+ # Test with go_backwards=True
228
+ layer = layers .GRU (
229
+ 2 ,
230
+ kernel_initializer = initializers .Constant (0.01 ),
231
+ recurrent_initializer = initializers .Constant (0.02 ),
232
+ bias_initializer = initializers .Constant (0.03 ),
233
+ return_state = True ,
234
+ go_backwards = True ,
235
+ )
236
+ output , state = layer (sequence , initial_state = initial_state )
237
+ self .assertAllClose (
238
+ np .array ([[0.13486053 , 0.23261218 ], [0.78257304 , 0.9691353 ]]),
239
+ output ,
240
+ )
241
+ self .assertAllClose (output , state )
242
+
208
243
def test_masking (self ):
209
244
sequence = np .arange (24 ).reshape ((2 , 4 , 3 )).astype ("float32" )
210
245
mask = np .array ([[True , True , False , True ], [True , False , False , True ]])
0 commit comments