@@ -329,7 +329,7 @@ def __init__(self, input_size, hidden_size, gate_non_linearity="sigmoid", zetaIn
329
329
self ._nuInit = nuInit
330
330
self ._name = name
331
331
self ._gate_non_linearity = NON_LINEARITY [gate_non_linearity ]
332
- self .W = nn .Parameter (0.1 * torch .randn ([input_size , hidden_size ]))
332
+ self .W = nn .Parameter (0.1 * torch .randn ([hidden_size , input_size ]))
333
333
self .U = nn .Parameter (0.1 * torch .randn ([hidden_size , hidden_size ]))
334
334
335
335
self .bias_gate = nn .Parameter (torch .ones ([1 , hidden_size ]))
@@ -1065,7 +1065,8 @@ def forward(self, input, hiddenState=None, cellState=None):
1065
1065
1066
1066
class FastGRNNCUDA (nn .Module ):
1067
1067
"""Unrolled implementation of the FastGRNNCUDACell"""
1068
- def __init__ (self , input_size , hidden_size , gate_non_linearity = "sigmoid" , zetaInit = 1.0 , nuInit = - 4.0 , name = "FastGRNNCUDACell" ):
1068
+ def __init__ (self , input_size , hidden_size , gate_nonlinearity = "sigmoid" ,
1069
+ update_nonlinearity = "tanh" , wRank = None , uRank = None , zetaInit = 1.0 , nuInit = - 4.0 , name = "FastGRNNCUDACell" ):
1069
1070
super (FastGRNNCUDA , self ).__init__ ()
1070
1071
if utils .findCUDA () is None :
1071
1072
raise Exception ('FastGRNNCUDA is supported only on GPU devices.' )
@@ -1075,7 +1076,34 @@ def __init__(self, input_size, hidden_size, gate_non_linearity="sigmoid", zetaIn
1075
1076
self ._zetaInit = zetaInit
1076
1077
self ._nuInit = nuInit
1077
1078
self ._name = name
1078
- self ._gate_non_linearity = NON_LINEARITY [gate_non_linearity ]
1079
+
1080
+ if wRank is not None :
1081
+ self ._num_W_matrices += 1
1082
+ self ._num_weight_matrices [0 ] = self ._num_W_matrices
1083
+ if uRank is not None :
1084
+ self ._num_U_matrices += 1
1085
+ self ._num_weight_matrices [1 ] = self ._num_U_matrices
1086
+ self ._name = name
1087
+
1088
+ if wRank is None :
1089
+ self .W = nn .Parameter (0.1 * torch .randn ([hidden_size , input_size ]))
1090
+ self .W1 = torch .empty (0 )
1091
+ self .W2 = torch .empty (0 )
1092
+ else :
1093
+ self .W = torch .empty (0 )
1094
+ self .W1 = nn .Parameter (0.1 * torch .randn ([wRank , input_size ]))
1095
+ self .W2 = nn .Parameter (0.1 * torch .randn ([hidden_size , wRank ]))
1096
+
1097
+ if uRank is None :
1098
+ self .U = nn .Parameter (0.1 * torch .randn ([hidden_size , hidden_size ]))
1099
+ self .U1 = torch .empty (0 )
1100
+ self .U2 = torch .empty (0 )
1101
+ else :
1102
+ self .U = torch .empty (0 )
1103
+ self .U1 = nn .Parameter (0.1 * torch .randn ([uRank , hidden_size ]))
1104
+ self .U2 = nn .Parameter (0.1 * torch .randn ([hidden_size , uRank ]))
1105
+
1106
+ self ._gate_non_linearity = NON_LINEARITY [gate_nonlinearity ]
1079
1107
self .W = nn .Parameter (0.1 * torch .randn ([input_size , hidden_size ]))
1080
1108
self .U = nn .Parameter (0.1 * torch .randn ([hidden_size , hidden_size ]))
1081
1109
@@ -1086,9 +1114,12 @@ def __init__(self, input_size, hidden_size, gate_non_linearity="sigmoid", zetaIn
1086
1114
1087
1115
def forward (self , input , h_state , cell_state = None ):
1088
1116
# input: [timesteps, batch, features, state_size]
1089
- return FastGRNNUnrollFunction .apply (input , self .W , self .U , self .bias_gate , self .bias_update , self .zeta , self .nu , h_state , self ._gate_non_linearity )
1117
+ return FastGRNNUnrollFunction .apply (input , self .bias_gate , self .bias_update , self .zeta , self .nu , h_state ,
1118
+ self .W , self .U , self .W1 , self .W2 , self .U1 , self .U2 , self ._gate_non_linearity )
1090
1119
1091
1120
def getVars (self ):
1121
+ if self ._num_W_matrices != 1 :
1122
+ return [self .W1 , self .W2 , self .U1 , self .U2 , self .bias_gate , self .bias_update , self .zeta , self .nu ]
1092
1123
return [self .W , self .U , self .bias_gate , self .bias_update , self .zeta , self .nu ]
1093
1124
1094
1125
class SRNN2 (nn .Module ):
@@ -1225,10 +1256,10 @@ def backward(ctx, grad_h):
1225
1256
1226
1257
class FastGRNNUnrollFunction (Function ):
1227
1258
@staticmethod
1228
- def forward (ctx , input , w , u , bias_gate , bias_update , zeta , nu , old_h , gate_non_linearity ):
1229
- outputs = fastgrnn_cuda .forward_unroll (input , w , u , bias_gate , bias_update , zeta , nu , old_h , gate_non_linearity )
1259
+ def forward (ctx , input , bias_gate , bias_update , zeta , nu , old_h , w , u , w1 , w2 , u1 , u2 , gate_non_linearity ):
1260
+ outputs = fastgrnn_cuda .forward_unroll (input , w , u , bias_gate , bias_update , zeta , nu , old_h , gate_non_linearity , w1 , w2 , u1 , u2 )
1230
1261
hidden_states = outputs [0 ]
1231
- variables = [input , hidden_states , zeta , nu , w , u ] + outputs [1 :] + [old_h ]
1262
+ variables = [input , hidden_states , zeta , nu , w , u ] + outputs [1 :] + [old_h , w1 , w2 , u1 , u2 ]
1232
1263
ctx .save_for_backward (* variables )
1233
1264
ctx .gate_non_linearity = gate_non_linearity
1234
1265
return hidden_states
@@ -1237,5 +1268,4 @@ def forward(ctx, input, w, u, bias_gate, bias_update, zeta, nu, old_h, gate_non
1237
1268
def backward (ctx , grad_h ):
1238
1269
outputs = fastgrnn_cuda .backward_unroll (
1239
1270
grad_h .contiguous (), * ctx .saved_variables , ctx .gate_non_linearity )
1240
- d_input , d_w , d_u , d_bias_gate , d_bias_update , d_zeta , d_nu , d_old_h = outputs
1241
- return d_input , d_w , d_u , d_bias_gate , d_bias_update , d_zeta , d_nu , d_old_h
1271
+ return tuple (outputs + [None ])
0 commit comments