Skip to content

Commit 0b70d5e

Browse files
committed
FastGRNNCUDA: added unrolled implementation
1 parent e4ce97f commit 0b70d5e

File tree

3 files changed

+359
-0
lines changed

3 files changed

+359
-0
lines changed

pytorch/edgeml_pytorch/cuda/fastgrnn_cuda.cpp

Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,30 @@ std::vector<torch::Tensor> fastgrnn_cuda_backward(
2525
torch::Tensor z,
2626
torch::Tensor h_prime);
2727

28+
std::vector<torch::Tensor> fastgrnn_unroll_cuda_forward(
29+
torch::Tensor input,
30+
torch::Tensor w,
31+
torch::Tensor u,
32+
torch::Tensor bias_z,
33+
torch::Tensor bias_h_prime,
34+
torch::Tensor zeta,
35+
torch::Tensor nu,
36+
torch::Tensor initial_h,
37+
int z_non_linearity);
38+
39+
std::vector<torch::Tensor> fastgrnn_unroll_cuda_backward(
40+
torch::Tensor grad_h,
41+
torch::Tensor input,
42+
torch::Tensor hidden_states,
43+
torch::Tensor zeta,
44+
torch::Tensor nu,
45+
torch::Tensor w,
46+
torch::Tensor u,
47+
torch::Tensor z,
48+
torch::Tensor h_prime,
49+
torch::Tensor initial_h,
50+
int z_non_linearity);
51+
2852
#define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor")
2953
#define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous")
3054
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
@@ -75,7 +99,68 @@ std::vector<torch::Tensor> fastgrnn_backward(
7599
return fastgrnn_cuda_backward(grad_h, input, old_h, zeta, nu, W, U, z_non_linearity, z, h_prime);
76100
}
77101

102+
std::vector<torch::Tensor> fastgrnn_unroll_forward(
103+
torch::Tensor input,
104+
torch::Tensor w,
105+
torch::Tensor u,
106+
torch::Tensor bias_z,
107+
torch::Tensor bias_h_prime,
108+
torch::Tensor zeta,
109+
torch::Tensor nu,
110+
torch::Tensor initial_h,
111+
int z_non_linearity) {
112+
CHECK_INPUT(input);
113+
CHECK_INPUT(w);
114+
CHECK_INPUT(u);
115+
CHECK_INPUT(bias_z);
116+
CHECK_INPUT(bias_h_prime);
117+
CHECK_INPUT(initial_h);
118+
CHECK_INPUT(zeta);
119+
CHECK_INPUT(nu);
120+
return fastgrnn_unroll_cuda_forward(input, w, u, bias_z, bias_h_prime, zeta, nu, initial_h, z_non_linearity);
121+
}
122+
123+
std::vector<torch::Tensor> fastgrnn_unroll_backward(
124+
torch::Tensor grad_h,
125+
torch::Tensor input,
126+
torch::Tensor hidden_states,
127+
torch::Tensor zeta,
128+
torch::Tensor nu,
129+
torch::Tensor w,
130+
torch::Tensor u,
131+
torch::Tensor z,
132+
torch::Tensor h_prime,
133+
torch::Tensor initial_h,
134+
int z_non_linearity) {
135+
CHECK_INPUT(grad_h);
136+
CHECK_INPUT(input);
137+
CHECK_INPUT(hidden_states);
138+
CHECK_INPUT(z);
139+
CHECK_INPUT(h_prime);
140+
CHECK_INPUT(w);
141+
CHECK_INPUT(u);
142+
CHECK_INPUT(zeta);
143+
CHECK_INPUT(nu);
144+
CHECK_INPUT(initial_h);
145+
146+
return fastgrnn_unroll_cuda_backward(
147+
grad_h,
148+
input,
149+
hidden_states,
150+
zeta,
151+
nu,
152+
w,
153+
u,
154+
z,
155+
h_prime,
156+
initial_h,
157+
z_non_linearity);
158+
}
159+
160+
78161
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
79162
m.def("forward", &fastgrnn_forward, "FastGRNN forward (CUDA)");
80163
m.def("backward", &fastgrnn_backward, "FastGRNN backward (CUDA)");
164+
m.def("forward_unroll", &fastgrnn_unroll_forward, "FastGRNN Unrolled forward (CUDA)");
165+
m.def("backward_unroll", &fastgrnn_unroll_backward, "FastGRNN Unrolled backward (CUDA)");
81166
}

pytorch/edgeml_pytorch/cuda/fastgrnn_cuda_kernel.cu

