Skip to content

Commit 79bd969

Browse files
committed
add FastGRNNCUDACell
1 parent ae2668e commit 79bd969

File tree

4 files changed

+369
-0
lines changed

4 files changed

+369
-0
lines changed
Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,103 @@
1+
#include <torch/extension.h>
2+
3+
#include <vector>
4+
5+
// CUDA forward declarations
6+
7+
std::vector<torch::Tensor> fastgrnn_cuda_forward(
8+
torch::Tensor input,
9+
torch::Tensor w,
10+
torch::Tensor u,
11+
torch::Tensor bias_z,
12+
torch::Tensor bias_h_prime,
13+
torch::Tensor old_h,
14+
torch::Tensor zeta,
15+
torch::Tensor nu);
16+
17+
std::vector<torch::Tensor> fastgrnn_cuda_backward(
18+
torch::Tensor grad_h,
19+
torch::Tensor input,
20+
torch::Tensor old_h,
21+
torch::Tensor z_t,
22+
torch::Tensor h_prime_t,
23+
torch::Tensor pre_comp,
24+
torch::Tensor w,
25+
torch::Tensor u,
26+
torch::Tensor bias_z,
27+
torch::Tensor bias_h_prime,
28+
torch::Tensor zeta,
29+
torch::Tensor nu);
30+
31+
// C++ interface
32+
33+
// NOTE: AT_ASSERT has become AT_CHECK on master after 0.4.
34+
#define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor")
35+
#define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous")
36+
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
37+
38+
std::vector<torch::Tensor> fastgrnn_forward(
39+
torch::Tensor input,
40+
torch::Tensor w,
41+
torch::Tensor u,
42+
torch::Tensor bias_z,
43+
torch::Tensor bias_h_prime,
44+
torch::Tensor old_h,
45+
torch::Tensor zeta,
46+
torch::Tensor nu) {
47+
CHECK_INPUT(input);
48+
CHECK_INPUT(w);
49+
CHECK_INPUT(u);
50+
CHECK_INPUT(bias_z);
51+
CHECK_INPUT(bias_h_prime);
52+
CHECK_INPUT(old_h);
53+
CHECK_INPUT(zeta);
54+
CHECK_INPUT(nu);
55+
56+
return fastgrnn_cuda_forward(input, w, u, bias_z, bias_h_prime, old_h, zeta, nu);
57+
}
58+
59+
std::vector<torch::Tensor> fastgrnn_backward(
60+
torch::Tensor grad_h,
61+
torch::Tensor input,
62+
torch::Tensor old_h,
63+
torch::Tensor z_t,
64+
torch::Tensor h_prime_t,
65+
torch::Tensor pre_comp,
66+
torch::Tensor w,
67+
torch::Tensor u,
68+
torch::Tensor bias_z,
69+
torch::Tensor bias_h_prime,
70+
torch::Tensor zeta,
71+
torch::Tensor nu) {
72+
CHECK_INPUT(grad_h);
73+
CHECK_INPUT(input);
74+
CHECK_INPUT(old_h);
75+
CHECK_INPUT(z_t);
76+
CHECK_INPUT(h_prime_t);
77+
CHECK_INPUT(pre_comp);
78+
CHECK_INPUT(w);
79+
CHECK_INPUT(u);
80+
CHECK_INPUT(bias_z);
81+
CHECK_INPUT(bias_h_prime);
82+
CHECK_INPUT(zeta);
83+
CHECK_INPUT(nu);
84+
85+
return fastgrnn_cuda_backward(
86+
grad_h,
87+
input,
88+
old_h,
89+
z_t,
90+
h_prime_t,
91+
pre_comp,
92+
w,
93+
u,
94+
bias_z,
95+
bias_h_prime,
96+
zeta,
97+
nu);
98+
}
99+
100+
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
101+
m.def("forward", &fastgrnn_forward, "FastGRNN forward (CUDA)");
102+
m.def("backward", &fastgrnn_backward, "FastGRNN backward (CUDA)");
103+
}
Lines changed: 172 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,172 @@
1+
#include <torch/extension.h>
2+
3+
#include <cuda.h>
4+
#include <cuda_runtime.h>
5+
6+
#include <vector>
7+
8+
namespace {
9+
template <typename scalar_t>
10+
__device__ __forceinline__ scalar_t sigmoid(scalar_t z) {
11+
return 1.0 / (1.0 + exp(-z));
12+
}
13+
14+
template <typename scalar_t>
15+
__device__ __forceinline__ scalar_t d_sigmoid(scalar_t z) {
16+
const auto s = sigmoid(z);
17+
return (1.0 - s) * s;
18+
}
19+
20+
template <typename scalar_t>
21+
__device__ __forceinline__ scalar_t d_tanh(scalar_t z) {
22+
const auto t = tanh(z);
23+
return 1 - (t * t);
24+
}
25+
26+
template <typename scalar_t>
27+
__global__ void fastgrnn_cuda_forward_kernel(
28+
const torch::PackedTensorAccessor<scalar_t,2,torch::RestrictPtrTraits,size_t> pre_comp,
29+
const torch::PackedTensorAccessor<scalar_t,2,torch::RestrictPtrTraits,size_t> old_h,
30+
torch::PackedTensorAccessor<scalar_t,2,torch::RestrictPtrTraits,size_t> new_h,
31+
torch::PackedTensorAccessor<scalar_t,2,torch::RestrictPtrTraits,size_t> z_t,
32+
torch::PackedTensorAccessor<scalar_t,2,torch::RestrictPtrTraits,size_t> h_prime_t,
33+
torch::PackedTensorAccessor<scalar_t,2,torch::RestrictPtrTraits,size_t> bias_z,
34+
torch::PackedTensorAccessor<scalar_t,2,torch::RestrictPtrTraits,size_t> bias_h_prime,
35+
torch::PackedTensorAccessor<scalar_t,2,torch::RestrictPtrTraits,size_t> zeta,
36+
torch::PackedTensorAccessor<scalar_t,2,torch::RestrictPtrTraits,size_t> nu) {
37+
//batch index
38+
const int n = blockIdx.y;
39+
// column index
40+
const int c = blockIdx.x * blockDim.x + threadIdx.x;
41+
if (c < pre_comp.size(1)){
42+
z_t[n][c] = sigmoid(pre_comp[n][c] + bias_z[n][c]);
43+
h_prime_t[n][c] = tanh(pre_comp[n][c] + bias_h_prime[n][c]);
44+
45+
new_h[n][c] = (sigmoid(zeta[0][0]) * (1 - z_t[n][c]) + sigmoid(nu[0][0])) * h_prime_t[n][c] + z_t[n][c] * old_h[n][c];
46+
}
47+
}
48+
49+
template <typename scalar_t>
50+
__global__ void fastgrnn_cuda_backward_kernel(
51+
torch::PackedTensorAccessor<scalar_t,2,torch::RestrictPtrTraits,size_t> d_zeta,
52+
torch::PackedTensorAccessor<scalar_t,2,torch::RestrictPtrTraits,size_t> d_nu,
53+
torch::PackedTensorAccessor<scalar_t,2,torch::RestrictPtrTraits,size_t> d_precomp,
54+
torch::PackedTensorAccessor<scalar_t,2,torch::RestrictPtrTraits,size_t> d_bias_z,
55+
torch::PackedTensorAccessor<scalar_t,2,torch::RestrictPtrTraits,size_t> d_bias_h_prime_t,
56+
torch::PackedTensorAccessor<scalar_t,2,torch::RestrictPtrTraits,size_t> d_old_h,
57+
const torch::PackedTensorAccessor<scalar_t,2,torch::RestrictPtrTraits,size_t> grad_h,
58+
const torch::PackedTensorAccessor<scalar_t,2,torch::RestrictPtrTraits,size_t> old_h,
59+
const torch::PackedTensorAccessor<scalar_t,2,torch::RestrictPtrTraits,size_t> z_t,
60+
const torch::PackedTensorAccessor<scalar_t,2,torch::RestrictPtrTraits,size_t> h_prime_t,
61+
const torch::PackedTensorAccessor<scalar_t,2,torch::RestrictPtrTraits,size_t> pre_comp,
62+
const torch::PackedTensorAccessor<scalar_t,2,torch::RestrictPtrTraits,size_t> bias_z,
63+
const torch::PackedTensorAccessor<scalar_t,2,torch::RestrictPtrTraits,size_t> bias_h_prime,
64+
const torch::PackedTensorAccessor<scalar_t,2,torch::RestrictPtrTraits,size_t> zeta,
65+
const torch::PackedTensorAccessor<scalar_t,2,torch::RestrictPtrTraits,size_t> nu) {
66+
//batch index
67+
const int n = blockIdx.y;
68+
// column index
69+
const int c = blockIdx.x * blockDim.x + threadIdx.x;
70+
if (c < d_precomp.size(1)){
71+
auto temp_grad = grad_h[n][c] * h_prime_t[n][c];
72+
d_zeta[0][0] = temp_grad * (1 - z_t[n][c]) * d_sigmoid(zeta[0][0]);
73+
d_nu[0][0] = temp_grad * d_sigmoid(nu[0][0]);
74+
d_bias_z[n][c] = grad_h[n][c] * (sigmoid(zeta[0][0]) * -1 * h_prime_t[n][c] + old_h[n][c]) * d_sigmoid(pre_comp[n][c] + bias_z[n][c]);;
75+
d_bias_h_prime_t[n][c] = grad_h[n][c] * (sigmoid(zeta[0][0]) * (1 - z_t[n][c]) + sigmoid(nu[0][0])) * d_tanh(pre_comp[n][c] + bias_h_prime[n][c]);
76+
d_old_h[n][c] = grad_h[n][c] * z_t[n][c];
77+
d_precomp[n][c] = d_bias_z[n][c] + d_bias_h_prime_t[n][c];
78+
}
79+
}
80+
} // namespace
81+
82+
std::vector<torch::Tensor> fastgrnn_cuda_forward(
83+
torch::Tensor input,
84+
torch::Tensor w,
85+
torch::Tensor u,
86+
torch::Tensor bias_z,
87+
torch::Tensor bias_h_prime,
88+
torch::Tensor old_h,
89+
torch::Tensor zeta,
90+
torch::Tensor nu) {
91+
auto w_comp = torch::mm(input, w);
92+
auto u_comp = torch::mm(old_h, u);
93+
auto pre_comp = torch::add(u_comp, w_comp);
94+
95+
const auto batch_size = old_h.size(0);
96+
const auto state_size = old_h.size(1);
97+
98+
auto new_h = torch::zeros_like(old_h);
99+
auto z_t = torch::zeros_like(old_h);
100+
auto h_prime_t = torch::zeros_like(old_h);
101+
102+
const int threads = 1024;
103+
const dim3 blocks((state_size + threads - 1) / threads, batch_size);
104+
105+
AT_DISPATCH_FLOATING_TYPES(pre_comp.type(), "fastgrnn_forward_cuda", ([&] {
106+
fastgrnn_cuda_forward_kernel<scalar_t><<<blocks, threads>>>(
107+
pre_comp.packed_accessor<scalar_t,2,torch::RestrictPtrTraits,size_t>(),
108+
old_h.packed_accessor<scalar_t,2,torch::RestrictPtrTraits,size_t>(),
109+
new_h.packed_accessor<scalar_t,2,torch::RestrictPtrTraits,size_t>(),
110+
z_t.packed_accessor<scalar_t,2,torch::RestrictPtrTraits,size_t>(),
111+
h_prime_t.packed_accessor<scalar_t,2,torch::RestrictPtrTraits,size_t>(),
112+
bias_z.packed_accessor<scalar_t,2,torch::RestrictPtrTraits,size_t>(),
113+
bias_h_prime.packed_accessor<scalar_t,2,torch::RestrictPtrTraits,size_t>(),
114+
zeta.packed_accessor<scalar_t,2,torch::RestrictPtrTraits,size_t>(),
115+
nu.packed_accessor<scalar_t,2,torch::RestrictPtrTraits,size_t>());
116+
}));
117+
118+
return {new_h, z_t, h_prime_t, pre_comp};
119+
}
120+
121+
std::vector<torch::Tensor> fastgrnn_cuda_backward(
122+
torch::Tensor grad_h,
123+
torch::Tensor input,
124+
torch::Tensor old_h,
125+
torch::Tensor z_t,
126+
torch::Tensor h_prime_t,
127+
torch::Tensor pre_comp,
128+
torch::Tensor w,
129+
torch::Tensor u,
130+
torch::Tensor bias_z,
131+
torch::Tensor bias_h_prime,
132+
torch::Tensor zeta,
133+
torch::Tensor nu) {
134+
auto d_precomp = torch::zeros_like(pre_comp);
135+
auto d_old_h = torch::zeros_like(old_h);
136+
auto d_zeta = torch::zeros_like(zeta);
137+
auto d_nu = torch::zeros_like(nu);
138+
auto d_bias_z = torch::zeros_like(bias_z);
139+
auto d_bias_h_prime = torch::zeros_like(bias_h_prime);
140+
141+
const auto batch_size = old_h.size(0);
142+
const auto state_size = old_h.size(1);
143+
144+
const int threads = 1024;
145+
const dim3 blocks((state_size + threads - 1) / threads, batch_size);
146+
147+
AT_DISPATCH_FLOATING_TYPES(pre_comp.type(), "fastgrnn_forward_cuda", ([&] {
148+
fastgrnn_cuda_backward_kernel<scalar_t><<<blocks, threads>>>(
149+
d_zeta.packed_accessor<scalar_t,2,torch::RestrictPtrTraits,size_t>(),
150+
d_nu.packed_accessor<scalar_t,2,torch::RestrictPtrTraits,size_t>(),
151+
d_precomp.packed_accessor<scalar_t,2,torch::RestrictPtrTraits,size_t>(),
152+
d_bias_z.packed_accessor<scalar_t,2,torch::RestrictPtrTraits,size_t>(),
153+
d_bias_h_prime.packed_accessor<scalar_t,2,torch::RestrictPtrTraits,size_t>(),
154+
d_old_h.packed_accessor<scalar_t,2,torch::RestrictPtrTraits,size_t>(),
155+
grad_h.packed_accessor<scalar_t,2,torch::RestrictPtrTraits,size_t>(),
156+
old_h.packed_accessor<scalar_t,2,torch::RestrictPtrTraits,size_t>(),
157+
z_t.packed_accessor<scalar_t,2,torch::RestrictPtrTraits,size_t>(),
158+
h_prime_t.packed_accessor<scalar_t,2,torch::RestrictPtrTraits,size_t>(),
159+
pre_comp.packed_accessor<scalar_t,2,torch::RestrictPtrTraits,size_t>(),
160+
bias_z.packed_accessor<scalar_t,2,torch::RestrictPtrTraits,size_t>(),
161+
bias_h_prime.packed_accessor<scalar_t,2,torch::RestrictPtrTraits,size_t>(),
162+
zeta.packed_accessor<scalar_t,2,torch::RestrictPtrTraits,size_t>(),
163+
nu.packed_accessor<scalar_t,2,torch::RestrictPtrTraits,size_t>());
164+
}));
165+
166+
d_old_h = torch::add(d_old_h, torch::mm(torch::add(d_bias_h_prime, d_bias_z), u.transpose(0, 1)));
167+
auto d_input = torch::mm(d_precomp, w.transpose(0, 1));
168+
auto d_w = torch::mm(input.transpose(0, 1), d_precomp);
169+
auto d_u = torch::mm(old_h.transpose(0, 1), d_precomp);
170+
171+
return {d_old_h, d_input, d_w, d_u, d_bias_z, d_bias_h_prime, d_nu, d_zeta};
172+
}

