Skip to content

Commit 520c10f

Browse files
committed
fastgrnncuda: add sparsify support
1 parent a89588f commit 520c10f

File tree

1 file changed

+51
-2
lines changed
  • pytorch/edgeml_pytorch/graph

1 file changed

+51
-2
lines changed

pytorch/edgeml_pytorch/graph/rnn.py

Lines changed: 51 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -319,8 +319,9 @@ class FastGRNNCUDACell(RNNCell):
319319
320320
'''
321321
def __init__(self, input_size, hidden_size, gate_nonlinearity="sigmoid",
322-
update_nonlinearity="tanh", wRank=None, uRank=None, zetaInit=1.0, nuInit=-4.0, name="FastGRNNCUDACell"):
323-
super(FastGRNNCUDACell, self).__init__(input_size, hidden_size, gate_non_linearity, update_nonlinearity, 1, 1, 2, wRank, uRank)
322+
update_nonlinearity="tanh", wRank=None, uRank=None, zetaInit=1.0, nuInit=-4.0, wSparsity=1.0, uSparsity=1.0, name="FastGRNNCUDACell"):
323+
super(FastGRNNCUDACell, self).__init__(input_size, hidden_size, gate_non_linearity, update_nonlinearity,
324+
1, 1, 2, wRank, uRank, wSparsity, uSparsity)
324325
if utils.findCUDA() is None:
325326
raise Exception('FastGRNNCUDA is supported only on GPU devices.')
326327
NON_LINEARITY = {"sigmoid": 0, "relu": 1, "tanh": 2}
@@ -1166,6 +1167,54 @@ def getVars(self):
11661167
Vars.extend([self.bias_gate, self.bias_update, self.zeta, self.nu])
11671168
return Vars
11681169

1170+
def get_model_size(self):
1171+
'''
1172+
Function to get aimed model size
1173+
'''
1174+
mats = self.getVars()
1175+
endW = self._num_W_matrices
1176+
endU = endW + self._num_U_matrices
1177+
1178+
totalnnz = 2 # For Zeta and Nu
1179+
for i in range(0, endW):
1180+
device = mats[i].device
1181+
totalnnz += utils.countNNZ(mats[i].cpu(), self._wSparsity)
1182+
mats[i].to(device)
1183+
for i in range(endW, endU):
1184+
device = mats[i].device
1185+
totalnnz += utils.countNNZ(mats[i].cpu(), self._uSparsity)
1186+
mats[i].to(device)
1187+
for i in range(endU, len(mats)):
1188+
device = mats[i].device
1189+
totalnnz += utils.countNNZ(mats[i].cpu(), False)
1190+
mats[i].to(device)
1191+
return totalnnz * 4
1192+
1193+
def copy_previous_UW(self):
1194+
mats = self.getVars()
1195+
num_mats = self._num_W_matrices + self._num_U_matrices
1196+
if len(self.oldmats) != num_mats:
1197+
for i in range(num_mats):
1198+
self.oldmats.append(torch.FloatTensor())
1199+
for i in range(num_mats):
1200+
self.oldmats[i] = torch.FloatTensor(mats[i].detach().clone().to(mats[i].device))
1201+
1202+
def sparsify(self):
1203+
mats = self.getVars()
1204+
endW = self._num_W_matrices
1205+
endU = endW + self._num_U_matrices
1206+
for i in range(0, endW):
1207+
mats[i] = utils.hardThreshold(mats[i], self._wSparsity)
1208+
for i in range(endW, endU):
1209+
mats[i] = utils.hardThreshold(mats[i], self._uSparsity)
1210+
self.copy_previous_UW()
1211+
1212+
def sparsifyWithSupport(self):
1213+
mats = self.getVars()
1214+
endU = self._num_W_matrices + self._num_U_matrices
1215+
for i in range(0, endU):
1216+
mats[i] = utils.supportBasedThreshold(mats[i], self.oldmats[i])
1217+
11691218
class SRNN2(nn.Module):
11701219

11711220
def __init__(self, inputDim, outputDim, hiddenDim0, hiddenDim1, cellType,

0 commit comments

Comments
 (0)