Lines changed: 230 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,37 @@ __global__ void fastgrnn_cuda_backward_kernel(
8787
d_nu[n][c] = h_prime[n][c] * grad_h[n][c] * d_nu_sigmoid[0][0];
8888
}
8989
}
90+
91+
template <typename scalar_t, scalar_t (*d_non_linearity) (scalar_t)>
92+
__global__ void fastgrnn_unroll_cuda_backward_kernel(
93+
torch::PackedTensorAccessor<scalar_t,2,torch::RestrictPtrTraits,size_t> d_precomp,
94+
torch::PackedTensorAccessor<scalar_t,2,torch::RestrictPtrTraits,size_t> d_old_h,
95+
torch::PackedTensorAccessor<scalar_t,2,torch::RestrictPtrTraits,size_t> d_bias_z,
96+
torch::PackedTensorAccessor<scalar_t,2,torch::RestrictPtrTraits,size_t> d_bias_h_prime,
97+
torch::PackedTensorAccessor<scalar_t,2,torch::RestrictPtrTraits,size_t> d_nu,
98+
torch::PackedTensorAccessor<scalar_t,2,torch::RestrictPtrTraits,size_t> d_zeta,
99+
const torch::PackedTensorAccessor<scalar_t,2,torch::RestrictPtrTraits,size_t> grad_h,
100+
const torch::PackedTensorAccessor<scalar_t,2,torch::RestrictPtrTraits,size_t> z,
101+
const torch::PackedTensorAccessor<scalar_t,2,torch::RestrictPtrTraits,size_t> h_prime,
102+
const torch::PackedTensorAccessor<scalar_t,2,torch::RestrictPtrTraits,size_t> zeta,
103+
const torch::PackedTensorAccessor<scalar_t,2,torch::RestrictPtrTraits,size_t> nu,
104+
const torch::PackedTensorAccessor<scalar_t,2,torch::RestrictPtrTraits,size_t> d_zeta_sigmoid,
105+
const torch::PackedTensorAccessor<scalar_t,2,torch::RestrictPtrTraits,size_t> d_nu_sigmoid,
106+
const torch::PackedTensorAccessor<scalar_t,2,torch::RestrictPtrTraits,size_t> old_h) {
107+
const int n = blockIdx.y;
108+
const int c = blockIdx.x * blockDim.x + threadIdx.x;
109+
if (c < old_h.size(1)){
110+
d_old_h[n][c] = z[n][c] * grad_h[n][c];
111+
scalar_t temp_bias_h_prime = (zeta[0][0] * (1.0 - z[n][c]) + nu[0][0]) * d_tanh(h_prime[n][c]) * grad_h[n][c];
112+
scalar_t temp_bias_z = (old_h[n][c] - zeta[0][0] * h_prime[n][c]) * d_non_linearity(z[n][c]) * grad_h[n][c];
113+
d_bias_h_prime[n][c] += temp_bias_h_prime;
114+
d_bias_z[n][c] += temp_bias_z;
115+
d_precomp[n][c] = temp_bias_z + temp_bias_h_prime;
116+
d_zeta[n][c] += (1.0 - z[n][c]) * h_prime[n][c] * grad_h[n][c] * d_zeta_sigmoid[0][0];
117+
d_nu[n][c] += h_prime[n][c] * grad_h[n][c] * d_nu_sigmoid[0][0];
118+
}
119+
}
120+
90121
} // namespace
91122

