5
5
6
6
#include < vector>
7
7
8
+ __forceinline__ torch::Tensor d_sigmoid (torch::Tensor z) {
9
+ return (1 - z) * z;
10
+ }
11
+
12
+ __forceinline__ torch::Tensor d_tanh (torch::Tensor z) {
13
+ return 1 - z.pow (2 );
14
+ }
15
+
16
+
8
17
namespace {
9
18
template <typename scalar_t >
10
19
__device__ __forceinline__ scalar_t sigmoid (scalar_t z) {
11
20
return 1.0 / (1.0 + exp (-z));
12
21
}
13
22
14
23
template <typename scalar_t >
15
- __device__ __forceinline__ scalar_t d_sigmoid (scalar_t z) {
16
- const auto s = sigmoid (z);
17
- return (1.0 - s) * s;
24
+ __device__ __forceinline__ scalar_t d_sigmoid (scalar_t sig_z) {
25
+ return (1.0 - sig_z) * sig_z;
18
26
}
19
27
20
28
template <typename scalar_t >
21
- __device__ __forceinline__ scalar_t d_tanh (scalar_t z) {
22
- const auto t = tanh (z);
23
- return 1 - (t * t);
29
+ __device__ __forceinline__ scalar_t d_tanh (scalar_t tan_z) {
30
+ return 1 - (tan_z * tan_z);
24
31
}
25
32
26
33
template <typename scalar_t >
27
34
__global__ void fastgrnn_cuda_forward_kernel (
35
+ torch::PackedTensorAccessor<scalar_t ,2 ,torch::RestrictPtrTraits,size_t > new_h,
36
+ torch::PackedTensorAccessor<scalar_t ,2 ,torch::RestrictPtrTraits,size_t > z,
37
+ torch::PackedTensorAccessor<scalar_t ,2 ,torch::RestrictPtrTraits,size_t > h_prime,
28
38
const torch::PackedTensorAccessor<scalar_t ,2 ,torch::RestrictPtrTraits,size_t > pre_comp,
29
- const torch::PackedTensorAccessor<scalar_t ,2 ,torch::RestrictPtrTraits,size_t > old_h,
30
- torch::PackedTensorAccessor<scalar_t ,2 ,torch::RestrictPtrTraits,size_t > new_h,
31
- torch::PackedTensorAccessor<scalar_t ,2 ,torch::RestrictPtrTraits,size_t > z_t ,
32
- torch::PackedTensorAccessor<scalar_t ,2 ,torch::RestrictPtrTraits,size_t > h_prime_t ,
33
- torch::PackedTensorAccessor<scalar_t ,2 ,torch::RestrictPtrTraits,size_t > bias_z,
34
- torch::PackedTensorAccessor<scalar_t ,2 ,torch::RestrictPtrTraits,size_t > bias_h_prime,
35
- torch::PackedTensorAccessor<scalar_t ,2 ,torch::RestrictPtrTraits,size_t > zeta,
36
- torch::PackedTensorAccessor<scalar_t ,2 ,torch::RestrictPtrTraits,size_t > nu) {
37
- // batch index
39
+ const torch::PackedTensorAccessor<scalar_t ,2 ,torch::RestrictPtrTraits,size_t > bias_z,
40
+ const torch::PackedTensorAccessor<scalar_t ,2 ,torch::RestrictPtrTraits,size_t > bias_h_prime,
41
+ const torch::PackedTensorAccessor<scalar_t ,2 ,torch::RestrictPtrTraits,size_t > nu,
42
+ const torch::PackedTensorAccessor<scalar_t ,2 ,torch::RestrictPtrTraits,size_t > zeta,
43
+ const torch::PackedTensorAccessor<scalar_t ,2 ,torch::RestrictPtrTraits,size_t > old_h) {
38
44
const int n = blockIdx .y ;
39
- // column index
40
45
const int c = blockIdx .x * blockDim .x + threadIdx .x ;
41
- if (c < pre_comp.size (1 )){
42
- z_t [n][c] = sigmoid (pre_comp[n][c] + bias_z[n][c]);
43
- h_prime_t [n][c] = tanh (pre_comp[n][c] + bias_h_prime[n][c]);
44
-
45
- new_h[n][c] = (sigmoid (zeta[0 ][0 ]) * (1 - z_t [n][c]) + sigmoid (nu[0 ][0 ])) * h_prime_t [n][c] + z_t [n][c] * old_h[n][c];
46
+ if (c < old_h.size (1 )){
47
+ z[n][c] = sigmoid (pre_comp[n][c] + bias_z[0 ][c]);
48
+ h_prime[n][c] = tanh (pre_comp[n][c] + bias_h_prime[0 ][c]);
49
+ new_h[n][c] = (zeta[0 ][0 ] * (1.0 - z[n][c]) + nu[0 ][0 ]) * h_prime[n][c] + old_h[n][c] * z[n][c];
46
50
}
47
51
}
48
52
53
+
49
54
template <typename scalar_t >
50
55
__global__ void fastgrnn_cuda_backward_kernel (
51
- torch::PackedTensorAccessor<scalar_t ,2 ,torch::RestrictPtrTraits,size_t > d_zeta,
52
- torch::PackedTensorAccessor<scalar_t ,2 ,torch::RestrictPtrTraits,size_t > d_nu,
53
56
torch::PackedTensorAccessor<scalar_t ,2 ,torch::RestrictPtrTraits,size_t > d_precomp,
54
- torch::PackedTensorAccessor<scalar_t ,2 ,torch::RestrictPtrTraits,size_t > d_bias_z,
55
- torch::PackedTensorAccessor<scalar_t ,2 ,torch::RestrictPtrTraits,size_t > d_bias_h_prime_t ,
56
57
torch::PackedTensorAccessor<scalar_t ,2 ,torch::RestrictPtrTraits,size_t > d_old_h,
58
+ torch::PackedTensorAccessor<scalar_t ,2 ,torch::RestrictPtrTraits,size_t > d_bias_z,
59
+ torch::PackedTensorAccessor<scalar_t ,2 ,torch::RestrictPtrTraits,size_t > d_bias_h_prime,
60
+ torch::PackedTensorAccessor<scalar_t ,2 ,torch::RestrictPtrTraits,size_t > d_nu,
61
+ torch::PackedTensorAccessor<scalar_t ,2 ,torch::RestrictPtrTraits,size_t > d_zeta,
57
62
const torch::PackedTensorAccessor<scalar_t ,2 ,torch::RestrictPtrTraits,size_t > grad_h,
58
- const torch::PackedTensorAccessor<scalar_t ,2 ,torch::RestrictPtrTraits,size_t > old_h,
59
- const torch::PackedTensorAccessor<scalar_t ,2 ,torch::RestrictPtrTraits,size_t > z_t ,
60
- const torch::PackedTensorAccessor<scalar_t ,2 ,torch::RestrictPtrTraits,size_t > h_prime_t ,
61
- const torch::PackedTensorAccessor<scalar_t ,2 ,torch::RestrictPtrTraits,size_t > pre_comp,
62
- const torch::PackedTensorAccessor<scalar_t ,2 ,torch::RestrictPtrTraits,size_t > bias_z,
63
- const torch::PackedTensorAccessor<scalar_t ,2 ,torch::RestrictPtrTraits,size_t > bias_h_prime,
63
+ const torch::PackedTensorAccessor<scalar_t ,2 ,torch::RestrictPtrTraits,size_t > z,
64
+ const torch::PackedTensorAccessor<scalar_t ,2 ,torch::RestrictPtrTraits,size_t > h_prime,
64
65
const torch::PackedTensorAccessor<scalar_t ,2 ,torch::RestrictPtrTraits,size_t > zeta,
65
- const torch::PackedTensorAccessor<scalar_t ,2 ,torch::RestrictPtrTraits,size_t > nu) {
66
- // batch index
66
+ const torch::PackedTensorAccessor<scalar_t ,2 ,torch::RestrictPtrTraits,size_t > nu,
67
+ const torch::PackedTensorAccessor<scalar_t ,2 ,torch::RestrictPtrTraits,size_t > d_zeta_sigmoid,
68
+ const torch::PackedTensorAccessor<scalar_t ,2 ,torch::RestrictPtrTraits,size_t > d_nu_sigmoid,
69
+ const torch::PackedTensorAccessor<scalar_t ,2 ,torch::RestrictPtrTraits,size_t > old_h) {
67
70
const int n = blockIdx .y ;
68
- // column index
69
71
const int c = blockIdx .x * blockDim .x + threadIdx .x ;
70
- if (c < d_precomp.size (1 )){
71
- auto temp_grad = grad_h[n][c] * h_prime_t [n][c];
72
- d_zeta[0 ][0 ] = temp_grad * (1 - z_t [n][c]) * d_sigmoid (zeta[0 ][0 ]);
73
- d_nu[0 ][0 ] = temp_grad * d_sigmoid (nu[0 ][0 ]);
74
- d_bias_z[n][c] = grad_h[n][c] * (sigmoid (zeta[0 ][0 ]) * -1 * h_prime_t [n][c] + old_h[n][c]) * d_sigmoid (pre_comp[n][c] + bias_z[n][c]);;
75
- d_bias_h_prime_t [n][c] = grad_h[n][c] * (sigmoid (zeta[0 ][0 ]) * (1 - z_t [n][c]) + sigmoid (nu[0 ][0 ])) * d_tanh (pre_comp[n][c] + bias_h_prime[n][c]);
76
- d_old_h[n][c] = grad_h[n][c] * z_t [n][c];
77
- d_precomp[n][c] = d_bias_z[n][c] + d_bias_h_prime_t [n][c];
72
+ if (c < old_h.size (1 )){
73
+ d_old_h[n][c] = z[n][c] * grad_h[n][c];
74
+ d_bias_h_prime[n][c] = (zeta[0 ][0 ] * (1.0 - z[n][c]) + nu[0 ][0 ]) * d_tanh (h_prime[n][c]) * grad_h[n][c];
75
+ d_bias_z[n][c] = (old_h[n][c] - zeta[0 ][0 ] * h_prime[n][c]) * d_sigmoid (z[n][c]) * grad_h[n][c];
76
+ d_precomp[n][c] = d_bias_z[n][c] + d_bias_h_prime[n][c];
77
+ d_zeta[n][c] = (1.0 - z[n][c]) * h_prime[n][c]*grad_h[n][c] * d_zeta_sigmoid[0 ][0 ];
78
+ d_nu[n][c] = h_prime[n][c] * grad_h[n][c] * d_nu_sigmoid[0 ][0 ];
78
79
}
79
80
}
80
81
} // namespace
@@ -85,88 +86,86 @@ std::vector<torch::Tensor> fastgrnn_cuda_forward(
85
86
torch::Tensor u,
86
87
torch::Tensor bias_z,
87
88
torch::Tensor bias_h_prime,
88
- torch::Tensor old_h,
89
89
torch::Tensor zeta,
90
- torch::Tensor nu) {
91
- auto w_comp = torch::mm (input, w);
92
- auto u_comp = torch::mm (old_h, u);
93
- auto pre_comp = torch::add (u_comp, w_comp);
94
-
90
+ torch::Tensor nu,
91
+ torch::Tensor old_h) {
92
+
93
+ auto pre_comp = torch::addmm (torch::mm (input, w.transpose (0 , 1 )), old_h, u.transpose (0 , 1 ));
94
+ nu = torch::sigmoid (nu);
95
+ zeta = torch::sigmoid (zeta);
95
96
const auto batch_size = old_h.size (0 );
96
97
const auto state_size = old_h.size (1 );
97
-
98
98
auto new_h = torch::zeros_like (old_h);
99
- auto z_t = torch::zeros_like (old_h);
100
- auto h_prime_t = torch::zeros_like (old_h);
101
-
99
+ auto z = torch::zeros_like (old_h);
100
+ auto h_prime = torch::zeros_like (old_h);
102
101
const int threads = 1024 ;
103
102
const dim3 blocks ((state_size + threads - 1 ) / threads, batch_size);
104
-
105
103
AT_DISPATCH_FLOATING_TYPES (pre_comp.type (), " fastgrnn_forward_cuda" , ([&] {
106
104
fastgrnn_cuda_forward_kernel<scalar_t ><<<blocks, threads>>> (
107
- pre_comp.packed_accessor <scalar_t ,2 ,torch::RestrictPtrTraits,size_t >(),
108
- old_h.packed_accessor <scalar_t ,2 ,torch::RestrictPtrTraits,size_t >(),
109
105
new_h.packed_accessor <scalar_t ,2 ,torch::RestrictPtrTraits,size_t >(),
110
- z_t .packed_accessor <scalar_t ,2 ,torch::RestrictPtrTraits,size_t >(),
111
- h_prime_t .packed_accessor <scalar_t ,2 ,torch::RestrictPtrTraits,size_t >(),
106
+ z.packed_accessor <scalar_t ,2 ,torch::RestrictPtrTraits,size_t >(),
107
+ h_prime.packed_accessor <scalar_t ,2 ,torch::RestrictPtrTraits,size_t >(),
108
+ pre_comp.packed_accessor <scalar_t ,2 ,torch::RestrictPtrTraits,size_t >(),
112
109
bias_z.packed_accessor <scalar_t ,2 ,torch::RestrictPtrTraits,size_t >(),
113
110
bias_h_prime.packed_accessor <scalar_t ,2 ,torch::RestrictPtrTraits,size_t >(),
111
+ nu.packed_accessor <scalar_t ,2 ,torch::RestrictPtrTraits,size_t >(),
114
112
zeta.packed_accessor <scalar_t ,2 ,torch::RestrictPtrTraits,size_t >(),
115
- nu .packed_accessor <scalar_t ,2 ,torch::RestrictPtrTraits,size_t >());
113
+ old_h .packed_accessor <scalar_t ,2 ,torch::RestrictPtrTraits,size_t >());
116
114
}));
117
-
118
- return {new_h, z_t , h_prime_t , pre_comp};
115
+ return {new_h, z, h_prime};
119
116
}
120
117
121
118
std::vector<torch::Tensor> fastgrnn_cuda_backward (
122
119
torch::Tensor grad_h,
123
120
torch::Tensor input,
124
121
torch::Tensor old_h,
125
- torch::Tensor z_t ,
126
- torch::Tensor h_prime_t ,
127
- torch::Tensor pre_comp,
122
+ torch::Tensor zeta,
123
+ torch::Tensor nu,
128
124
torch::Tensor w,
129
125
torch::Tensor u,
130
- torch::Tensor bias_z ,
131
- torch::Tensor bias_h_prime,
132
- torch::Tensor zeta,
133
- torch::Tensor nu) {
134
- auto d_precomp = torch::zeros_like (pre_comp );
135
- auto d_old_h = torch::zeros_like (old_h);
136
- auto d_zeta = torch::zeros_like (zeta );
137
- auto d_nu = torch::zeros_like (nu );
138
- auto d_bias_z = torch::zeros_like (bias_z );
139
- auto d_bias_h_prime = torch::zeros_like (bias_h_prime );
140
-
141
- const auto batch_size = old_h. size ( 0 );
142
- const auto state_size = old_h.size (1 );
143
-
144
- const int threads = 1024 ;
145
- const dim3 blocks ((state_size + threads - 1 ) / threads, batch_size) ;
146
-
147
- AT_DISPATCH_FLOATING_TYPES (pre_comp .type (), " fastgrnn_forward_cuda " , ([&] {
126
+ torch::Tensor z ,
127
+ torch::Tensor h_prime) {
128
+ auto d_precomp = torch::zeros_like (old_h);
129
+ auto d_bias_z = torch::zeros_like (old_h);
130
+ auto d_bias_h_prime = torch::zeros_like (old_h );
131
+ auto d_nu = torch::zeros_like (old_h);
132
+ auto d_zeta = torch::zeros_like (old_h );
133
+ auto d_old_h = torch::zeros_like (old_h );
134
+ zeta = torch::sigmoid (zeta );
135
+ nu = torch::sigmoid (nu );
136
+ auto d_nu_sigmoid = d_sigmoid (nu);
137
+ auto d_zeta_sigmoid = d_sigmoid (zeta );
138
+ const auto batch_size = old_h.size (0 );
139
+ const auto state_size = old_h. size ( 1 );
140
+
141
+ const int threads = 1024 ;
142
+ const dim3 blocks ((state_size + threads - 1 ) / threads, batch_size);
143
+ AT_DISPATCH_FLOATING_TYPES (old_h .type (), " fastgrnn_backward_cuda " , ([&] {
148
144
fastgrnn_cuda_backward_kernel<scalar_t ><<<blocks, threads>>> (
149
- d_zeta.packed_accessor <scalar_t ,2 ,torch::RestrictPtrTraits,size_t >(),
150
- d_nu.packed_accessor <scalar_t ,2 ,torch::RestrictPtrTraits,size_t >(),
151
145
d_precomp.packed_accessor <scalar_t ,2 ,torch::RestrictPtrTraits,size_t >(),
146
+ d_old_h.packed_accessor <scalar_t ,2 ,torch::RestrictPtrTraits,size_t >(),
152
147
d_bias_z.packed_accessor <scalar_t ,2 ,torch::RestrictPtrTraits,size_t >(),
153
148
d_bias_h_prime.packed_accessor <scalar_t ,2 ,torch::RestrictPtrTraits,size_t >(),
154
- d_old_h.packed_accessor <scalar_t ,2 ,torch::RestrictPtrTraits,size_t >(),
149
+ d_nu.packed_accessor <scalar_t ,2 ,torch::RestrictPtrTraits,size_t >(),
150
+ d_zeta.packed_accessor <scalar_t ,2 ,torch::RestrictPtrTraits,size_t >(),
155
151
grad_h.packed_accessor <scalar_t ,2 ,torch::RestrictPtrTraits,size_t >(),
156
- old_h.packed_accessor <scalar_t ,2 ,torch::RestrictPtrTraits,size_t >(),
157
- z_t .packed_accessor <scalar_t ,2 ,torch::RestrictPtrTraits,size_t >(),
158
- h_prime_t .packed_accessor <scalar_t ,2 ,torch::RestrictPtrTraits,size_t >(),
159
- pre_comp.packed_accessor <scalar_t ,2 ,torch::RestrictPtrTraits,size_t >(),
160
- bias_z.packed_accessor <scalar_t ,2 ,torch::RestrictPtrTraits,size_t >(),
161
- bias_h_prime.packed_accessor <scalar_t ,2 ,torch::RestrictPtrTraits,size_t >(),
152
+ z.packed_accessor <scalar_t ,2 ,torch::RestrictPtrTraits,size_t >(),
153
+ h_prime.packed_accessor <scalar_t ,2 ,torch::RestrictPtrTraits,size_t >(),
162
154
zeta.packed_accessor <scalar_t ,2 ,torch::RestrictPtrTraits,size_t >(),
163
- nu.packed_accessor <scalar_t ,2 ,torch::RestrictPtrTraits,size_t >());
155
+ nu.packed_accessor <scalar_t ,2 ,torch::RestrictPtrTraits,size_t >(),
156
+ d_zeta_sigmoid.packed_accessor <scalar_t ,2 ,torch::RestrictPtrTraits,size_t >(),
157
+ d_nu_sigmoid.packed_accessor <scalar_t ,2 ,torch::RestrictPtrTraits,size_t >(),
158
+ old_h.packed_accessor <scalar_t ,2 ,torch::RestrictPtrTraits,size_t >());
164
159
}));
165
160
166
- d_old_h = torch::add (d_old_h, torch::mm (torch::add (d_bias_h_prime, d_bias_z), u.transpose (0 , 1 )));
167
- auto d_input = torch::mm (d_precomp, w.transpose (0 , 1 ));
168
- auto d_w = torch::mm (input.transpose (0 , 1 ), d_precomp);
169
- auto d_u = torch::mm (old_h.transpose (0 , 1 ), d_precomp);
170
-
171
- return {d_old_h, d_input, d_w, d_u, d_bias_z, d_bias_h_prime, d_nu, d_zeta};
161
+ d_old_h = torch::addmm (d_old_h, d_precomp, u);
162
+ auto d_input = torch::mm (d_precomp, w);
163
+ auto d_w = torch::mm (d_precomp.transpose (0 , 1 ), input);
164
+ auto d_u = torch::mm (d_precomp.transpose (0 , 1 ), old_h);
165
+ d_bias_z = d_bias_z.sum (0 , true );
166
+ d_bias_h_prime = d_bias_h_prime.sum (0 , true );
167
+ d_zeta = (d_zeta.sum (0 , true )).sum (1 , true );
168
+ d_nu = (d_nu.sum (0 , true )).sum (1 , true );
169
+
170
+ return {d_input, d_w, d_u, d_bias_z, d_bias_h_prime, d_zeta, d_nu, d_old_h};
172
171
}
0 commit comments