Skip to content

Commit 23f6e0c

Browse files
committed
fastgrnncuda: add low rank support to cell
1 parent cb1e26f commit 23f6e0c

File tree

3 files changed

+188
-72
lines changed

3 files changed

+188
-72
lines changed

pytorch/edgeml_pytorch/cuda/fastgrnn_cuda.cpp

Lines changed: 80 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -3,27 +3,35 @@
33
#include <vector>
44

55
std::vector<torch::Tensor> fastgrnn_cuda_forward(
6-
torch::Tensor input,
7-
torch::Tensor W,
8-
torch::Tensor U,
9-
torch::Tensor bias_gate,
10-
torch::Tensor bias_update,
11-
torch::Tensor zeta,
12-
torch::Tensor nu,
13-
torch::Tensor old_h,
14-
int z_non_linearity);
6+
torch::Tensor input,
7+
torch::Tensor w,
8+
torch::Tensor u,
9+
torch::Tensor bias_gate,
10+
torch::Tensor bias_update,
11+
torch::Tensor zeta,
12+
torch::Tensor nu,
13+
torch::Tensor old_h,
14+
int z_non_linearity,
15+
torch::Tensor w1,
16+
torch::Tensor w2,
17+
torch::Tensor u1,
18+
torch::Tensor u2);
1519

1620
std::vector<torch::Tensor> fastgrnn_cuda_backward(
17-
torch::Tensor grad_h,
18-
torch::Tensor input,
19-
torch::Tensor old_h,
20-
torch::Tensor zeta,
21-
torch::Tensor nu,
22-
torch::Tensor W,
23-
torch::Tensor U,
24-
int z_non_linearity,
25-
torch::Tensor z,
26-
torch::Tensor h_prime);
21+
torch::Tensor grad_h,
22+
torch::Tensor input,
23+
torch::Tensor old_h,
24+
torch::Tensor zeta,
25+
torch::Tensor nu,
26+
torch::Tensor w,
27+
torch::Tensor u,
28+
int z_non_linearity,
29+
torch::Tensor z,
30+
torch::Tensor h_prime,
31+
torch::Tensor w1,
32+
torch::Tensor w2,
33+
torch::Tensor u1,
34+
torch::Tensor u2);
2735