pytorch/edgeml_pytorch/graph/rnn.py

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,17 @@
11
# Copyright (c) Microsoft Corporation. All rights reserved.
22
# Licensed under the MIT license.
33

4+
import os
45
import torch
56
import torch.nn as nn
67
from torch.autograd import Function
78
import numpy as np
89

910
import edgeml_pytorch.utils as utils
1011

12+
if "CUDA_HOME" in os.environ:
13+
import fastgrnn_cuda
14+
1115
def onnx_exportable_rnn(input, fargs, cell, output):
1216
class RNNSymbolic(Function):
1317
@staticmethod
@@ -296,6 +300,63 @@ def getVars(self):
296300
Vars.extend([self.zeta, self.nu])
297301
return Vars
298302

303+
class FastGRNNCUDACell(RNNCell):
304+
'''
305+
A CUDA implementation of FastGRNN Cell with Full Rank Support
306+
hidden_size = # hidden units
307+
308+
zetaInit = init for zeta, the scale param
309+
nuInit = init for nu, the translation param
310+
311+
FastGRNN architecture and compression techniques are found in
312+
FastGRNN(LINK) paper
313+
314+
Basic architecture is like:
315+
316+
z_t = sigmoid(Wx_t + Uh_{t-1} + B_g)
317+
h_t^ = tanh(Wx_t + Uh_{t-1} + B_h)
318+
h_t = z_t*h_{t-1} + (sigmoid(zeta)(1-z_t) + sigmoid(nu))*h_t^
319+
320+
'''
321+
def __init__(self, input_size, hidden_size, zetaInit=1.0, nuInit=-4.0, name="FastGRNNCUDACell"):
322+
super(FastGRNNCUDACell, self).__init__(input_size, hidden_size, "sigmoid", "tanh", 1, 1, 2)
323+
if not "CUDA_HOME" in os.environ:
324+
raise Exception('FastGRNNCUDACell is supported only on GPU devices.')
325+
self._input_size = input_size
326+
self._hidden_size = hidden_size
327+
self._gate_non_linearity = gate_non_linearity
328+
self._update_non_linearity = update_non_linearity
329+
self._zetaInit = zetaInit
330+
self._nuInit = nuInit
331+
self._name = name
332+
333+
self.W = nn.Parameter(0.1 * torch.randn([input_size, hidden_size]))
334+
self.U = nn.Parameter(0.1 * torch.randn([hidden_size, hidden_size]))
335+
336+
self.bias_gate = nn.Parameter(torch.ones([1, hidden_size]))
337+
self.bias_update = nn.Parameter(torch.ones([1, hidden_size]))
338+
self.zeta = nn.Parameter(self._zetaInit * torch.ones([1, 1]))
339+
self.nu = nn.Parameter(self._nuInit * torch.ones([1, 1]))
340+
341+
def reset_parameters(self):
342+
stdv = 1.0 / math.sqrt(self.state_size)
343+
for weight in self.parameters():
344+
weight.data.uniform_(-stdv, +stdv)
345+
346+
@property
347+
def name(self):
348+
return self._name
349+
350+
@property
351+
def cellType(self):
352+
return "FastGRNNCUDACell"
353+
354+
def forward(self, input, state):
355+
# Calls the custom autograd function while invokes the CUDA implementation
356+
return FastGRNNFunction.apply(input, self.W, self.U, self.bias_gate, self.bias_update, state, self.zeta, self.nu)
357+
358+
def getVars(self):
359+
return [self.W, self.U, self.bias_gate, self.bias_update, self.zeta, self.nu]
299360

