Skip to content

Commit 67bcd88

Browse files
authored
Fix GRU with return_state=True on tf backend with cuda (#21603)
1 parent ac5c97f commit 67bcd88

File tree

2 files changed

+36
-1
lines changed

2 files changed

+36
-1
lines changed

keras/src/backend/tensorflow/rnn.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -778,7 +778,7 @@ def _cudnn_gru(
778778
return (
779779
last_output,
780780
outputs,
781-
state,
781+
[state],
782782
)
783783

784784

keras/src/layers/rnn/gru_test.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -205,6 +205,41 @@ def test_pass_initial_state(self):
205205
output,
206206
)
207207

208+
def test_pass_return_state(self):
209+
sequence = np.arange(24).reshape((2, 4, 3)).astype("float32")
210+
initial_state = np.arange(4).reshape((2, 2)).astype("float32")
211+
212+
# Test with go_backwards=False
213+
layer = layers.GRU(
214+
2,
215+
kernel_initializer=initializers.Constant(0.01),
216+
recurrent_initializer=initializers.Constant(0.02),
217+
bias_initializer=initializers.Constant(0.03),
218+
return_state=True,
219+
)
220+
output, state = layer(sequence, initial_state=initial_state)
221+
self.assertAllClose(
222+
np.array([[0.23774096, 0.33508456], [0.83659905, 1.0227708]]),
223+
output,
224+
)
225+
self.assertAllClose(output, state)
226+
227+
# Test with go_backwards=True
228+
layer = layers.GRU(
229+
2,
230+
kernel_initializer=initializers.Constant(0.01),
231+
recurrent_initializer=initializers.Constant(0.02),
232+
bias_initializer=initializers.Constant(0.03),
233+
return_state=True,
234+
go_backwards=True,
235+
)
236+
output, state = layer(sequence, initial_state=initial_state)
237+
self.assertAllClose(
238+
np.array([[0.13486053, 0.23261218], [0.78257304, 0.9691353]]),
239+
output,
240+
)
241+
self.assertAllClose(output, state)
242+
208243
def test_masking(self):
209244
sequence = np.arange(24).reshape((2, 4, 3)).astype("float32")
210245
mask = np.array([[True, True, False, True], [True, False, False, True]])

0 commit comments

Comments
 (0)