@@ -318,30 +318,51 @@ class FastGRNNCUDACell(RNNCell):
318
318
h_t = z_t*h_{t-1} + (sigmoid(zeta)(1-z_t) + sigmoid(nu))*h_t^
319
319
320
320
'''
321
- def __init__ (self , input_size , hidden_size , gate_non_linearity = "sigmoid" , zetaInit = 1.0 , nuInit = - 4.0 , name = "FastGRNNCUDACell" ):
322
- super (FastGRNNCUDACell , self ).__init__ (input_size , hidden_size , gate_non_linearity , "tanh" , 1 , 1 , 2 )
321
+ def __init__ (self , input_size , hidden_size , gate_nonlinearity = "sigmoid" ,
322
+ update_nonlinearity = "tanh" , wRank = None , uRank = None , zetaInit = 1.0 , nuInit = - 4.0 , name = "FastGRNNCUDACell" ):
323
+ super (FastGRNNCUDACell , self ).__init__ (input_size , hidden_size , gate_non_linearity , update_nonlinearity , 1 , 1 , 2 , wRank , uRank )
323
324
if utils .findCUDA () is None :
324
- raise Exception ('FastGRNNCUDACell is supported only on GPU devices.' )
325
+ raise Exception ('FastGRNNCUDA is supported only on GPU devices.' )
325
326
NON_LINEARITY = {"sigmoid" : 0 , "relu" : 1 , "tanh" : 2 }
326
327
self ._input_size = input_size
327
328
self ._hidden_size = hidden_size
328
329
self ._zetaInit = zetaInit
329
330
self ._nuInit = nuInit
330
331
self ._name = name
331
- self ._gate_non_linearity = NON_LINEARITY [gate_non_linearity ]
332
- self .W = nn .Parameter (0.1 * torch .randn ([hidden_size , input_size ]))
333
- self .U = nn .Parameter (0.1 * torch .randn ([hidden_size , hidden_size ]))
332
+
333
+ if wRank is not None :
334
+ self ._num_W_matrices += 1
335
+ self ._num_weight_matrices [0 ] = self ._num_W_matrices
336
+ if uRank is not None :
337
+ self ._num_U_matrices += 1
338
+ self ._num_weight_matrices [1 ] = self ._num_U_matrices
339
+ self ._name = name
340
+
341
+ if wRank is None :
342
+ self .W = nn .Parameter (0.1 * torch .randn ([hidden_size , input_size ]))
343
+ self .W1 = torch .empty (0 )
344
+ self .W2 = torch .empty (0 )
345
+ else :
346
+ self .W = torch .empty (0 )
347
+ self .W1 = nn .Parameter (0.1 * torch .randn ([wRank , input_size ]))
348
+ self .W2 = nn .Parameter (0.1 * torch .randn ([hidden_size , wRank ]))
349
+
350
+ if uRank is None :
351
+ self .U = nn .Parameter (0.1 * torch .randn ([hidden_size , hidden_size ]))
352
+ self .U1 = torch .empty (0 )
353
+ self .U2 = torch .empty (0 )
354
+ else :
355
+ self .U = torch .empty (0 )
356
+ self .U1 = nn .Parameter (0.1 * torch .randn ([uRank , hidden_size ]))
357
+ self .U2 = nn .Parameter (0.1 * torch .randn ([hidden_size , uRank ]))
358
+
359
+ self ._gate_non_linearity = NON_LINEARITY [gate_nonlinearity ]
334
360
335
361
self .bias_gate = nn .Parameter (torch .ones ([1 , hidden_size ]))
336
362
self .bias_update = nn .Parameter (torch .ones ([1 , hidden_size ]))
337
363
self .zeta = nn .Parameter (self ._zetaInit * torch .ones ([1 , 1 ]))
338
364
self .nu = nn .Parameter (self ._nuInit * torch .ones ([1 , 1 ]))
339
365
340
- def reset_parameters (self ):
341
- stdv = 1.0 / math .sqrt (self .state_size )
342
- for weight in self .parameters ():
343
- weight .data .uniform_ (- stdv , + stdv )
344
-
345
366
@property
346
367
def name (self ):
347
368
return self ._name
@@ -352,10 +373,23 @@ def cellType(self):
352
373
353
374
def forward (self , input , state ):
354
375
# Calls the custom autograd function while invokes the CUDA implementation
355
- return FastGRNNFunction .apply (input , self .W , self .U , self .bias_gate , self .bias_update , self .zeta , self .nu , state , self ._gate_non_linearity )
376
+ return FastGRNNFunction .apply (input , self .bias_gate , self .bias_update , self .zeta , self .nu , h_state ,
377
+ self .W , self .U , self .W1 , self .W2 , self .U1 , self .U2 , self ._gate_non_linearity )
356
378
357
379
def getVars (self ):
358
- return [self .W , self .U , self .bias_gate , self .bias_update , self .zeta , self .nu ]
380
+ Vars = []
381
+ if self ._num_W_matrices == 1 :
382
+ Vars .append (self .W )
383
+ else :
384
+ Vars .extend ([self .W1 , self .W2 ])
385
+
386
+ if self ._num_U_matrices == 1 :
387
+ Vars .append (self .U )
388
+ else :
389
+ Vars .extend ([self .U1 , self .U2 ])
390
+
391
+ Vars .extend ([self .bias_gate , self .bias_update , self .zeta , self .nu ])
392
+ return Vars
359
393
360
394
class FastRNNCell (RNNCell ):
361
395
'''
@@ -1104,8 +1138,6 @@ def __init__(self, input_size, hidden_size, gate_nonlinearity="sigmoid",
1104
1138
self .U2 = nn .Parameter (0.1 * torch .randn ([hidden_size , uRank ]))
1105
1139
1106
1140
self ._gate_non_linearity = NON_LINEARITY [gate_nonlinearity ]
1107
- self .W = nn .Parameter (0.1 * torch .randn ([input_size , hidden_size ]))
1108
- self .U = nn .Parameter (0.1 * torch .randn ([hidden_size , hidden_size ]))
1109
1141
1110
1142
self .bias_gate = nn .Parameter (torch .ones ([1 , hidden_size ]))
1111
1143
self .bias_update = nn .Parameter (torch .ones ([1 , hidden_size ]))
@@ -1118,9 +1150,19 @@ def forward(self, input, h_state, cell_state=None):
1118
1150
self .W , self .U , self .W1 , self .W2 , self .U1 , self .U2 , self ._gate_non_linearity )
1119
1151
1120
1152
def getVars (self ):
1121
- if self ._num_W_matrices != 1 :
1122
- return [self .W1 , self .W2 , self .U1 , self .U2 , self .bias_gate , self .bias_update , self .zeta , self .nu ]
1123
- return [self .W , self .U , self .bias_gate , self .bias_update , self .zeta , self .nu ]
1153
+ Vars = []
1154
+ if self ._num_W_matrices == 1 :
1155
+ Vars .append (self .W )
1156
+ else :
1157
+ Vars .extend ([self .W1 , self .W2 ])
1158
+
1159
+ if self ._num_U_matrices == 1 :
1160
+ Vars .append (self .U )
1161
+ else :
1162
+ Vars .extend ([self .U1 , self .U2 ])
1163
+
1164
+ Vars .extend ([self .bias_gate , self .bias_update , self .zeta , self .nu ])
1165
+ return Vars
1124
1166
1125
1167
class SRNN2 (nn .Module ):
1126
1168
@@ -1239,10 +1281,10 @@ def forward(self, x, brickSize):
1239
1281
1240
1282
class FastGRNNFunction (Function ):
1241
1283
@staticmethod
1242
- def forward (ctx , input , w , u , bias_gate , bias_update , zeta , nu , old_h , gate_non_linearity ):
1243
- outputs = fastgrnn_cuda .forward (input , w , u , bias_gate , bias_update , zeta , nu , old_h , gate_non_linearity )
1284
+ def forward (ctx , input , bias_gate , bias_update , zeta , nu , old_h , w , u , w1 , w2 , u1 , u2 , gate_non_linearity ):
1285
+ outputs = fastgrnn_cuda .forward (input , w , u , bias_gate , bias_update , zeta , nu , old_h , gate_non_linearity , w1 , w2 , u1 , u2 )
1244
1286
new_h = outputs [0 ]
1245
- variables = [input , old_h , zeta , nu , w , u ] + outputs [1 :]
1287
+ variables = [input , old_h , zeta , nu , w , u ] + outputs [1 :] + [ w1 , w2 , u1 , u2 ]
1246
1288
ctx .save_for_backward (* variables )
1247
1289
ctx .non_linearity = gate_non_linearity
1248
1290
return new_h
@@ -1251,8 +1293,7 @@ def forward(ctx, input, w, u, bias_gate, bias_update, zeta, nu, old_h, gate_non_
1251
1293
def backward (ctx , grad_h ):
1252
1294
outputs = fastgrnn_cuda .backward (
1253
1295
grad_h .contiguous (), * ctx .saved_variables , ctx .non_linearity )
1254
- d_input , d_w , d_u , d_bias_gate , d_bias_update , d_zeta , d_nu , d_old_h = outputs
1255
- return d_input , d_w , d_u , d_bias_gate , d_bias_update , d_zeta , d_nu , d_old_h , None
1296
+ return tuple (outputs + [None ])
1256
1297
1257
1298
class FastGRNNUnrollFunction (Function ):
1258
1299
@staticmethod
0 commit comments