300361
class FastRNNCell(RNNCell):
301362
'''
@@ -1117,3 +1178,20 @@ def forward(self, x, brickSize):
11171178
hidd1 = torch.squeeze(hidd1[-1])
11181179
out = torch.matmul(hidd1, self.W) + self.B
11191180
return out
1181+
1182+
class FastGRNNFunction(Function):
1183+
@staticmethod
1184+
def forward(ctx, input, w, u, bias_z, bias_h_prime, old_h, zeta, nu):
1185+
outputs = fastgrnn_cuda.forward(input, w, u, bias_z, bias_h_prime, old_h, zeta, nu)
1186+
new_h = outputs[0]
1187+
variables = [input, old_h] + outputs[1:] + [w, u, bias_z, bias_h_prime, zeta, nu]
1188+
ctx.save_for_backward(*variables)
1189+
return new_h
1190+
1191+
@staticmethod
1192+
def backward(ctx, grad_h):
1193+
outputs = fastgrnn_cuda.backward(
1194+
grad_h.contiguous(), *ctx.saved_variables)
1195+
d_old_h, d_input, d_w, d_u, d_bias_z, d_bias_h_prime_t, d_nu, d_zeta = outputs
1196+
return d_input, d_w, d_u, d_bias_z, d_bias_h_prime_t, d_old_h, d_zeta, d_nu
1197+

0 commit comments

Comments
 (0)