@@ -319,8 +319,9 @@ class FastGRNNCUDACell(RNNCell):
319
319
320
320
'''
321
321
def __init__ (self , input_size , hidden_size , gate_nonlinearity = "sigmoid" ,
322
- update_nonlinearity = "tanh" , wRank = None , uRank = None , zetaInit = 1.0 , nuInit = - 4.0 , name = "FastGRNNCUDACell" ):
323
- super (FastGRNNCUDACell , self ).__init__ (input_size , hidden_size , gate_non_linearity , update_nonlinearity , 1 , 1 , 2 , wRank , uRank )
322
+ update_nonlinearity = "tanh" , wRank = None , uRank = None , zetaInit = 1.0 , nuInit = - 4.0 , wSparsity = 1.0 , uSparsity = 1.0 , name = "FastGRNNCUDACell" ):
323
+ super (FastGRNNCUDACell , self ).__init__ (input_size , hidden_size , gate_non_linearity , update_nonlinearity ,
324
+ 1 , 1 , 2 , wRank , uRank , wSparsity , uSparsity )
324
325
if utils .findCUDA () is None :
325
326
raise Exception ('FastGRNNCUDA is supported only on GPU devices.' )
326
327
NON_LINEARITY = {"sigmoid" : 0 , "relu" : 1 , "tanh" : 2 }
@@ -1166,6 +1167,54 @@ def getVars(self):
1166
1167
Vars .extend ([self .bias_gate , self .bias_update , self .zeta , self .nu ])
1167
1168
return Vars
1168
1169
1170
+ def get_model_size (self ):
1171
+ '''
1172
+ Function to get aimed model size
1173
+ '''
1174
+ mats = self .getVars ()
1175
+ endW = self ._num_W_matrices
1176
+ endU = endW + self ._num_U_matrices
1177
+
1178
+ totalnnz = 2 # For Zeta and Nu
1179
+ for i in range (0 , endW ):
1180
+ device = mats [i ].device
1181
+ totalnnz += utils .countNNZ (mats [i ].cpu (), self ._wSparsity )
1182
+ mats [i ].to (device )
1183
+ for i in range (endW , endU ):
1184
+ device = mats [i ].device
1185
+ totalnnz += utils .countNNZ (mats [i ].cpu (), self ._uSparsity )
1186
+ mats [i ].to (device )
1187
+ for i in range (endU , len (mats )):
1188
+ device = mats [i ].device
1189
+ totalnnz += utils .countNNZ (mats [i ].cpu (), False )
1190
+ mats [i ].to (device )
1191
+ return totalnnz * 4
1192
+
1193
+ def copy_previous_UW (self ):
1194
+ mats = self .getVars ()
1195
+ num_mats = self ._num_W_matrices + self ._num_U_matrices
1196
+ if len (self .oldmats ) != num_mats :
1197
+ for i in range (num_mats ):
1198
+ self .oldmats .append (torch .FloatTensor ())
1199
+ for i in range (num_mats ):
1200
+ self .oldmats [i ] = torch .FloatTensor (mats [i ].detach ().clone ().to (mats [i ].device ))
1201
+
1202
+ def sparsify (self ):
1203
+ mats = self .getVars ()
1204
+ endW = self ._num_W_matrices
1205
+ endU = endW + self ._num_U_matrices
1206
+ for i in range (0 , endW ):
1207
+ mats [i ] = utils .hardThreshold (mats [i ], self ._wSparsity )
1208
+ for i in range (endW , endU ):
1209
+ mats [i ] = utils .hardThreshold (mats [i ], self ._uSparsity )
1210
+ self .copy_previous_UW ()
1211
+
1212
+ def sparsifyWithSupport (self ):
1213
+ mats = self .getVars ()
1214
+ endU = self ._num_W_matrices + self ._num_U_matrices
1215
+ for i in range (0 , endU ):
1216
+ mats [i ] = utils .supportBasedThreshold (mats [i ], self .oldmats [i ])
1217
+
1169
1218
class SRNN2 (nn .Module ):
1170
1219
1171
1220
def __init__ (self , inputDim , outputDim , hiddenDim0 , hiddenDim1 , cellType ,
0 commit comments