92123
std::vector<torch::Tensor> fastgrnn_cuda_forward(
@@ -246,3 +277,202 @@ std::vector<torch::Tensor> fastgrnn_cuda_backward(
246277

247278
return {d_input, d_w, d_u, d_bias_z, d_bias_h_prime, d_zeta, d_nu, d_old_h};
248279
}
280+
281+
std::vector<torch::Tensor> fastgrnn_unroll_cuda_forward(
282+
torch::Tensor input,
283+
torch::Tensor w,
284+
torch::Tensor u,
285+
torch::Tensor bias_z,
286+
torch::Tensor bias_h_prime,
287+
torch::Tensor zeta,
288+
torch::Tensor nu,
289+
torch::Tensor initial_h,
290+
int z_non_linearity) {
291+
auto options = torch::TensorOptions().dtype(input.dtype()).device(input.device().type());
292+
const auto timesteps = input.size(0);
293+
const auto batch_size = initial_h.size(0);
294+
const auto state_size = initial_h.size(1);
295+
296+
auto hidden_states = torch::zeros({timesteps, batch_size, state_size}, options);
297+
auto z_s = torch::zeros_like(hidden_states);
298+
auto h_prime_s = torch::zeros_like(hidden_states);
299+
300+
auto prev_h = initial_h;
301+
auto new_h = torch::zeros_like(prev_h);
302+
auto z = torch::zeros_like(prev_h);
303+
auto h_prime = torch::zeros_like(prev_h);
304+
auto pre_comp = torch::zeros_like(prev_h);
305+
306+
const int threads = 1024;
307+
const dim3 blocks((state_size + threads - 1) / threads, batch_size);
308+
309+
w = w.transpose(0, 1);
310+
u = u.transpose(0, 1);
311+
zeta = torch::sigmoid(zeta);
312+
nu = torch::sigmoid(nu);
313+
314+
for (int t=0; t < timesteps; t++) {
315+
pre_comp = torch::addmm(torch::mm(input[t], w), prev_h, u);
316+
317+
if (z_non_linearity == 0)
318+
AT_DISPATCH_FLOATING_TYPES(pre_comp.type(), "fastgrnn_forward_cuda", ([&] {
319+
fastgrnn_cuda_forward_kernel<scalar_t, sigmoid><<<blocks, threads>>>(
320+
new_h.packed_accessor<scalar_t,2,torch::RestrictPtrTraits,size_t>(),
321+
z.packed_accessor<scalar_t,2,torch::RestrictPtrTraits,size_t>(),
322+
h_prime.packed_accessor<scalar_t,2,torch::RestrictPtrTraits,size_t>(),
323+
pre_comp.packed_accessor<scalar_t,2,torch::RestrictPtrTraits,size_t>(),
324+
bias_z.packed_accessor<scalar_t,2,torch::RestrictPtrTraits,size_t>(),
325+
bias_h_prime.packed_accessor<scalar_t,2,torch::RestrictPtrTraits,size_t>(),
326+
nu.packed_accessor<scalar_t,2,torch::RestrictPtrTraits,size_t>(),
327+
zeta.packed_accessor<scalar_t,2,torch::RestrictPtrTraits,size_t>(),
328+
prev_h.packed_accessor<scalar_t,2,torch::RestrictPtrTraits,size_t>());
329+
}));
330+
else if(z_non_linearity == 1)
331+
AT_DISPATCH_FLOATING_TYPES(pre_comp.type(), "fastgrnn_forward_cuda", ([&] {
332+
fastgrnn_cuda_forward_kernel<scalar_t, relu><<<blocks, threads>>>(
333+
new_h.packed_accessor<scalar_t,2,torch::RestrictPtrTraits,size_t>(),
334+
z.packed_accessor<scalar_t,2,torch::RestrictPtrTraits,size_t>(),
335+
h_prime.packed_accessor<scalar_t,2,torch::RestrictPtrTraits,size_t>(),
336+
pre_comp.packed_accessor<scalar_t,2,torch::RestrictPtrTraits,size_t>(),
337+
bias_z.packed_accessor<scalar_t,2,torch::RestrictPtrTraits,size_t>(),
338+
bias_h_prime.packed_accessor<scalar_t,2,torch::RestrictPtrTraits,size_t>(),
339+
nu.packed_accessor<scalar_t,2,torch::RestrictPtrTraits,size_t>(),
340+
zeta.packed_accessor<scalar_t,2,torch::RestrictPtrTraits,size_t>(),
341+
prev_h.packed_accessor<scalar_t,2,torch::RestrictPtrTraits,size_t>());
342+
}));
343+
else if (z_non_linearity == 2)
344+
AT_DISPATCH_FLOATING_TYPES(pre_comp.type(), "fastgrnn_forward_cuda", ([&] {
345+
fastgrnn_cuda_forward_kernel<scalar_t, tanh><<<blocks, threads>>>(
346+
new_h.packed_accessor<scalar_t,2,torch::RestrictPtrTraits,size_t>(),
347+
z.packed_accessor<scalar_t,2,torch::RestrictPtrTraits,size_t>(),
348+
h_prime.packed_accessor<scalar_t,2,torch::RestrictPtrTraits,size_t>(),
349+
pre_comp.packed_accessor<scalar_t,2,torch::RestrictPtrTraits,size_t>(),
350+
bias_z.packed_accessor<scalar_t,2,torch::RestrictPtrTraits,size_t>(),
351+
bias_h_prime.packed_accessor<scalar_t,2,torch::RestrictPtrTraits,size_t>(),
352+
nu.packed_accessor<scalar_t,2,torch::RestrictPtrTraits,size_t>(),
353+
zeta.packed_accessor<scalar_t,2,torch::RestrictPtrTraits,size_t>(),
354+
prev_h.packed_accessor<scalar_t,2,torch::RestrictPtrTraits,size_t>());
355+
}));
356+
hidden_states[t] = new_h;
357+
z_s[t] = z;
358+
h_prime_s[t] = h_prime;
359+
prev_h = new_h;
360+
}
361+
return {hidden_states, z_s, h_prime_s};
362+
}
363+
364+
std::vector<torch::Tensor> fastgrnn_unroll_cuda_backward(
365+
torch::Tensor grad_h,
366+
torch::Tensor input,
367+
torch::Tensor hidden_states,
368+
torch::Tensor zeta,
369+
torch::Tensor nu,
370+
torch::Tensor w,
371+
torch::Tensor u,
372+
torch::Tensor z,
373+
torch::Tensor h_prime,
374+
torch::Tensor initial_h,
375+
int z_non_linearity) {
376+
377+
auto d_input = torch::zeros_like(input);
378+
auto d_w = torch::zeros_like(w);
379+
auto d_u = torch::zeros_like(u);
380+
auto d_zeta = torch::zeros_like(initial_h);
381+
auto d_nu = torch::zeros_like(initial_h);
382+
auto d_bias_z = torch::zeros_like(initial_h);
383+
auto d_bias_h_prime = torch::zeros_like(initial_h);
384+
385+
zeta = torch::sigmoid(zeta);
386+
nu = torch::sigmoid(nu);
387+
auto d_nu_sigmoid = d_sigmoid(nu);
388+
auto d_zeta_sigmoid = d_sigmoid(zeta);
389+
390+
391+
auto grad_curr_h = torch::zeros_like(initial_h);
392+
auto d_precomp = torch::zeros_like(initial_h);
393+
auto d_old_h = torch::zeros_like(initial_h);
394+
auto prev_h_ = hidden_states[0];
395+
auto z_t_ = torch::zeros_like(initial_h);
396+
auto h_prime_t_ = torch::zeros_like(initial_h);
397+
398+
const auto batch_size = hidden_states.size(1);
399+
const auto state_size = hidden_states.size(2);
400+
401+
const int threads = 1024;
402+
const dim3 blocks((state_size + threads - 1) / threads, batch_size);
403+
for (auto t = hidden_states.size(0) - 1; t>=0; t--) {
404+
grad_curr_h = torch::add(grad_h[t], d_old_h);
405+
z_t_ = z[t];
406+
h_prime_t_ = h_prime[t];
407+
408+
if (t == 0)
409+
prev_h_ = initial_h;
410+
else
411+
prev_h_ = hidden_states[t-1];
412+
413+
if (z_non_linearity == 0)
414+
AT_DISPATCH_FLOATING_TYPES(z_t_.type(), "fastgrnn_forward_cuda", ([&] {
415+
fastgrnn_unroll_cuda_backward_kernel<scalar_t, d_sigmoid><<<blocks, threads>>>(
416+
d_precomp.packed_accessor<scalar_t,2,torch::RestrictPtrTraits,size_t>(),
417+
d_old_h.packed_accessor<scalar_t,2,torch::RestrictPtrTraits,size_t>(),
418+
d_bias_z.packed_accessor<scalar_t,2,torch::RestrictPtrTraits,size_t>(),
419+
d_bias_h_prime.packed_accessor<scalar_t,2,torch::RestrictPtrTraits,size_t>(),
420+
d_nu.packed_accessor<scalar_t,2,torch::RestrictPtrTraits,size_t>(),
421+
d_zeta.packed_accessor<scalar_t,2,torch::RestrictPtrTraits,size_t>(),
422+
grad_curr_h.packed_accessor<scalar_t,2,torch::RestrictPtrTraits,size_t>(),
423+
z_t_.packed_accessor<scalar_t,2,torch::RestrictPtrTraits,size_t>(),
424+
h_prime_t_.packed_accessor<scalar_t,2,torch::RestrictPtrTraits,size_t>(),
425+
zeta.packed_accessor<scalar_t,2,torch::RestrictPtrTraits,size_t>(),
426+
nu.packed_accessor<scalar_t,2,torch::RestrictPtrTraits,size_t>(),
427+
d_zeta_sigmoid.packed_accessor<scalar_t,2,torch::RestrictPtrTraits,size_t>(),
428+
d_nu_sigmoid.packed_accessor<scalar_t,2,torch::RestrictPtrTraits,size_t>(),
429+
prev_h_.packed_accessor<scalar_t,2,torch::RestrictPtrTraits,size_t>());
430+
}));
431+
else if (z_non_linearity == 1)
432+
AT_DISPATCH_FLOATING_TYPES(z_t_.type(), "fastgrnn_forward_cuda", ([&] {
433+
fastgrnn_unroll_cuda_backward_kernel<scalar_t, d_relu><<<blocks, threads>>>(
434+
d_precomp.packed_accessor<scalar_t,2,torch::RestrictPtrTraits,size_t>(),
435+
d_old_h.packed_accessor<scalar_t,2,torch::RestrictPtrTraits,size_t>(),
436+
d_bias_z.packed_accessor<scalar_t,2,torch::RestrictPtrTraits,size_t>(),
437+
d_bias_h_prime.packed_accessor<scalar_t,2,torch::RestrictPtrTraits,size_t>(),
438+
d_nu.packed_accessor<scalar_t,2,torch::RestrictPtrTraits,size_t>(),
439+
d_zeta.packed_accessor<scalar_t,2,torch::RestrictPtrTraits,size_t>(),
440+
grad_curr_h.packed_accessor<scalar_t,2,torch::RestrictPtrTraits,size_t>(),
441+
z_t_.packed_accessor<scalar_t,2,torch::RestrictPtrTraits,size_t>(),
442+
h_prime_t_.packed_accessor<scalar_t,2,torch::RestrictPtrTraits,size_t>(),
443+
zeta.packed_accessor<scalar_t,2,torch::RestrictPtrTraits,size_t>(),
444+
nu.packed_accessor<scalar_t,2,torch::RestrictPtrTraits,size_t>(),
445+
d_zeta_sigmoid.packed_accessor<scalar_t,2,torch::RestrictPtrTraits,size_t>(),
446+
d_nu_sigmoid.packed_accessor<scalar_t,2,torch::RestrictPtrTraits,size_t>(),
447+
prev_h_.packed_accessor<scalar_t,2,torch::RestrictPtrTraits,size_t>());
448+
}));
449+
else if(z_non_linearity == 2)
450+
AT_DISPATCH_FLOATING_TYPES(z_t_.type(), "fastgrnn_forward_cuda", ([&] {
451+
fastgrnn_unroll_cuda_backward_kernel<scalar_t, d_sigmoid><<<blocks, threads>>>(
452+
d_precomp.packed_accessor<scalar_t,2,torch::RestrictPtrTraits,size_t>(),
453+
d_old_h.packed_accessor<scalar_t,2,torch::RestrictPtrTraits,size_t>(),
454+
d_bias_z.packed_accessor<scalar_t,2,torch::RestrictPtrTraits,size_t>(),
455+
d_bias_h_prime.packed_accessor<scalar_t,2,torch::RestrictPtrTraits,size_t>(),
456+
d_nu.packed_accessor<scalar_t,2,torch::RestrictPtrTraits,size_t>(),
457+
d_zeta.packed_accessor<scalar_t,2,torch::RestrictPtrTraits,size_t>(),
458+
grad_curr_h.packed_accessor<scalar_t,2,torch::RestrictPtrTraits,size_t>(),
459+
z_t_.packed_accessor<scalar_t,2,torch::RestrictPtrTraits,size_t>(),
460+
h_prime_t_.packed_accessor<scalar_t,2,torch::RestrictPtrTraits,size_t>(),
461+
zeta.packed_accessor<scalar_t,2,torch::RestrictPtrTraits,size_t>(),
462+
nu.packed_accessor<scalar_t,2,torch::RestrictPtrTraits,size_t>(),
463+
d_zeta_sigmoid.packed_accessor<scalar_t,2,torch::RestrictPtrTraits,size_t>(),
464+
d_nu_sigmoid.packed_accessor<scalar_t,2,torch::RestrictPtrTraits,size_t>(),
465+
prev_h_.packed_accessor<scalar_t,2,torch::RestrictPtrTraits,size_t>());
466+
}));
467+
d_old_h = torch::addmm(d_old_h, d_precomp, u);
468+
d_input[t] = torch::mm(d_precomp, w);
469+
d_w = torch::addmm(d_w, d_precomp.transpose(0, 1), input[t]);
470+
d_u = torch::addmm(d_u, d_precomp.transpose(0, 1), prev_h_);
471+
// grad_curr_h = d_old_h;
472+
}
473+
d_bias_z = d_bias_z.sum(0, true);
474+
d_bias_h_prime = d_bias_h_prime.sum(0, true);
475+
d_zeta = (d_zeta.sum(0, true)).sum(1, true);
476+
d_nu = (d_nu.sum(0, true)).sum(1, true);
477+
return {d_input, d_w, d_u, d_bias_z, d_bias_h_prime, d_zeta, d_nu, d_old_h};
478+
}

pytorch/edgeml_pytorch/graph/rnn.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1063,6 +1063,34 @@ def getVars(self):
10631063
def forward(self, input, hiddenState=None, cellState=None):
10641064
return self.unrollRNN(input, hiddenState, cellState)
10651065

1066+
class FastGRNNCUDA(nn.Module):
1067+
"""Unrolled implementation of the FastGRNNCUDACell"""
1068+
def __init__(self, input_size, hidden_size, gate_non_linearity="sigmoid", zetaInit=1.0, nuInit=-4.0, name="FastGRNNCUDACell"):
1069+
super(FastGRNNCUDA, self).__init__()
1070+
if utils.findCUDA() is None:
1071+
raise Exception('FastGRNNCUDA is supported only on GPU devices.')
1072+
NON_LINEARITY = {"sigmoid": 0, "relu": 1, "tanh": 2}
1073+
self._input_size = input_size
1074+
self._hidden_size = hidden_size
1075+
self._zetaInit = zetaInit
1076+
self._nuInit = nuInit
1077+
self._name = name
1078+
self._gate_non_linearity = NON_LINEARITY[gate_non_linearity]
1079+
self.W = nn.Parameter(0.1 * torch.randn([input_size, hidden_size]))
1080+
self.U = nn.Parameter(0.1 * torch.randn([hidden_size, hidden_size]))
1081+
1082+
self.bias_gate = nn.Parameter(torch.ones([1, hidden_size]))
1083+
self.bias_update = nn.Parameter(torch.ones([1, hidden_size]))
1084+
self.zeta = nn.Parameter(self._zetaInit * torch.ones([1, 1]))
1085+
self.nu = nn.Parameter(self._nuInit * torch.ones([1, 1]))
1086+
1087+
def forward(self, input, h_state, cell_state=None):
1088+
# input: [timesteps, batch, features, state_size]
1089+
return FastGRNNUnrollFunction.apply(input, self.W, self.U, self.bias_gate, self.bias_update, self.zeta, self.nu, h_state, self._gate_non_linearity)
1090+
1091+
def getVars(self):
1092+
return [self.W, self.U, self.bias_gate, self.bias_update, self.zeta, self.nu]
1093+
10661094
class SRNN2(nn.Module):
10671095

10681096
def __init__(self, inputDim, outputDim, hiddenDim0, hiddenDim1, cellType,
@@ -1195,3 +1223,19 @@ def backward(ctx, grad_h):
11951223
d_input, d_w, d_u, d_bias_gate, d_bias_update, d_zeta, d_nu, d_old_h = outputs
11961224
return d_input, d_w, d_u, d_bias_gate, d_bias_update, d_zeta, d_nu, d_old_h, None
11971225

1226+
class FastGRNNUnrollFunction(Function):
1227+
@staticmethod
1228+
def forward(ctx, input, w, u, bias_gate, bias_update, zeta, nu, old_h, gate_non_linearity):
1229+
outputs = fastgrnn_cuda.forward_unroll(input, w, u, bias_gate, bias_update, zeta, nu, old_h, gate_non_linearity)
1230+
hidden_states = outputs[0]
1231+
variables = [input, hidden_states, zeta, nu, w, u] + outputs[1:] + [old_h]
1232+
ctx.save_for_backward(*variables)
1233+
ctx.gate_non_linearity = gate_non_linearity
1234+
return hidden_states
1235+
1236+
@staticmethod
1237+
def backward(ctx, grad_h):
1238+
outputs = fastgrnn_cuda.backward_unroll(
1239+
grad_h.contiguous(), *ctx.saved_variables, ctx.gate_non_linearity)
1240+
d_input, d_w, d_u, d_bias_gate, d_bias_update, d_zeta, d_nu, d_old_h = outputs
1241+
return d_input, d_w, d_u, d_bias_gate, d_bias_update, d_zeta, d_nu, d_old_h

0 commit comments

Comments
 (0)