Skip to content

Commit 4bcdf17

Browse files
committed
fastgrnncuda: added low rank support to unrolled version
1 parent 0b70d5e commit 4bcdf17

File tree

3 files changed

+136
-27
lines changed

3 files changed

+136
-27
lines changed

pytorch/edgeml_pytorch/cuda/fastgrnn_cuda.cpp

Lines changed: 46 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,11 @@ std::vector<torch::Tensor> fastgrnn_unroll_cuda_forward(
3434
torch::Tensor zeta,
3535
torch::Tensor nu,
3636
torch::Tensor initial_h,
37-
int z_non_linearity);
37+
int z_non_linearity,
38+
torch::Tensor w1,
39+
torch::Tensor w2,
40+
torch::Tensor u1,
41+
torch::Tensor u2);
3842

3943
std::vector<torch::Tensor> fastgrnn_unroll_cuda_backward(
4044
torch::Tensor grad_h,
@@ -47,7 +51,11 @@ std::vector<torch::Tensor> fastgrnn_unroll_cuda_backward(
4751
torch::Tensor z,
4852
torch::Tensor h_prime,
4953
torch::Tensor initial_h,
50-
int z_non_linearity);
54+
int z_non_linearity,
55+
torch::Tensor w1,
56+
torch::Tensor w2,
57+
torch::Tensor u1,
58+
torch::Tensor u2);
5159

5260
#define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor")
5361
#define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous")
@@ -108,16 +116,30 @@ std::vector<torch::Tensor> fastgrnn_unroll_forward(
108116
torch::Tensor zeta,
109117
torch::Tensor nu,
110118
torch::Tensor initial_h,
111-
int z_non_linearity) {
119+
int z_non_linearity,
120+
torch::Tensor w1,
121+
torch::Tensor w2,
122+
torch::Tensor u1,
123+
torch::Tensor u2) {
112124
CHECK_INPUT(input);
113-
CHECK_INPUT(w);
114-
CHECK_INPUT(u);
125+
if(w1.size(0) == 0) {
126+
CHECK_INPUT(w);
127+
} else {
128+
CHECK_INPUT(w1);
129+
CHECK_INPUT(w2);
130+
}
131+
if (u1.size(0) == 0) {
132+
CHECK_INPUT(u);
133+
} else {
134+
CHECK_INPUT(u1);
135+
CHECK_INPUT(u2);
136+
}
115137
CHECK_INPUT(bias_z);
116138
CHECK_INPUT(bias_h_prime);
117139
CHECK_INPUT(initial_h);
118140
CHECK_INPUT(zeta);
119141
CHECK_INPUT(nu);
120-
return fastgrnn_unroll_cuda_forward(input, w, u, bias_z, bias_h_prime, zeta, nu, initial_h, z_non_linearity);
142+
return fastgrnn_unroll_cuda_forward(input, w, u, bias_z, bias_h_prime, zeta, nu, initial_h, z_non_linearity, w1, w2, u1, u2);
121143
}
122144

123145
std::vector<torch::Tensor> fastgrnn_unroll_backward(
@@ -131,14 +153,28 @@ std::vector<torch::Tensor> fastgrnn_unroll_backward(
131153
torch::Tensor z,
132154
torch::Tensor h_prime,
133155
torch::Tensor initial_h,
156+
torch::Tensor w1,
157+
torch::Tensor w2,
158+
torch::Tensor u1,
159+
torch::Tensor u2,
134160
int z_non_linearity) {
135161
CHECK_INPUT(grad_h);
136162
CHECK_INPUT(input);
137163
CHECK_INPUT(hidden_states);
138164
CHECK_INPUT(z);
139165
CHECK_INPUT(h_prime);
140-
CHECK_INPUT(w);
141-
CHECK_INPUT(u);
166+
if(w1.size(0) == 0) {
167+
CHECK_INPUT(w);
168+
} else {
169+
CHECK_INPUT(w1);
170+
CHECK_INPUT(w2);
171+
}
172+
if (u1.size(0) == 0) {
173+
CHECK_INPUT(u);
174+
} else {
175+
CHECK_INPUT(u1);
176+
CHECK_INPUT(u2);
177+
}
142178
CHECK_INPUT(zeta);
143179
CHECK_INPUT(nu);
144180
CHECK_INPUT(initial_h);
@@ -154,7 +190,8 @@ std::vector<torch::Tensor> fastgrnn_unroll_backward(
154190
z,
155191
h_prime,
156192
initial_h,
157-
z_non_linearity);
193+
z_non_linearity,
194+
w1, w2, u1, u2);
158195
}
159196

