Skip to content

Commit a62e0ef

Browse files
authored
Change recurrent dropout implementation for LSTM. (#20663)
This change is to make the implementation of recurrent dropout consistent with GRU (changed as of #20656 ) and Keras 2. Also fixed a bug where the GRU fix would break when using CUDNN with a dropout and no recurrent dropout. The solution is to create multiple masks only when needed (implementation == 1). Added coverage for the case when dropout is set and recurrent dropout is not set.
1 parent eb8d13c commit a62e0ef

File tree

4 files changed

+54
-19
lines changed

4 files changed

+54
-19
lines changed

keras/src/layers/rnn/gru.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -133,7 +133,8 @@ def __init__(
133133
self.recurrent_dropout = min(1.0, max(0.0, recurrent_dropout))
134134
if self.recurrent_dropout != 0.0:
135135
self.implementation = 1
136-
self.dropout_mask_count = 3
136+
if self.implementation == 1:
137+
self.dropout_mask_count = 3
137138
self.seed = seed
138139
self.seed_generator = backend.random.SeedGenerator(seed=seed)
139140

@@ -255,7 +256,7 @@ def call(self, inputs, states, training=False):
255256
else:
256257
if training and 0.0 < self.dropout < 1.0:
257258
dp_mask = self.get_dropout_mask(inputs)
258-
inputs = inputs * dp_mask[0]
259+
inputs = inputs * dp_mask
259260

260261
# inputs projected by all gate matrices at once
261262
matrix_x = ops.matmul(inputs, self.kernel)

keras/src/layers/rnn/gru_test.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,16 @@
1010
class GRUTest(testing.TestCase):
1111
@pytest.mark.requires_trainable_backend
1212
def test_basics(self):
13+
self.run_layer_test(
14+
layers.GRU,
15+
init_kwargs={"units": 3, "dropout": 0.5},
16+
input_shape=(3, 2, 4),
17+
call_kwargs={"training": True},
18+
expected_output_shape=(3, 3),
19+
expected_num_trainable_weights=3,
20+
expected_num_non_trainable_weights=0,
21+
supports_masking=True,
22+
)
1323
self.run_layer_test(
1424
layers.GRU,
1525
init_kwargs={"units": 3, "dropout": 0.5, "recurrent_dropout": 0.5},

keras/src/layers/rnn/lstm.py

Lines changed: 31 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,7 @@ def __init__(
113113
)
114114
implementation = kwargs.pop("implementation", 2)
115115
super().__init__(**kwargs)
116+
self.implementation = implementation
116117
self.units = units
117118
self.activation = activations.get(activation)
118119
self.recurrent_activation = activations.get(recurrent_activation)
@@ -132,13 +133,16 @@ def __init__(
132133

133134
self.dropout = min(1.0, max(0.0, dropout))
134135
self.recurrent_dropout = min(1.0, max(0.0, recurrent_dropout))
136+
if self.recurrent_dropout != 0.0:
137+
self.implementation = 1
138+
if self.implementation == 1:
139+
self.dropout_mask_count = 4
135140
self.seed = seed
136141
self.seed_generator = backend.random.SeedGenerator(seed=seed)
137142

138143
self.unit_forget_bias = unit_forget_bias
139144
self.state_size = [self.units, self.units]
140145
self.output_size = self.units
141-
self.implementation = implementation
142146

143147
def build(self, input_shape):
144148
super().build(input_shape)
@@ -228,19 +232,18 @@ def call(self, inputs, states, training=False):
228232
h_tm1 = states[0] # previous memory state
229233
c_tm1 = states[1] # previous carry state
230234

231-
dp_mask = self.get_dropout_mask(inputs)
232-
rec_dp_mask = self.get_recurrent_dropout_mask(h_tm1)
233-
234-
if training and 0.0 < self.dropout < 1.0:
235-
inputs = inputs * dp_mask
236-
if training and 0.0 < self.recurrent_dropout < 1.0:
237-
h_tm1 = h_tm1 * rec_dp_mask
238-
239235
if self.implementation == 1:
240-
inputs_i = inputs
241-
inputs_f = inputs
242-
inputs_c = inputs
243-
inputs_o = inputs
236+
if training and 0.0 < self.dropout < 1.0:
237+
dp_mask = self.get_dropout_mask(inputs)
238+
inputs_i = inputs * dp_mask[0]
239+
inputs_f = inputs * dp_mask[1]
240+
inputs_c = inputs * dp_mask[2]
241+
inputs_o = inputs * dp_mask[3]
242+
else:
243+
inputs_i = inputs
244+
inputs_f = inputs
245+
inputs_c = inputs
246+
inputs_o = inputs
244247
k_i, k_f, k_c, k_o = ops.split(self.kernel, 4, axis=1)
245248
x_i = ops.matmul(inputs_i, k_i)
246249
x_f = ops.matmul(inputs_f, k_f)
@@ -253,14 +256,25 @@ def call(self, inputs, states, training=False):
253256
x_c += b_c
254257
x_o += b_o
255258

256-
h_tm1_i = h_tm1
257-
h_tm1_f = h_tm1
258-
h_tm1_c = h_tm1
259-
h_tm1_o = h_tm1
259+
if training and 0.0 < self.recurrent_dropout < 1.0:
260+
rec_dp_mask = self.get_recurrent_dropout_mask(h_tm1)
261+
h_tm1_i = h_tm1 * rec_dp_mask[0]
262+
h_tm1_f = h_tm1 * rec_dp_mask[1]
263+
h_tm1_c = h_tm1 * rec_dp_mask[2]
264+
h_tm1_o = h_tm1 * rec_dp_mask[3]
265+
else:
266+
h_tm1_i = h_tm1
267+
h_tm1_f = h_tm1
268+
h_tm1_c = h_tm1
269+
h_tm1_o = h_tm1
260270
x = (x_i, x_f, x_c, x_o)
261271
h_tm1 = (h_tm1_i, h_tm1_f, h_tm1_c, h_tm1_o)
262272
c, o = self._compute_carry_and_output(x, h_tm1, c_tm1)
263273
else:
274+
if training and 0.0 < self.dropout < 1.0:
275+
dp_mask = self.get_dropout_mask(inputs)
276+
inputs = inputs * dp_mask
277+
264278
z = ops.matmul(inputs, self.kernel)
265279

266280
z += ops.matmul(h_tm1, self.recurrent_kernel)

keras/src/layers/rnn/lstm_test.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,16 @@
1010
class LSTMTest(testing.TestCase):
1111
@pytest.mark.requires_trainable_backend
1212
def test_basics(self):
13+
self.run_layer_test(
14+
layers.LSTM,
15+
init_kwargs={"units": 3, "dropout": 0.5},
16+
input_shape=(3, 2, 4),
17+
call_kwargs={"training": True},
18+
expected_output_shape=(3, 3),
19+
expected_num_trainable_weights=3,
20+
expected_num_non_trainable_weights=0,
21+
supports_masking=True,
22+
)
1323
self.run_layer_test(
1424
layers.LSTM,
1525
init_kwargs={"units": 3, "dropout": 0.5, "recurrent_dropout": 0.5},

0 commit comments

Comments
 (0)