Skip to content

Commit 61ce245

Browse files
committed
fastgrnncuda: precompute sigmoid
1 parent a8df961 commit 61ce245

File tree

3 files changed

+127
-154
lines changed

3 files changed

+127
-154
lines changed

pytorch/edgeml_pytorch/cuda/fastgrnn_cuda.cpp

Lines changed: 25 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -2,35 +2,27 @@
22

33
#include <vector>
44

5-
// CUDA forward declarations
6-
75
std::vector<torch::Tensor> fastgrnn_cuda_forward(
86
torch::Tensor input,
97
torch::Tensor w,
108
torch::Tensor u,
11-
torch::Tensor bias_z,
12-
torch::Tensor bias_h_prime,
13-
torch::Tensor old_h,
9+
torch::Tensor bias_gate,
10+
torch::Tensor bias_update,
1411
torch::Tensor zeta,
15-
torch::Tensor nu);
12+
torch::Tensor nu,
13+
torch::Tensor old_h);
1614

1715
std::vector<torch::Tensor> fastgrnn_cuda_backward(
1816
torch::Tensor grad_h,
1917
torch::Tensor input,
2018
torch::Tensor old_h,
21-
torch::Tensor z_t,
22-
torch::Tensor h_prime_t,
23-
torch::Tensor pre_comp,
19+
torch::Tensor zeta,
20+
torch::Tensor nu,
2421
torch::Tensor w,
2522
torch::Tensor u,
26-
torch::Tensor bias_z,
27-
torch::Tensor bias_h_prime,
28-
torch::Tensor zeta,
29-
torch::Tensor nu);
23+
torch::Tensor z,
24+
torch::Tensor h_prime);
3025

31-
// C++ interface
32-
33-
// NOTE: AT_ASSERT has become AT_CHECK on master after 0.4.
3426
#define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor")
3527
#define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous")
3628
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
@@ -39,62 +31,44 @@ std::vector<torch::Tensor> fastgrnn_forward(
3931
torch::Tensor input,
4032
torch::Tensor w,
4133
torch::Tensor u,
42-
torch::Tensor bias_z,
43-
torch::Tensor bias_h_prime,
44-
torch::Tensor old_h,
34+
torch::Tensor bias_gate,
35+
torch::Tensor bias_update,
4536
torch::Tensor zeta,
46-
torch::Tensor nu) {
37+
torch::Tensor nu,
38+
torch::Tensor old_h) {
4739
CHECK_INPUT(input);
4840
CHECK_INPUT(w);
4941
CHECK_INPUT(u);
50-
CHECK_INPUT(bias_z);
51-
CHECK_INPUT(bias_h_prime);
52-
CHECK_INPUT(old_h);
42+
CHECK_INPUT(bias_gate);
43+
CHECK_INPUT(bias_update);
5344
CHECK_INPUT(zeta);
5445
CHECK_INPUT(nu);
46+
CHECK_INPUT(old_h);
5547

56-
return fastgrnn_cuda_forward(input, w, u, bias_z, bias_h_prime, old_h, zeta, nu);
48+
return fastgrnn_cuda_forward(input, w, u, bias_gate, bias_update, zeta, nu, old_h);
5749
}
5850

5951
std::vector<torch::Tensor> fastgrnn_backward(
6052
torch::Tensor grad_h,
6153
torch::Tensor input,
6254
torch::Tensor old_h,
63-
torch::Tensor z_t,
64-
torch::Tensor h_prime_t,
65-
torch::Tensor pre_comp,
55+
torch::Tensor zeta,
56+
torch::Tensor nu,
6657
torch::Tensor w,
6758
torch::Tensor u,
68-
torch::Tensor bias_z,
69-
torch::Tensor bias_h_prime,
70-
torch::Tensor zeta,
71-
torch::Tensor nu) {
59+
torch::Tensor z,
60+
torch::Tensor h_prime) {
7261
CHECK_INPUT(grad_h);
7362
CHECK_INPUT(input);
7463
CHECK_INPUT(old_h);
75-
CHECK_INPUT(z_t);
76-
CHECK_INPUT(h_prime_t);
77-
CHECK_INPUT(pre_comp);
78-
CHECK_INPUT(w);
79-
CHECK_INPUT(u);
80-
CHECK_INPUT(bias_z);
81-
CHECK_INPUT(bias_h_prime);
8264
CHECK_INPUT(zeta);
8365
CHECK_INPUT(nu);
66+
CHECK_INPUT(z);
67+
CHECK_INPUT(h_prime);
68+
CHECK_INPUT(w);
69+
CHECK_INPUT(u);
8470

85-
return fastgrnn_cuda_backward(
86-
grad_h,
87-
input,
88-
old_h,
89-
z_t,
90-
h_prime_t,
91-
pre_comp,
92-
w,
93-
u,
94-
bias_z,
95-
bias_h_prime,
96-
zeta,
97-
nu);
71+
return fastgrnn_cuda_backward(grad_h, input, old_h, zeta, nu, w, u, z, h_prime);
9872
}
9973

