@@ -131,6 +131,9 @@ def __init__(
131
131
132
132
self .dropout = min (1.0 , max (0.0 , dropout ))
133
133
self .recurrent_dropout = min (1.0 , max (0.0 , recurrent_dropout ))
134
+ if self .recurrent_dropout != 0.0 :
135
+ self .implementation = 1
136
+ self .dropout_mask_count = 3
134
137
self .seed = seed
135
138
self .seed_generator = backend .random .SeedGenerator (seed = seed )
136
139
@@ -181,9 +184,6 @@ def call(self, inputs, states, training=False):
181
184
states [0 ] if tree .is_nested (states ) else states
182
185
) # previous state
183
186
184
- dp_mask = self .get_dropout_mask (inputs )
185
- rec_dp_mask = self .get_recurrent_dropout_mask (h_tm1 )
186
-
187
187
if self .use_bias :
188
188
if not self .reset_after :
189
189
input_bias , recurrent_bias = self .bias , None
@@ -193,15 +193,16 @@ def call(self, inputs, states, training=False):
193
193
for e in ops .split (self .bias , self .bias .shape [0 ], axis = 0 )
194
194
)
195
195
196
- if training and 0.0 < self .dropout < 1.0 :
197
- inputs = inputs * dp_mask
198
- if training and 0.0 < self .recurrent_dropout < 1.0 :
199
- h_tm1 = h_tm1 * rec_dp_mask
200
-
201
196
if self .implementation == 1 :
202
- inputs_z = inputs
203
- inputs_r = inputs
204
- inputs_h = inputs
197
+ if training and 0.0 < self .dropout < 1.0 :
198
+ dp_mask = self .get_dropout_mask (inputs )
199
+ inputs_z = inputs * dp_mask [0 ]
200
+ inputs_r = inputs * dp_mask [1 ]
201
+ inputs_h = inputs * dp_mask [2 ]
202
+ else :
203
+ inputs_z = inputs
204
+ inputs_r = inputs
205
+ inputs_h = inputs
205
206
206
207
x_z = ops .matmul (inputs_z , self .kernel [:, : self .units ])
207
208
x_r = ops .matmul (
@@ -214,9 +215,15 @@ def call(self, inputs, states, training=False):
214
215
x_r += input_bias [self .units : self .units * 2 ]
215
216
x_h += input_bias [self .units * 2 :]
216
217
217
- h_tm1_z = h_tm1
218
- h_tm1_r = h_tm1
219
- h_tm1_h = h_tm1
218
+ if training and 0.0 < self .recurrent_dropout < 1.0 :
219
+ rec_dp_mask = self .get_recurrent_dropout_mask (h_tm1 )
220
+ h_tm1_z = h_tm1 * rec_dp_mask [0 ]
221
+ h_tm1_r = h_tm1 * rec_dp_mask [1 ]
222
+ h_tm1_h = h_tm1 * rec_dp_mask [2 ]
223
+ else :
224
+ h_tm1_z = h_tm1
225
+ h_tm1_r = h_tm1
226
+ h_tm1_h = h_tm1
220
227
221
228
recurrent_z = ops .matmul (
222
229
h_tm1_z , self .recurrent_kernel [:, : self .units ]
@@ -246,6 +253,10 @@ def call(self, inputs, states, training=False):
246
253
247
254
hh = self .activation (x_h + recurrent_h )
248
255
else :
256
+ if training and 0.0 < self .dropout < 1.0 :
257
+ dp_mask = self .get_dropout_mask (inputs )
258
+ inputs = inputs * dp_mask [0 ]
259
+
249
260
# inputs projected by all gate matrices at once
250
261
matrix_x = ops .matmul (inputs , self .kernel )
251
262
if self .use_bias :
0 commit comments