160197

pytorch/edgeml_pytorch/cuda/fastgrnn_cuda_kernel.cu

Lines changed: 51 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -287,7 +287,11 @@ std::vector<torch::Tensor> fastgrnn_unroll_cuda_forward(
287287
torch::Tensor zeta,
288288
torch::Tensor nu,
289289
torch::Tensor initial_h,
290-
int z_non_linearity) {
290+
int z_non_linearity,
291+
torch::Tensor w1,
292+
torch::Tensor w2,
293+
torch::Tensor u1,
294+
torch::Tensor u2) {
291295
auto options = torch::TensorOptions().dtype(input.dtype()).device(input.device().type());
292296
const auto timesteps = input.size(0);
293297
const auto batch_size = initial_h.size(0);
@@ -305,9 +309,19 @@ std::vector<torch::Tensor> fastgrnn_unroll_cuda_forward(
305309

306310
const int threads = 1024;
307311
const dim3 blocks((state_size + threads - 1) / threads, batch_size);
312+
bool w_low_rank = w1.size(0) != 0;
313+
bool u_low_rank = u1.size(0) != 0;
314+
if (w_low_rank){
315+
w = torch::mm(w1.transpose(0, 1), w2.transpose(0, 1));
316+
} else {
317+
w = w.transpose(0, 1);
318+
}
319+
if (u_low_rank){
320+
u = torch::mm(u1.transpose(0, 1), u2.transpose(0, 1));
321+
} else {
322+
u = u.transpose(0, 1);
323+
}
308324

309-
w = w.transpose(0, 1);
310-
u = u.transpose(0, 1);
311325
zeta = torch::sigmoid(zeta);
312326
nu = torch::sigmoid(nu);
313327

@@ -372,16 +386,29 @@ std::vector<torch::Tensor> fastgrnn_unroll_cuda_backward(
372386
torch::Tensor z,
373387
torch::Tensor h_prime,
374388
torch::Tensor initial_h,
375-
int z_non_linearity) {
389+
int z_non_linearity,
390+
torch::Tensor w1,
391+
torch::Tensor w2,
392+
torch::Tensor u1,
393+
torch::Tensor u2) {
376394

377395
auto d_input = torch::zeros_like(input);
378-
auto d_w = torch::zeros_like(w);
379-
auto d_u = torch::zeros_like(u);
380396
auto d_zeta = torch::zeros_like(initial_h);
381397
auto d_nu = torch::zeros_like(initial_h);
382398
auto d_bias_z = torch::zeros_like(initial_h);
383399
auto d_bias_h_prime = torch::zeros_like(initial_h);
384400

401+
bool w_low_rank = w1.size(0) != 0;
402+
bool u_low_rank = u1.size(0) != 0;
403+
if(w_low_rank) {
404+
w = torch::mm(w2, w1);
405+
}
406+
if (u_low_rank) {
407+
u = torch::mm(u2, u1);
408+
}
409+
auto d_w = torch::zeros_like(w);
410+
auto d_u = torch::zeros_like(u);
411+
385412
zeta = torch::sigmoid(zeta);
386413
nu = torch::sigmoid(nu);
387414
auto d_nu_sigmoid = d_sigmoid(nu);
@@ -468,11 +495,26 @@ std::vector<torch::Tensor> fastgrnn_unroll_cuda_backward(
468495
d_input[t] = torch::mm(d_precomp, w);
469496
d_w = torch::addmm(d_w, d_precomp.transpose(0, 1), input[t]);
470497
d_u = torch::addmm(d_u, d_precomp.transpose(0, 1), prev_h_);
471-
// grad_curr_h = d_old_h;
472498
}
473499
d_bias_z = d_bias_z.sum(0, true);
474500
d_bias_h_prime = d_bias_h_prime.sum(0, true);
475501
d_zeta = (d_zeta.sum(0, true)).sum(1, true);
476502
d_nu = (d_nu.sum(0, true)).sum(1, true);
477-
return {d_input, d_w, d_u, d_bias_z, d_bias_h_prime, d_zeta, d_nu, d_old_h};
478-
}
503+
if (w_low_rank) {
504+
auto d_w1 = torch::mm(w2.transpose(0, 1), d_w);
505+
auto d_w2 = torch::mm(d_w, w1.transpose(0, 1));
506+
d_w = torch::empty(0);
507+
} else {
508+
auto d_w1 = torch::empty(0);
509+
auto d_w2 = torch::empty(0);
510+
}
511+
if(u_low_rank) {
512+
auto d_u1 = torch::mm(u2.transpose(0, 1), d_u);
513+
auto d_u2 = torch::mm(d_u, u1.transpose(0, 1));
514+
d_u = torch::empty(0);
515+
} else {
516+
auto d_u1 = torch::empty(0);
517+
auto d_u2 = torch::empty(0);
518+
}
519+
return {d_input, d_bias_z, d_bias_h_prime, d_zeta, d_nu, d_old_h, d_w, d_u, d_w1, d_w2, d_u1, d_u2};
520+
}

pytorch/edgeml_pytorch/graph/rnn.py

Lines changed: 39 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -329,7 +329,7 @@ def __init__(self, input_size, hidden_size, gate_non_linearity="sigmoid", zetaIn
329329
self._nuInit = nuInit
330330
self._name = name
331331
self._gate_non_linearity = NON_LINEARITY[gate_non_linearity]
332-
self.W = nn.Parameter(0.1 * torch.randn([input_size, hidden_size]))
332+
self.W = nn.Parameter(0.1 * torch.randn([hidden_size, input_size]))
333333
self.U = nn.Parameter(0.1 * torch.randn([hidden_size, hidden_size]))
334334

335335
self.bias_gate = nn.Parameter(torch.ones([1, hidden_size]))
@@ -1065,7 +1065,8 @@ def forward(self, input, hiddenState=None, cellState=None):
10651065

10661066
class FastGRNNCUDA(nn.Module):
10671067
"""Unrolled implementation of the FastGRNNCUDACell"""
1068-
def __init__(self, input_size, hidden_size, gate_non_linearity="sigmoid", zetaInit=1.0, nuInit=-4.0, name="FastGRNNCUDACell"):
1068+
def __init__(self, input_size, hidden_size, gate_nonlinearity="sigmoid",
1069+
update_nonlinearity="tanh", wRank=None, uRank=None, zetaInit=1.0, nuInit=-4.0, name="FastGRNNCUDACell"):
10691070
super(FastGRNNCUDA, self).__init__()
10701071
if utils.findCUDA() is None:
10711072
raise Exception('FastGRNNCUDA is supported only on GPU devices.')
@@ -1075,7 +1076,34 @@ def __init__(self, input_size, hidden_size, gate_non_linearity="sigmoid", zetaIn
10751076
self._zetaInit = zetaInit
10761077
self._nuInit = nuInit
10771078
self._name = name
1078-
self._gate_non_linearity = NON_LINEARITY[gate_non_linearity]
1079+
1080+
if wRank is not None:
1081+
self._num_W_matrices += 1
1082+
self._num_weight_matrices[0] = self._num_W_matrices
1083+
if uRank is not None:
1084+
self._num_U_matrices += 1
1085+
self._num_weight_matrices[1] = self._num_U_matrices
1086+
self._name = name
1087+
1088+
if wRank is None:
1089+
self.W = nn.Parameter(0.1 * torch.randn([hidden_size, input_size]))
1090+
self.W1 = torch.empty(0)
1091+
self.W2 = torch.empty(0)
1092+
else:
1093+
self.W = torch.empty(0)
1094+
self.W1 = nn.Parameter(0.1 * torch.randn([wRank, input_size]))
1095+
self.W2 = nn.Parameter(0.1 * torch.randn([hidden_size, wRank]))
1096+
1097+
if uRank is None:
1098+
self.U = nn.Parameter(0.1 * torch.randn([hidden_size, hidden_size]))
1099+
self.U1 = torch.empty(0)
1100+
self.U2 = torch.empty(0)
1101+
else:
1102+
self.U = torch.empty(0)
1103+
self.U1 = nn.Parameter(0.1 * torch.randn([uRank, hidden_size]))
1104+
self.U2 = nn.Parameter(0.1 * torch.randn([hidden_size, uRank]))
1105+
1106+
self._gate_non_linearity = NON_LINEARITY[gate_nonlinearity]
10791107
self.W = nn.Parameter(0.1 * torch.randn([input_size, hidden_size]))
10801108
self.U = nn.Parameter(0.1 * torch.randn([hidden_size, hidden_size]))
10811109

@@ -1086,9 +1114,12 @@ def __init__(self, input_size, hidden_size, gate_non_linearity="sigmoid", zetaIn
10861114

10871115
def forward(self, input, h_state, cell_state=None):
10881116
# input: [timesteps, batch, features, state_size]
1089-
return FastGRNNUnrollFunction.apply(input, self.W, self.U, self.bias_gate, self.bias_update, self.zeta, self.nu, h_state, self._gate_non_linearity)
1117+
return FastGRNNUnrollFunction.apply(input, self.bias_gate, self.bias_update, self.zeta, self.nu, h_state,
1118+
self.W, self.U, self.W1, self.W2, self.U1, self.U2, self._gate_non_linearity)
10901119

10911120
def getVars(self):
1121+
if self._num_W_matrices != 1:
1122+
return [self.W1, self.W2, self.U1, self.U2, self.bias_gate, self.bias_update, self.zeta, self.nu]
10921123
return [self.W, self.U, self.bias_gate, self.bias_update, self.zeta, self.nu]
10931124

10941125
class SRNN2(nn.Module):
@@ -1225,10 +1256,10 @@ def backward(ctx, grad_h):
12251256

12261257
class FastGRNNUnrollFunction(Function):
12271258
@staticmethod
1228-
def forward(ctx, input, w, u, bias_gate, bias_update, zeta, nu, old_h, gate_non_linearity):
1229-
outputs = fastgrnn_cuda.forward_unroll(input, w, u, bias_gate, bias_update, zeta, nu, old_h, gate_non_linearity)
1259+
def forward(ctx, input, bias_gate, bias_update, zeta, nu, old_h, w, u, w1, w2, u1, u2, gate_non_linearity):
1260+
outputs = fastgrnn_cuda.forward_unroll(input, w, u, bias_gate, bias_update, zeta, nu, old_h, gate_non_linearity, w1, w2, u1, u2)
12301261
hidden_states = outputs[0]
1231-
variables = [input, hidden_states, zeta, nu, w, u] + outputs[1:] + [old_h]
1262+
variables = [input, hidden_states, zeta, nu, w, u] + outputs[1:] + [old_h, w1, w2, u1, u2]
12321263
ctx.save_for_backward(*variables)
12331264
ctx.gate_non_linearity = gate_non_linearity
12341265
return hidden_states
@@ -1237,5 +1268,4 @@ def forward(ctx, input, w, u, bias_gate, bias_update, zeta, nu, old_h, gate_non
12371268
def backward(ctx, grad_h):
12381269
outputs = fastgrnn_cuda.backward_unroll(
12391270
grad_h.contiguous(), *ctx.saved_variables, ctx.gate_non_linearity)
1240-
d_input, d_w, d_u, d_bias_gate, d_bias_update, d_zeta, d_nu, d_old_h = outputs
1241-
return d_input, d_w, d_u, d_bias_gate, d_bias_update, d_zeta, d_nu, d_old_h
1271+
return tuple(outputs + [None])

0 commit comments

Comments
 (0)