Skip to content

Commit 72ca0a0

Browse files
authored
Fix recurrent dropout for GRU. (#20656)
The simplified implementation, which used the same recurrent dropout masks for all the previous states didn't work and caused the training to not converge with large enough recurrent dropout values. This new implementation is now the same as Keras 2. Note that recurrent dropout requires "implementation 1" to be turned on. Fixes #20276
1 parent 4c7c4b5 commit 72ca0a0

File tree

1 file changed

+25
-14
lines changed

1 file changed

+25
-14
lines changed

keras/src/layers/rnn/gru.py

Lines changed: 25 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,9 @@ def __init__(
131131

132132
self.dropout = min(1.0, max(0.0, dropout))
133133
self.recurrent_dropout = min(1.0, max(0.0, recurrent_dropout))
134+
if self.recurrent_dropout != 0.0:
135+
self.implementation = 1
136+
self.dropout_mask_count = 3
134137
self.seed = seed
135138
self.seed_generator = backend.random.SeedGenerator(seed=seed)
136139

@@ -181,9 +184,6 @@ def call(self, inputs, states, training=False):
181184
states[0] if tree.is_nested(states) else states
182185
) # previous state
183186

184-
dp_mask = self.get_dropout_mask(inputs)
185-
rec_dp_mask = self.get_recurrent_dropout_mask(h_tm1)
186-
187187
if self.use_bias:
188188
if not self.reset_after:
189189
input_bias, recurrent_bias = self.bias, None
@@ -193,15 +193,16 @@ def call(self, inputs, states, training=False):
193193
for e in ops.split(self.bias, self.bias.shape[0], axis=0)
194194
)
195195

196-
if training and 0.0 < self.dropout < 1.0:
197-
inputs = inputs * dp_mask
198-
if training and 0.0 < self.recurrent_dropout < 1.0:
199-
h_tm1 = h_tm1 * rec_dp_mask
200-
201196
if self.implementation == 1:
202-
inputs_z = inputs
203-
inputs_r = inputs
204-
inputs_h = inputs
197+
if training and 0.0 < self.dropout < 1.0:
198+
dp_mask = self.get_dropout_mask(inputs)
199+
inputs_z = inputs * dp_mask[0]
200+
inputs_r = inputs * dp_mask[1]
201+
inputs_h = inputs * dp_mask[2]
202+
else:
203+
inputs_z = inputs
204+
inputs_r = inputs
205+
inputs_h = inputs
205206

206207
x_z = ops.matmul(inputs_z, self.kernel[:, : self.units])
207208
x_r = ops.matmul(
@@ -214,9 +215,15 @@ def call(self, inputs, states, training=False):
214215
x_r += input_bias[self.units : self.units * 2]
215216
x_h += input_bias[self.units * 2 :]
216217

217-
h_tm1_z = h_tm1
218-
h_tm1_r = h_tm1
219-
h_tm1_h = h_tm1
218+
if training and 0.0 < self.recurrent_dropout < 1.0:
219+
rec_dp_mask = self.get_recurrent_dropout_mask(h_tm1)
220+
h_tm1_z = h_tm1 * rec_dp_mask[0]
221+
h_tm1_r = h_tm1 * rec_dp_mask[1]
222+
h_tm1_h = h_tm1 * rec_dp_mask[2]
223+
else:
224+
h_tm1_z = h_tm1
225+
h_tm1_r = h_tm1
226+
h_tm1_h = h_tm1
220227

221228
recurrent_z = ops.matmul(
222229
h_tm1_z, self.recurrent_kernel[:, : self.units]
@@ -246,6 +253,10 @@ def call(self, inputs, states, training=False):
246253

247254
hh = self.activation(x_h + recurrent_h)
248255
else:
256+
if training and 0.0 < self.dropout < 1.0:
257+
dp_mask = self.get_dropout_mask(inputs)
258+
inputs = inputs * dp_mask[0]
259+
249260
# inputs projected by all gate matrices at once
250261
matrix_x = ops.matmul(inputs, self.kernel)
251262
if self.use_bias:

0 commit comments

Comments
 (0)