2836
std::vector<torch::Tensor> fastgrnn_unroll_cuda_forward(
2937
torch::Tensor input,
@@ -62,49 +70,77 @@ std::vector<torch::Tensor> fastgrnn_unroll_cuda_backward(
6270
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
6371

6472
std::vector<torch::Tensor> fastgrnn_forward(
65-
torch::Tensor input,
66-
torch::Tensor W,
67-
torch::Tensor U,
68-
torch::Tensor bias_gate,
69-
torch::Tensor bias_update,
70-
torch::Tensor zeta,
71-
torch::Tensor nu,
72-
torch::Tensor old_h,
73-
int z_non_linearity) {
73+
torch::Tensor input,
74+
torch::Tensor w,
75+
torch::Tensor u,
76+
torch::Tensor bias_gate,
77+
torch::Tensor bias_update,
78+
torch::Tensor zeta,
79+
torch::Tensor nu,
80+
torch::Tensor old_h,
81+
int z_non_linearity,
82+
torch::Tensor w1,
83+
torch::Tensor w2,
84+
torch::Tensor u1,
85+
torch::Tensor u2) {
7486
CHECK_INPUT(input);
75-
CHECK_INPUT(W);
76-
CHECK_INPUT(U);
87+
if(w1.size(0) == 0) {
88+
CHECK_INPUT(w);
89+
} else {
90+
CHECK_INPUT(w1);
91+
CHECK_INPUT(w2);
92+
}
93+
if (u1.size(0) == 0) {
94+
CHECK_INPUT(u);
95+
} else {
96+
CHECK_INPUT(u1);
97+
CHECK_INPUT(u2);
98+
}
7799
CHECK_INPUT(bias_gate);
78100
CHECK_INPUT(bias_update);
79101
CHECK_INPUT(zeta);
80102
CHECK_INPUT(nu);
81103
CHECK_INPUT(old_h);
82104

83-
return fastgrnn_cuda_forward(input, W, U, bias_gate, bias_update, zeta, nu, old_h, z_non_linearity);
105+
return fastgrnn_cuda_forward(input, w, u, bias_gate, bias_update, zeta, nu, old_h, z_non_linearity, w1, w2, u1, u2);
84106
}
85107

86108
std::vector<torch::Tensor> fastgrnn_backward(
87-
torch::Tensor grad_h,
88-
torch::Tensor input,
89-
torch::Tensor old_h,
90-
torch::Tensor zeta,
91-
torch::Tensor nu,
92-
torch::Tensor W,
93-
torch::Tensor U,
94-
torch::Tensor z,
95-
torch::Tensor h_prime,
96-
int z_non_linearity) {
109+
torch::Tensor grad_h,
110+
torch::Tensor input,
111+
torch::Tensor old_h,
112+
torch::Tensor zeta,
113+
torch::Tensor nu,
114+
torch::Tensor w,
115+
torch::Tensor u,
116+
torch::Tensor z,
117+
torch::Tensor h_prime,
118+
torch::Tensor w1,
119+
torch::Tensor w2,
120+
torch::Tensor u1,
121+
torch::Tensor u2,
122+
int z_non_linearity) {
97123
CHECK_INPUT(grad_h);
98124
CHECK_INPUT(input);
99125
CHECK_INPUT(old_h);
100126
CHECK_INPUT(zeta);
101127
CHECK_INPUT(nu);
102128
CHECK_INPUT(z);
103129
CHECK_INPUT(h_prime);
104-
CHECK_INPUT(W);
105-
CHECK_INPUT(U);
130+
if(w1.size(0) == 0) {
131+
CHECK_INPUT(w);
132+
} else {
133+
CHECK_INPUT(w1);
134+
CHECK_INPUT(w2);
135+
}
136+
if (u1.size(0) == 0) {
137+
CHECK_INPUT(u);
138+
} else {
139+
CHECK_INPUT(u1);
140+
CHECK_INPUT(u2);
141+
}
106142

107-
return fastgrnn_cuda_backward(grad_h, input, old_h, zeta, nu, W, U, z_non_linearity, z, h_prime);
143+
return fastgrnn_cuda_backward(grad_h, input, old_h, zeta, nu, w, u, z_non_linearity, z, h_prime, w1, w2, u1, u2);
108144
}
109145

110146
std::vector<torch::Tensor> fastgrnn_unroll_forward(

pytorch/edgeml_pytorch/cuda/fastgrnn_cuda_kernel.cu

Lines changed: 44 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -129,8 +129,21 @@ std::vector<torch::Tensor> fastgrnn_cuda_forward(
129129
torch::Tensor zeta,
130130
torch::Tensor nu,
131131
torch::Tensor old_h,
132-
int z_non_linearity) {
133-
132+
int z_non_linearity,
133+
torch::Tensor w1,
134+
torch::Tensor w2,
135+
torch::Tensor u1,
136+
torch::Tensor u2) {
137+
138+
bool w_low_rank = w1.size(0) != 0;
139+
bool u_low_rank = u1.size(0) != 0;
140+
if (w_low_rank){
141+
w = torch::mm(w2, w1);
142+
}
143+
if (u_low_rank){
144+
u = torch::mm(u2, u1);
145+
}
146+
134147
auto pre_comp = torch::addmm(torch::mm(input, w.transpose(0, 1)), old_h, u.transpose(0, 1));
135148
nu = torch::sigmoid(nu);
136149
zeta = torch::sigmoid(zeta);
@@ -194,13 +207,30 @@ std::vector<torch::Tensor> fastgrnn_cuda_backward(
194207
torch::Tensor u,
195208
int z_non_linearity,
196209
torch::Tensor z,
197-
torch::Tensor h_prime) {
210+
torch::Tensor h_prime,
211+
torch::Tensor w1,
212+
torch::Tensor w2,
213+
torch::Tensor u1,
214+
torch::Tensor u2) {
198215
auto d_precomp = torch::zeros_like(old_h);
199216
auto d_bias_z = torch::zeros_like(old_h);
200217
auto d_bias_h_prime = torch::zeros_like(old_h);
201218
auto d_nu = torch::zeros_like(old_h);
202219
auto d_zeta = torch::zeros_like(old_h);
203220
auto d_old_h = torch::zeros_like(old_h);
221+
auto d_w1 = torch::empty(0);
222+
auto d_w2 = torch::empty(0);
223+
auto d_u1 = torch::empty(0);
224+
auto d_u2 = torch::empty(0);
225+
226+
bool w_low_rank = w1.size(0) != 0;
227+
bool u_low_rank = u1.size(0) != 0;
228+
if(w_low_rank) {
229+
w = torch::mm(w2, w1);
230+
}
231+
if (u_low_rank) {
232+
u = torch::mm(u2, u1);
233+
}
204234
zeta = torch::sigmoid(zeta);
205235
nu = torch::sigmoid(nu);
206236
auto d_nu_sigmoid = d_sigmoid(nu);
@@ -274,8 +304,17 @@ std::vector<torch::Tensor> fastgrnn_cuda_backward(
274304
d_bias_h_prime = d_bias_h_prime.sum(0, true);
275305
d_zeta = (d_zeta.sum(0, true)).sum(1, true);
276306
d_nu = (d_nu.sum(0, true)).sum(1, true);
277-
278-
return {d_input, d_w, d_u, d_bias_z, d_bias_h_prime, d_zeta, d_nu, d_old_h};
307+
if (w_low_rank) {
308+
d_w1 = torch::mm(w2.transpose(0, 1), d_w);
309+
d_w2 = torch::mm(d_w, w1.transpose(0, 1));
310+
d_w = torch::empty(0);
311+
}
312+
if(u_low_rank) {
313+
d_u1 = torch::mm(u2.transpose(0, 1), d_u);
314+
d_u2 = torch::mm(d_u, u1.transpose(0, 1));
315+
d_u = torch::empty(0);
316+
}
317+
return {d_input, d_bias_z, d_bias_h_prime, d_zeta, d_nu, d_old_h, d_w, d_u, d_w1, d_w2, d_u1, d_u2};
279318
}
280319

281320
std::vector<torch::Tensor> fastgrnn_unroll_cuda_forward(

pytorch/edgeml_pytorch/graph/rnn.py

Lines changed: 64 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -318,30 +318,51 @@ class FastGRNNCUDACell(RNNCell):
318318
h_t = z_t*h_{t-1} + (sigmoid(zeta)(1-z_t) + sigmoid(nu))*h_t^
319319
320320
'''
321-
def __init__(self, input_size, hidden_size, gate_non_linearity="sigmoid", zetaInit=1.0, nuInit=-4.0, name="FastGRNNCUDACell"):
322-
super(FastGRNNCUDACell, self).__init__(input_size, hidden_size, gate_non_linearity, "tanh", 1, 1, 2)
321+
def __init__(self, input_size, hidden_size, gate_nonlinearity="sigmoid",
322+
update_nonlinearity="tanh", wRank=None, uRank=None, zetaInit=1.0, nuInit=-4.0, name="FastGRNNCUDACell"):
323+
super(FastGRNNCUDACell, self).__init__(input_size, hidden_size, gate_non_linearity, update_nonlinearity, 1, 1, 2, wRank, uRank)
323324
if utils.findCUDA() is None:
324-
raise Exception('FastGRNNCUDACell is supported only on GPU devices.')
325+
raise Exception('FastGRNNCUDA is supported only on GPU devices.')
325326
NON_LINEARITY = {"sigmoid": 0, "relu": 1, "tanh": 2}
326327
self._input_size = input_size
327328
self._hidden_size = hidden_size
328329
self._zetaInit = zetaInit
329330
self._nuInit = nuInit
330331
self._name = name
331-
self._gate_non_linearity = NON_LINEARITY[gate_non_linearity]
332-
self.W = nn.Parameter(0.1 * torch.randn([hidden_size, input_size]))
333-
self.U = nn.Parameter(0.1 * torch.randn([hidden_size, hidden_size]))
332+
333+
if wRank is not None:
334+
self._num_W_matrices += 1
335+
self._num_weight_matrices[0] = self._num_W_matrices
336+
if uRank is not None:
337+
self._num_U_matrices += 1
338+
self._num_weight_matrices[1] = self._num_U_matrices
339+
self._name = name
340+
341+
if wRank is None:
342+
self.W = nn.Parameter(0.1 * torch.randn([hidden_size, input_size]))
343+
self.W1 = torch.empty(0)
344+
self.W2 = torch.empty(0)
345+
else:
346+
self.W = torch.empty(0)
347+
self.W1 = nn.Parameter(0.1 * torch.randn([wRank, input_size]))
348+
self.W2 = nn.Parameter(0.1 * torch.randn([hidden_size, wRank]))
349+
350+
if uRank is None:
351+
self.U = nn.Parameter(0.1 * torch.randn([hidden_size, hidden_size]))
352+
self.U1 = torch.empty(0)
353+
self.U2 = torch.empty(0)
354+
else:
355+
self.U = torch.empty(0)
356+
self.U1 = nn.Parameter(0.1 * torch.randn([uRank, hidden_size]))
357+
self.U2 = nn.Parameter(0.1 * torch.randn([hidden_size, uRank]))
358+
359+
self._gate_non_linearity = NON_LINEARITY[gate_nonlinearity]
334360

335361
self.bias_gate = nn.Parameter(torch.ones([1, hidden_size]))
336362
self.bias_update = nn.Parameter(torch.ones([1, hidden_size]))
337363
self.zeta = nn.Parameter(self._zetaInit * torch.ones([1, 1]))
338364
self.nu = nn.Parameter(self._nuInit * torch.ones([1, 1]))
339365

340-
def reset_parameters(self):
341-
stdv = 1.0 / math.sqrt(self.state_size)
342-
for weight in self.parameters():
343-
weight.data.uniform_(-stdv, +stdv)
344-
345366
@property
346367
def name(self):
347368
return self._name
@@ -352,10 +373,23 @@ def cellType(self):
352373

353374
def forward(self, input, state):
354375
# Calls the custom autograd function while invokes the CUDA implementation
355-
return FastGRNNFunction.apply(input, self.W, self.U, self.bias_gate, self.bias_update, self.zeta, self.nu, state, self._gate_non_linearity)
376+
return FastGRNNFunction.apply(input, self.bias_gate, self.bias_update, self.zeta, self.nu, h_state,
377+
self.W, self.U, self.W1, self.W2, self.U1, self.U2, self._gate_non_linearity)
356378

357379
def getVars(self):
358-
return [self.W, self.U, self.bias_gate, self.bias_update, self.zeta, self.nu]
380+
Vars = []
381+
if self._num_W_matrices == 1:
382+
Vars.append(self.W)
383+
else:
384+
Vars.extend([self.W1, self.W2])
385+
386+
if self._num_U_matrices == 1:
387+
Vars.append(self.U)
388+
else:
389+
Vars.extend([self.U1, self.U2])
390+
391+
Vars.extend([self.bias_gate, self.bias_update, self.zeta, self.nu])
392+
return Vars
359393

360394
class FastRNNCell(RNNCell):
361395
'''
@@ -1104,8 +1138,6 @@ def __init__(self, input_size, hidden_size, gate_nonlinearity="sigmoid",
11041138
self.U2 = nn.Parameter(0.1 * torch.randn([hidden_size, uRank]))
11051139

11061140
self._gate_non_linearity = NON_LINEARITY[gate_nonlinearity]
1107-
self.W = nn.Parameter(0.1 * torch.randn([input_size, hidden_size]))
1108-
self.U = nn.Parameter(0.1 * torch.randn([hidden_size, hidden_size]))
11091141

11101142
self.bias_gate = nn.Parameter(torch.ones([1, hidden_size]))
11111143
self.bias_update = nn.Parameter(torch.ones([1, hidden_size]))
@@ -1118,9 +1150,19 @@ def forward(self, input, h_state, cell_state=None):
11181150
self.W, self.U, self.W1, self.W2, self.U1, self.U2, self._gate_non_linearity)
11191151

11201152
def getVars(self):
1121-
if self._num_W_matrices != 1:
1122-
return [self.W1, self.W2, self.U1, self.U2, self.bias_gate, self.bias_update, self.zeta, self.nu]
1123-
return [self.W, self.U, self.bias_gate, self.bias_update, self.zeta, self.nu]
1153+
Vars = []
1154+
if self._num_W_matrices == 1:
1155+
Vars.append(self.W)
1156+
else:
1157+
Vars.extend([self.W1, self.W2])
1158+
1159+
if self._num_U_matrices == 1:
1160+
Vars.append(self.U)
1161+
else:
1162+
Vars.extend([self.U1, self.U2])
1163+
1164+
Vars.extend([self.bias_gate, self.bias_update, self.zeta, self.nu])
1165+
return Vars
11241166

11251167
class SRNN2(nn.Module):
11261168

@@ -1239,10 +1281,10 @@ def forward(self, x, brickSize):
12391281

12401282
class FastGRNNFunction(Function):
12411283
@staticmethod
1242-
def forward(ctx, input, w, u, bias_gate, bias_update, zeta, nu, old_h, gate_non_linearity):
1243-
outputs = fastgrnn_cuda.forward(input, w, u, bias_gate, bias_update, zeta, nu, old_h, gate_non_linearity)
1284+
def forward(ctx, input, bias_gate, bias_update, zeta, nu, old_h, w, u, w1, w2, u1, u2, gate_non_linearity):
1285+
outputs = fastgrnn_cuda.forward(input, w, u, bias_gate, bias_update, zeta, nu, old_h, gate_non_linearity, w1, w2, u1, u2)
12441286
new_h = outputs[0]
1245-
variables = [input, old_h, zeta, nu, w, u] + outputs[1:]
1287+
variables = [input, old_h, zeta, nu, w, u] + outputs[1:] + [w1, w2, u1, u2]
12461288
ctx.save_for_backward(*variables)
12471289
ctx.non_linearity = gate_non_linearity
12481290
return new_h
@@ -1251,8 +1293,7 @@ def forward(ctx, input, w, u, bias_gate, bias_update, zeta, nu, old_h, gate_non_
12511293
def backward(ctx, grad_h):
12521294
outputs = fastgrnn_cuda.backward(
12531295
grad_h.contiguous(), *ctx.saved_variables, ctx.non_linearity)
1254-
d_input, d_w, d_u, d_bias_gate, d_bias_update, d_zeta, d_nu, d_old_h = outputs
1255-
return d_input, d_w, d_u, d_bias_gate, d_bias_update, d_zeta, d_nu, d_old_h, None
1296+
return tuple(outputs + [None])
12561297

12571298
class FastGRNNUnrollFunction(Function):
12581299
@staticmethod

0 commit comments

Comments
 (0)