Skip to content

Commit 5d0af0f

Browse files
Merge pull request #131 from MJ10/fastgrnn-cuda
FastGRNN CUDA
2 parents 0176ebc + 23f6e0c commit 5d0af0f

File tree

5 files changed

+1035
-1
lines changed

5 files changed

+1035
-1
lines changed
Lines changed: 239 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,239 @@
1+
#include <torch/extension.h>
2+
3+
#include <vector>
4+
5+
std::vector<torch::Tensor> fastgrnn_cuda_forward(
6+
torch::Tensor input,
7+
torch::Tensor w,
8+
torch::Tensor u,
9+
torch::Tensor bias_gate,
10+
torch::Tensor bias_update,
11+
torch::Tensor zeta,
12+
torch::Tensor nu,
13+
torch::Tensor old_h,
14+
int z_non_linearity,
15+
torch::Tensor w1,
16+
torch::Tensor w2,
17+
torch::Tensor u1,
18+
torch::Tensor u2);
19+
20+
std::vector<torch::Tensor> fastgrnn_cuda_backward(
21+
torch::Tensor grad_h,
22+
torch::Tensor input,
23+
torch::Tensor old_h,
24+
torch::Tensor zeta,
25+
torch::Tensor nu,
26+
torch::Tensor w,
27+
torch::Tensor u,
28+
int z_non_linearity,
29+
torch::Tensor z,
30+
torch::Tensor h_prime,
31+
torch::Tensor w1,
32+
torch::Tensor w2,
33+
torch::Tensor u1,
34+
torch::Tensor u2);
35+
36+
std::vector<torch::Tensor> fastgrnn_unroll_cuda_forward(
37+
torch::Tensor input,
38+
torch::Tensor w,
39+
torch::Tensor u,
40+
torch::Tensor bias_z,
41+
torch::Tensor bias_h_prime,
42+
torch::Tensor zeta,
43+
torch::Tensor nu,
44+
torch::Tensor initial_h,
45+
int z_non_linearity,
46+
torch::Tensor w1,
47+
torch::Tensor w2,
48+
torch::Tensor u1,
49+
torch::Tensor u2);
50+
51+
std::vector<torch::Tensor> fastgrnn_unroll_cuda_backward(
52+
torch::Tensor grad_h,
53+
torch::Tensor input,
54+
torch::Tensor hidden_states,
55+
torch::Tensor zeta,
56+
torch::Tensor nu,
57+
torch::Tensor w,
58+
torch::Tensor u,
59+
torch::Tensor z,
60+
torch::Tensor h_prime,
61+
torch::Tensor initial_h,
62+
int z_non_linearity,
63+
torch::Tensor w1,
64+
torch::Tensor w2,
65+
torch::Tensor u1,
66+
torch::Tensor u2);
67+
68+
#define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor")
69+
#define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous")
70+
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
71+
72+
std::vector<torch::Tensor> fastgrnn_forward(
73+
torch::Tensor input,
74+
torch::Tensor w,
75+
torch::Tensor u,
76+
torch::Tensor bias_gate,
77+
torch::Tensor bias_update,
78+
torch::Tensor zeta,
79+
torch::Tensor nu,
80+
torch::Tensor old_h,
81+
int z_non_linearity,
82+
torch::Tensor w1,
83+
torch::Tensor w2,
84+
torch::Tensor u1,
85+
torch::Tensor u2) {
86+
CHECK_INPUT(input);
87+
if(w1.size(0) == 0) {
88+
CHECK_INPUT(w);
89+
} else {
90+
CHECK_INPUT(w1);
91+
CHECK_INPUT(w2);
92+
}
93+
if (u1.size(0) == 0) {
94+
CHECK_INPUT(u);
95+
} else {
96+
CHECK_INPUT(u1);
97+
CHECK_INPUT(u2);
98+
}
99+
CHECK_INPUT(bias_gate);
100+
CHECK_INPUT(bias_update);
101+
CHECK_INPUT(zeta);
102+
CHECK_INPUT(nu);
103+
CHECK_INPUT(old_h);
104+
105+
return fastgrnn_cuda_forward(input, w, u, bias_gate, bias_update, zeta, nu, old_h, z_non_linearity, w1, w2, u1, u2);
106+
}
107+
108+
std::vector<torch::Tensor> fastgrnn_backward(
109+
torch::Tensor grad_h,
110+
torch::Tensor input,
111+
torch::Tensor old_h,
112+
torch::Tensor zeta,
113+
torch::Tensor nu,
114+
torch::Tensor w,
115+
torch::Tensor u,
116+
torch::Tensor z,
117+
torch::Tensor h_prime,
118+
torch::Tensor w1,
119+
torch::Tensor w2,
120+
torch::Tensor u1,
121+
torch::Tensor u2,
122+
int z_non_linearity) {
123+
CHECK_INPUT(grad_h);
124+
CHECK_INPUT(input);
125+
CHECK_INPUT(old_h);
126+
CHECK_INPUT(zeta);
127+
CHECK_INPUT(nu);
128+
CHECK_INPUT(z);
129+
CHECK_INPUT(h_prime);
130+
if(w1.size(0) == 0) {
131+
CHECK_INPUT(w);
132+
} else {
133+
CHECK_INPUT(w1);
134+
CHECK_INPUT(w2);
135+
}
136+
if (u1.size(0) == 0) {
137+
CHECK_INPUT(u);
138+
} else {
139+
CHECK_INPUT(u1);
140+
CHECK_INPUT(u2);
141+
}
142+
143+
return fastgrnn_cuda_backward(grad_h, input, old_h, zeta, nu, w, u, z_non_linearity, z, h_prime, w1, w2, u1, u2);
144+
}
145+
146+
std::vector<torch::Tensor> fastgrnn_unroll_forward(
147+
torch::Tensor input,
148+
torch::Tensor w,
149+
torch::Tensor u,
150+
torch::Tensor bias_z,
151+
torch::Tensor bias_h_prime,
152+
torch::Tensor zeta,
153+
torch::Tensor nu,
154+
torch::Tensor initial_h,
155+
int z_non_linearity,
156+
torch::Tensor w1,
157+
torch::Tensor w2,
158+
torch::Tensor u1,
159+
torch::Tensor u2) {
160+
CHECK_INPUT(input);
161+
if(w1.size(0) == 0) {
162+
CHECK_INPUT(w);
163+
} else {
164+
CHECK_INPUT(w1);
165+
CHECK_INPUT(w2);
166+
}
167+
if (u1.size(0) == 0) {
168+
CHECK_INPUT(u);
169+
} else {
170+
CHECK_INPUT(u1);
171+
CHECK_INPUT(u2);
172+
}
173+
CHECK_INPUT(bias_z);
174+
CHECK_INPUT(bias_h_prime);
175+
CHECK_INPUT(initial_h);
176+
CHECK_INPUT(zeta);
177+
CHECK_INPUT(nu);
178+
return fastgrnn_unroll_cuda_forward(input, w, u, bias_z, bias_h_prime, zeta, nu, initial_h, z_non_linearity, w1, w2, u1, u2);
179+
}
180+
181+
std::vector<torch::Tensor> fastgrnn_unroll_backward(
182+
torch::Tensor grad_h,
183+
torch::Tensor input,
184+
torch::Tensor hidden_states,
185+
torch::Tensor zeta,
186+
torch::Tensor nu,
187+
torch::Tensor w,
188+
torch::Tensor u,
189+
torch::Tensor z,
190+
torch::Tensor h_prime,
191+
torch::Tensor initial_h,
192+
torch::Tensor w1,
193+
torch::Tensor w2,
194+
torch::Tensor u1,
195+
torch::Tensor u2,
196+
int z_non_linearity) {
197+
CHECK_INPUT(grad_h);
198+
CHECK_INPUT(input);
199+
CHECK_INPUT(hidden_states);
200+
CHECK_INPUT(z);
201+
CHECK_INPUT(h_prime);
202+
if(w1.size(0) == 0) {
203+
CHECK_INPUT(w);
204+
} else {
205+
CHECK_INPUT(w1);
206+
CHECK_INPUT(w2);
207+
}
208+
if (u1.size(0) == 0) {
209+
CHECK_INPUT(u);
210+
} else {
211+
CHECK_INPUT(u1);
212+
CHECK_INPUT(u2);
213+
}
214+
CHECK_INPUT(zeta);
215+
CHECK_INPUT(nu);
216+
CHECK_INPUT(initial_h);
217+
218+
return fastgrnn_unroll_cuda_backward(
219+
grad_h,
220+
input,
221+
hidden_states,
222+
zeta,
223+
nu,
224+
w,
225+
u,
226+
z,
227+
h_prime,
228+
initial_h,
229+
z_non_linearity,
230+
w1, w2, u1, u2);
231+
}
232+
233+
234+
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
235+
m.def("forward", &fastgrnn_forward, "FastGRNN forward (CUDA)");
236+
m.def("backward", &fastgrnn_backward, "FastGRNN backward (CUDA)");
237+
m.def("forward_unroll", &fastgrnn_unroll_forward, "FastGRNN Unrolled forward (CUDA)");
238+
m.def("backward_unroll", &fastgrnn_unroll_backward, "FastGRNN Unrolled backward (CUDA)");
239+
}

0 commit comments

Comments
 (0)