10074
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {

pytorch/edgeml_pytorch/cuda/fastgrnn_cuda_kernel.cu

Lines changed: 96 additions & 97 deletions
Original file line numberDiff line numberDiff line change
@@ -5,76 +5,77 @@
55

66
#include <vector>
77

8+
__forceinline__ torch::Tensor d_sigmoid(torch::Tensor z) {
9+
return (1 - z) * z;
10+
}
11+
12+
__forceinline__ torch::Tensor d_tanh(torch::Tensor z) {
13+
return 1 - z.pow(2);
14+
}
15+
16+
817
namespace {
918
template <typename scalar_t>
1019
__device__ __forceinline__ scalar_t sigmoid(scalar_t z) {
1120
return 1.0 / (1.0 + exp(-z));
1221
}
1322

1423
template <typename scalar_t>
15-
__device__ __forceinline__ scalar_t d_sigmoid(scalar_t z) {
16-
const auto s = sigmoid(z);
17-
return (1.0 - s) * s;
24+
__device__ __forceinline__ scalar_t d_sigmoid(scalar_t sig_z) {
25+
return (1.0 - sig_z) * sig_z;
1826
}
1927

2028
template <typename scalar_t>
21-
__device__ __forceinline__ scalar_t d_tanh(scalar_t z) {
22-
const auto t = tanh(z);
23-
return 1 - (t * t);
29+
__device__ __forceinline__ scalar_t d_tanh(scalar_t tan_z) {
30+
return 1 - (tan_z * tan_z);
2431
}
2532

2633
template <typename scalar_t>
2734
__global__ void fastgrnn_cuda_forward_kernel(
35+
torch::PackedTensorAccessor<scalar_t,2,torch::RestrictPtrTraits,size_t> new_h,
36+
torch::PackedTensorAccessor<scalar_t,2,torch::RestrictPtrTraits,size_t> z,
37+
torch::PackedTensorAccessor<scalar_t,2,torch::RestrictPtrTraits,size_t> h_prime,
2838
const torch::PackedTensorAccessor<scalar_t,2,torch::RestrictPtrTraits,size_t> pre_comp,
29-
const torch::PackedTensorAccessor<scalar_t,2,torch::RestrictPtrTraits,size_t> old_h,
30-
torch::PackedTensorAccessor<scalar_t,2,torch::RestrictPtrTraits,size_t> new_h,
31-
torch::PackedTensorAccessor<scalar_t,2,torch::RestrictPtrTraits,size_t> z_t,
32-
torch::PackedTensorAccessor<scalar_t,2,torch::RestrictPtrTraits,size_t> h_prime_t,
33-
torch::PackedTensorAccessor<scalar_t,2,torch::RestrictPtrTraits,size_t> bias_z,
34-
torch::PackedTensorAccessor<scalar_t,2,torch::RestrictPtrTraits,size_t> bias_h_prime,
35-
torch::PackedTensorAccessor<scalar_t,2,torch::RestrictPtrTraits,size_t> zeta,
36-
torch::PackedTensorAccessor<scalar_t,2,torch::RestrictPtrTraits,size_t> nu) {
37-
//batch index
39+
const torch::PackedTensorAccessor<scalar_t,2,torch::RestrictPtrTraits,size_t> bias_z,
40+
const torch::PackedTensorAccessor<scalar_t,2,torch::RestrictPtrTraits,size_t> bias_h_prime,
41+
const torch::PackedTensorAccessor<scalar_t,2,torch::RestrictPtrTraits,size_t> nu,
42+
const torch::PackedTensorAccessor<scalar_t,2,torch::RestrictPtrTraits,size_t> zeta,
43+
const torch::PackedTensorAccessor<scalar_t,2,torch::RestrictPtrTraits,size_t> old_h) {
3844
const int n = blockIdx.y;
39-
// column index
4045
const int c = blockIdx.x * blockDim.x + threadIdx.x;
41-
if (c < pre_comp.size(1)){
42-
z_t[n][c] = sigmoid(pre_comp[n][c] + bias_z[n][c]);
43-
h_prime_t[n][c] = tanh(pre_comp[n][c] + bias_h_prime[n][c]);
44-
45-
new_h[n][c] = (sigmoid(zeta[0][0]) * (1 - z_t[n][c]) + sigmoid(nu[0][0])) * h_prime_t[n][c] + z_t[n][c] * old_h[n][c];
46+
if (c < old_h.size(1)){
47+
z[n][c] = sigmoid(pre_comp[n][c] + bias_z[0][c]);
48+
h_prime[n][c] = tanh(pre_comp[n][c] + bias_h_prime[0][c]);
49+
new_h[n][c] = (zeta[0][0] * (1.0 - z[n][c]) + nu[0][0]) * h_prime[n][c] + old_h[n][c] * z[n][c];
4650
}
4751
}
4852

53+
4954
template <typename scalar_t>
5055
__global__ void fastgrnn_cuda_backward_kernel(
51-
torch::PackedTensorAccessor<scalar_t,2,torch::RestrictPtrTraits,size_t> d_zeta,
52-
torch::PackedTensorAccessor<scalar_t,2,torch::RestrictPtrTraits,size_t> d_nu,
5356
torch::PackedTensorAccessor<scalar_t,2,torch::RestrictPtrTraits,size_t> d_precomp,
54-
torch::PackedTensorAccessor<scalar_t,2,torch::RestrictPtrTraits,size_t> d_bias_z,
55-
torch::PackedTensorAccessor<scalar_t,2,torch::RestrictPtrTraits,size_t> d_bias_h_prime_t,
5657
torch::PackedTensorAccessor<scalar_t,2,torch::RestrictPtrTraits,size_t> d_old_h,
58+
torch::PackedTensorAccessor<scalar_t,2,torch::RestrictPtrTraits,size_t> d_bias_z,
59+
torch::PackedTensorAccessor<scalar_t,2,torch::RestrictPtrTraits,size_t> d_bias_h_prime,
60+
torch::PackedTensorAccessor<scalar_t,2,torch::RestrictPtrTraits,size_t> d_nu,
61+
torch::PackedTensorAccessor<scalar_t,2,torch::RestrictPtrTraits,size_t> d_zeta,
5762
const torch::PackedTensorAccessor<scalar_t,2,torch::RestrictPtrTraits,size_t> grad_h,
58-
const torch::PackedTensorAccessor<scalar_t,2,torch::RestrictPtrTraits,size_t> old_h,
59-
const torch::PackedTensorAccessor<scalar_t,2,torch::RestrictPtrTraits,size_t> z_t,
60-
const torch::PackedTensorAccessor<scalar_t,2,torch::RestrictPtrTraits,size_t> h_prime_t,
61-
const torch::PackedTensorAccessor<scalar_t,2,torch::RestrictPtrTraits,size_t> pre_comp,
62-
const torch::PackedTensorAccessor<scalar_t,2,torch::RestrictPtrTraits,size_t> bias_z,
63-
const torch::PackedTensorAccessor<scalar_t,2,torch::RestrictPtrTraits,size_t> bias_h_prime,
63+
const torch::PackedTensorAccessor<scalar_t,2,torch::RestrictPtrTraits,size_t> z,
64+
const torch::PackedTensorAccessor<scalar_t,2,torch::RestrictPtrTraits,size_t> h_prime,
6465
const torch::PackedTensorAccessor<scalar_t,2,torch::RestrictPtrTraits,size_t> zeta,
65-
const torch::PackedTensorAccessor<scalar_t,2,torch::RestrictPtrTraits,size_t> nu) {
66-
//batch index
66+
const torch::PackedTensorAccessor<scalar_t,2,torch::RestrictPtrTraits,size_t> nu,
67+
const torch::PackedTensorAccessor<scalar_t,2,torch::RestrictPtrTraits,size_t> d_zeta_sigmoid,
68+
const torch::PackedTensorAccessor<scalar_t,2,torch::RestrictPtrTraits,size_t> d_nu_sigmoid,
69+
const torch::PackedTensorAccessor<scalar_t,2,torch::RestrictPtrTraits,size_t> old_h) {
6770
const int n = blockIdx.y;
68-
// column index
6971
const int c = blockIdx.x * blockDim.x + threadIdx.x;
70-
if (c < d_precomp.size(1)){
71-
auto temp_grad = grad_h[n][c] * h_prime_t[n][c];
72-
d_zeta[0][0] = temp_grad * (1 - z_t[n][c]) * d_sigmoid(zeta[0][0]);
73-
d_nu[0][0] = temp_grad * d_sigmoid(nu[0][0]);
74-
d_bias_z[n][c] = grad_h[n][c] * (sigmoid(zeta[0][0]) * -1 * h_prime_t[n][c] + old_h[n][c]) * d_sigmoid(pre_comp[n][c] + bias_z[n][c]);;
75-
d_bias_h_prime_t[n][c] = grad_h[n][c] * (sigmoid(zeta[0][0]) * (1 - z_t[n][c]) + sigmoid(nu[0][0])) * d_tanh(pre_comp[n][c] + bias_h_prime[n][c]);
76-
d_old_h[n][c] = grad_h[n][c] * z_t[n][c];
77-
d_precomp[n][c] = d_bias_z[n][c] + d_bias_h_prime_t[n][c];
72+
if (c < old_h.size(1)){
73+
d_old_h[n][c] = z[n][c] * grad_h[n][c];
74+
d_bias_h_prime[n][c] = (zeta[0][0] * (1.0 - z[n][c]) + nu[0][0]) * d_tanh(h_prime[n][c]) * grad_h[n][c];
75+
d_bias_z[n][c] = (old_h[n][c] - zeta[0][0] * h_prime[n][c]) * d_sigmoid(z[n][c]) * grad_h[n][c];
76+
d_precomp[n][c] = d_bias_z[n][c] + d_bias_h_prime[n][c];
77+
d_zeta[n][c] = (1.0 - z[n][c]) * h_prime[n][c]*grad_h[n][c] * d_zeta_sigmoid[0][0];
78+
d_nu[n][c] = h_prime[n][c] * grad_h[n][c] * d_nu_sigmoid[0][0];
7879
}
7980
}
8081
} // namespace
@@ -85,88 +86,86 @@ std::vector<torch::Tensor> fastgrnn_cuda_forward(
8586
torch::Tensor u,
8687
torch::Tensor bias_z,
8788
torch::Tensor bias_h_prime,
88-
torch::Tensor old_h,
8989
torch::Tensor zeta,
90-
torch::Tensor nu) {
91-
auto w_comp = torch::mm(input, w);
92-
auto u_comp = torch::mm(old_h, u);
93-
auto pre_comp = torch::add(u_comp, w_comp);
94-
90+
torch::Tensor nu,
91+
torch::Tensor old_h) {
92+
93+
auto pre_comp = torch::addmm(torch::mm(input, w.transpose(0, 1)), old_h, u.transpose(0, 1));
94+
nu = torch::sigmoid(nu);
95+
zeta = torch::sigmoid(zeta);
9596
const auto batch_size = old_h.size(0);
9697
const auto state_size = old_h.size(1);
97-
9898
auto new_h = torch::zeros_like(old_h);
99-
auto z_t = torch::zeros_like(old_h);
100-
auto h_prime_t = torch::zeros_like(old_h);
101-
99+
auto z = torch::zeros_like(old_h);
100+
auto h_prime = torch::zeros_like(old_h);
102101
const int threads = 1024;
103102
const dim3 blocks((state_size + threads - 1) / threads, batch_size);
104-
105103
AT_DISPATCH_FLOATING_TYPES(pre_comp.type(), "fastgrnn_forward_cuda", ([&] {
106104
fastgrnn_cuda_forward_kernel<scalar_t><<<blocks, threads>>>(
107-
pre_comp.packed_accessor<scalar_t,2,torch::RestrictPtrTraits,size_t>(),
108-
old_h.packed_accessor<scalar_t,2,torch::RestrictPtrTraits,size_t>(),
109105
new_h.packed_accessor<scalar_t,2,torch::RestrictPtrTraits,size_t>(),
110-
z_t.packed_accessor<scalar_t,2,torch::RestrictPtrTraits,size_t>(),
111-
h_prime_t.packed_accessor<scalar_t,2,torch::RestrictPtrTraits,size_t>(),
106+
z.packed_accessor<scalar_t,2,torch::RestrictPtrTraits,size_t>(),
107+
h_prime.packed_accessor<scalar_t,2,torch::RestrictPtrTraits,size_t>(),
108+
pre_comp.packed_accessor<scalar_t,2,torch::RestrictPtrTraits,size_t>(),
112109
bias_z.packed_accessor<scalar_t,2,torch::RestrictPtrTraits,size_t>(),
113110
bias_h_prime.packed_accessor<scalar_t,2,torch::RestrictPtrTraits,size_t>(),
111+
nu.packed_accessor<scalar_t,2,torch::RestrictPtrTraits,size_t>(),
114112
zeta.packed_accessor<scalar_t,2,torch::RestrictPtrTraits,size_t>(),
115-
nu.packed_accessor<scalar_t,2,torch::RestrictPtrTraits,size_t>());
113+
old_h.packed_accessor<scalar_t,2,torch::RestrictPtrTraits,size_t>());
116114
}));
117-
118-
return {new_h, z_t, h_prime_t, pre_comp};
115+
return {new_h, z, h_prime};
119116
}
120117

121118
std::vector<torch::Tensor> fastgrnn_cuda_backward(
122119
torch::Tensor grad_h,
123120
torch::Tensor input,
124121
torch::Tensor old_h,
125-
torch::Tensor z_t,
126-
torch::Tensor h_prime_t,
127-
torch::Tensor pre_comp,
122+
torch::Tensor zeta,
123+
torch::Tensor nu,
128124
torch::Tensor w,
129125
torch::Tensor u,
130-
torch::Tensor bias_z,
131-
torch::Tensor bias_h_prime,
132-
torch::Tensor zeta,
133-
torch::Tensor nu) {
134-
auto d_precomp = torch::zeros_like(pre_comp);
135-
auto d_old_h = torch::zeros_like(old_h);
136-
auto d_zeta = torch::zeros_like(zeta);
137-
auto d_nu = torch::zeros_like(nu);
138-
auto d_bias_z = torch::zeros_like(bias_z);
139-
auto d_bias_h_prime = torch::zeros_like(bias_h_prime);
140-
141-
const auto batch_size = old_h.size(0);
142-
const auto state_size = old_h.size(1);
143-
144-
const int threads = 1024;
145-
const dim3 blocks((state_size + threads - 1) / threads, batch_size);
146-
147-
AT_DISPATCH_FLOATING_TYPES(pre_comp.type(), "fastgrnn_forward_cuda", ([&] {
126+
torch::Tensor z,
127+
torch::Tensor h_prime) {
128+
auto d_precomp = torch::zeros_like(old_h);
129+
auto d_bias_z = torch::zeros_like(old_h);
130+
auto d_bias_h_prime = torch::zeros_like(old_h);
131+
auto d_nu = torch::zeros_like(old_h);
132+
auto d_zeta = torch::zeros_like(old_h);
133+
auto d_old_h = torch::zeros_like(old_h);
134+
zeta = torch::sigmoid(zeta);
135+
nu = torch::sigmoid(nu);
136+
auto d_nu_sigmoid = d_sigmoid(nu);
137+
auto d_zeta_sigmoid = d_sigmoid(zeta);
138+
const auto batch_size = old_h.size(0);
139+
const auto state_size = old_h.size(1);
140+
141+
const int threads = 1024;
142+
const dim3 blocks((state_size + threads - 1) / threads, batch_size);
143+
AT_DISPATCH_FLOATING_TYPES(old_h.type(), "fastgrnn_backward_cuda", ([&] {
148144
fastgrnn_cuda_backward_kernel<scalar_t><<<blocks, threads>>>(
149-
d_zeta.packed_accessor<scalar_t,2,torch::RestrictPtrTraits,size_t>(),
150-
d_nu.packed_accessor<scalar_t,2,torch::RestrictPtrTraits,size_t>(),
151145
d_precomp.packed_accessor<scalar_t,2,torch::RestrictPtrTraits,size_t>(),
146+
d_old_h.packed_accessor<scalar_t,2,torch::RestrictPtrTraits,size_t>(),
152147
d_bias_z.packed_accessor<scalar_t,2,torch::RestrictPtrTraits,size_t>(),
153148
d_bias_h_prime.packed_accessor<scalar_t,2,torch::RestrictPtrTraits,size_t>(),
154-
d_old_h.packed_accessor<scalar_t,2,torch::RestrictPtrTraits,size_t>(),
149+
d_nu.packed_accessor<scalar_t,2,torch::RestrictPtrTraits,size_t>(),
150+
d_zeta.packed_accessor<scalar_t,2,torch::RestrictPtrTraits,size_t>(),
155151
grad_h.packed_accessor<scalar_t,2,torch::RestrictPtrTraits,size_t>(),
156-
old_h.packed_accessor<scalar_t,2,torch::RestrictPtrTraits,size_t>(),
157-
z_t.packed_accessor<scalar_t,2,torch::RestrictPtrTraits,size_t>(),
158-
h_prime_t.packed_accessor<scalar_t,2,torch::RestrictPtrTraits,size_t>(),
159-
pre_comp.packed_accessor<scalar_t,2,torch::RestrictPtrTraits,size_t>(),
160-
bias_z.packed_accessor<scalar_t,2,torch::RestrictPtrTraits,size_t>(),
161-
bias_h_prime.packed_accessor<scalar_t,2,torch::RestrictPtrTraits,size_t>(),
152+
z.packed_accessor<scalar_t,2,torch::RestrictPtrTraits,size_t>(),
153+
h_prime.packed_accessor<scalar_t,2,torch::RestrictPtrTraits,size_t>(),
162154
zeta.packed_accessor<scalar_t,2,torch::RestrictPtrTraits,size_t>(),
163-
nu.packed_accessor<scalar_t,2,torch::RestrictPtrTraits,size_t>());
155+
nu.packed_accessor<scalar_t,2,torch::RestrictPtrTraits,size_t>(),
156+
d_zeta_sigmoid.packed_accessor<scalar_t,2,torch::RestrictPtrTraits,size_t>(),
157+
d_nu_sigmoid.packed_accessor<scalar_t,2,torch::RestrictPtrTraits,size_t>(),
158+
old_h.packed_accessor<scalar_t,2,torch::RestrictPtrTraits,size_t>());
164159
}));
165160

166-
d_old_h = torch::add(d_old_h, torch::mm(torch::add(d_bias_h_prime, d_bias_z), u.transpose(0, 1)));
167-
auto d_input = torch::mm(d_precomp, w.transpose(0, 1));
168-
auto d_w = torch::mm(input.transpose(0, 1), d_precomp);
169-
auto d_u = torch::mm(old_h.transpose(0, 1), d_precomp);
170-
171-
return {d_old_h, d_input, d_w, d_u, d_bias_z, d_bias_h_prime, d_nu, d_zeta};
161+
d_old_h = torch::addmm(d_old_h, d_precomp, u);
162+
auto d_input = torch::mm(d_precomp, w);
163+
auto d_w = torch::mm(d_precomp.transpose(0, 1), input);
164+
auto d_u = torch::mm(d_precomp.transpose(0, 1), old_h);
165+
d_bias_z = d_bias_z.sum(0, true);
166+
d_bias_h_prime = d_bias_h_prime.sum(0, true);
167+
d_zeta = (d_zeta.sum(0, true)).sum(1, true);
168+
d_nu = (d_nu.sum(0, true)).sum(1, true);
169+
170+
return {d_input, d_w, d_u, d_bias_z, d_bias_h_prime, d_zeta, d_nu, d_old_h};
172171
}

0 commit comments

Comments
 (0)