@@ -113,6 +113,7 @@ def __init__(
113
113
)
114
114
implementation = kwargs .pop ("implementation" , 2 )
115
115
super ().__init__ (** kwargs )
116
+ self .implementation = implementation
116
117
self .units = units
117
118
self .activation = activations .get (activation )
118
119
self .recurrent_activation = activations .get (recurrent_activation )
@@ -132,13 +133,16 @@ def __init__(
132
133
133
134
self .dropout = min (1.0 , max (0.0 , dropout ))
134
135
self .recurrent_dropout = min (1.0 , max (0.0 , recurrent_dropout ))
136
+ if self .recurrent_dropout != 0.0 :
137
+ self .implementation = 1
138
+ if self .implementation == 1 :
139
+ self .dropout_mask_count = 4
135
140
self .seed = seed
136
141
self .seed_generator = backend .random .SeedGenerator (seed = seed )
137
142
138
143
self .unit_forget_bias = unit_forget_bias
139
144
self .state_size = [self .units , self .units ]
140
145
self .output_size = self .units
141
- self .implementation = implementation
142
146
143
147
def build (self , input_shape ):
144
148
super ().build (input_shape )
@@ -228,19 +232,18 @@ def call(self, inputs, states, training=False):
228
232
h_tm1 = states [0 ] # previous memory state
229
233
c_tm1 = states [1 ] # previous carry state
230
234
231
- dp_mask = self .get_dropout_mask (inputs )
232
- rec_dp_mask = self .get_recurrent_dropout_mask (h_tm1 )
233
-
234
- if training and 0.0 < self .dropout < 1.0 :
235
- inputs = inputs * dp_mask
236
- if training and 0.0 < self .recurrent_dropout < 1.0 :
237
- h_tm1 = h_tm1 * rec_dp_mask
238
-
239
235
if self .implementation == 1 :
240
- inputs_i = inputs
241
- inputs_f = inputs
242
- inputs_c = inputs
243
- inputs_o = inputs
236
+ if training and 0.0 < self .dropout < 1.0 :
237
+ dp_mask = self .get_dropout_mask (inputs )
238
+ inputs_i = inputs * dp_mask [0 ]
239
+ inputs_f = inputs * dp_mask [1 ]
240
+ inputs_c = inputs * dp_mask [2 ]
241
+ inputs_o = inputs * dp_mask [3 ]
242
+ else :
243
+ inputs_i = inputs
244
+ inputs_f = inputs
245
+ inputs_c = inputs
246
+ inputs_o = inputs
244
247
k_i , k_f , k_c , k_o = ops .split (self .kernel , 4 , axis = 1 )
245
248
x_i = ops .matmul (inputs_i , k_i )
246
249
x_f = ops .matmul (inputs_f , k_f )
@@ -253,14 +256,25 @@ def call(self, inputs, states, training=False):
253
256
x_c += b_c
254
257
x_o += b_o
255
258
256
- h_tm1_i = h_tm1
257
- h_tm1_f = h_tm1
258
- h_tm1_c = h_tm1
259
- h_tm1_o = h_tm1
259
+ if training and 0.0 < self .recurrent_dropout < 1.0 :
260
+ rec_dp_mask = self .get_recurrent_dropout_mask (h_tm1 )
261
+ h_tm1_i = h_tm1 * rec_dp_mask [0 ]
262
+ h_tm1_f = h_tm1 * rec_dp_mask [1 ]
263
+ h_tm1_c = h_tm1 * rec_dp_mask [2 ]
264
+ h_tm1_o = h_tm1 * rec_dp_mask [3 ]
265
+ else :
266
+ h_tm1_i = h_tm1
267
+ h_tm1_f = h_tm1
268
+ h_tm1_c = h_tm1
269
+ h_tm1_o = h_tm1
260
270
x = (x_i , x_f , x_c , x_o )
261
271
h_tm1 = (h_tm1_i , h_tm1_f , h_tm1_c , h_tm1_o )
262
272
c , o = self ._compute_carry_and_output (x , h_tm1 , c_tm1 )
263
273
else :
274
+ if training and 0.0 < self .dropout < 1.0 :
275
+ dp_mask = self .get_dropout_mask (inputs )
276
+ inputs = inputs * dp_mask
277
+
264
278
z = ops .matmul (inputs , self .kernel )
265
279
266
280
z += ops .matmul (h_tm1 , self .recurrent_kernel )
0 commit comments