|
| 1 | +#include <torch/extension.h> |
| 2 | + |
| 3 | +#include <cuda.h> |
| 4 | +#include <cuda_runtime.h> |
| 5 | + |
| 6 | +#include <vector> |
| 7 | + |
| 8 | +namespace { |
| 9 | +template <typename scalar_t> |
| 10 | +__device__ __forceinline__ scalar_t sigmoid(scalar_t z) { |
| 11 | + return 1.0 / (1.0 + exp(-z)); |
| 12 | +} |
| 13 | + |
| 14 | +template <typename scalar_t> |
| 15 | +__device__ __forceinline__ scalar_t d_sigmoid(scalar_t z) { |
| 16 | + const auto s = sigmoid(z); |
| 17 | + return (1.0 - s) * s; |
| 18 | +} |
| 19 | + |
| 20 | +template <typename scalar_t> |
| 21 | +__device__ __forceinline__ scalar_t d_tanh(scalar_t z) { |
| 22 | + const auto t = tanh(z); |
| 23 | + return 1 - (t * t); |
| 24 | +} |
| 25 | + |
| 26 | +template <typename scalar_t> |
| 27 | +__global__ void fastgrnn_cuda_forward_kernel( |
| 28 | + const torch::PackedTensorAccessor<scalar_t,2,torch::RestrictPtrTraits,size_t> pre_comp, |
| 29 | + const torch::PackedTensorAccessor<scalar_t,2,torch::RestrictPtrTraits,size_t> old_h, |
| 30 | + torch::PackedTensorAccessor<scalar_t,2,torch::RestrictPtrTraits,size_t> new_h, |
| 31 | + torch::PackedTensorAccessor<scalar_t,2,torch::RestrictPtrTraits,size_t> z_t, |
| 32 | + torch::PackedTensorAccessor<scalar_t,2,torch::RestrictPtrTraits,size_t> h_prime_t, |
| 33 | + torch::PackedTensorAccessor<scalar_t,2,torch::RestrictPtrTraits,size_t> bias_z, |
| 34 | + torch::PackedTensorAccessor<scalar_t,2,torch::RestrictPtrTraits,size_t> bias_h_prime, |
| 35 | + torch::PackedTensorAccessor<scalar_t,2,torch::RestrictPtrTraits,size_t> zeta, |
| 36 | + torch::PackedTensorAccessor<scalar_t,2,torch::RestrictPtrTraits,size_t> nu) { |
| 37 | + //batch index |
| 38 | + const int n = blockIdx.y; |
| 39 | + // column index |
| 40 | + const int c = blockIdx.x * blockDim.x + threadIdx.x; |
| 41 | + if (c < pre_comp.size(1)){ |
| 42 | + z_t[n][c] = sigmoid(pre_comp[n][c] + bias_z[n][c]); |
| 43 | + h_prime_t[n][c] = tanh(pre_comp[n][c] + bias_h_prime[n][c]); |
| 44 | + |
| 45 | + new_h[n][c] = (sigmoid(zeta[0][0]) * (1 - z_t[n][c]) + sigmoid(nu[0][0])) * h_prime_t[n][c] + z_t[n][c] * old_h[n][c]; |
| 46 | + } |
| 47 | +} |
| 48 | + |
| 49 | +template <typename scalar_t> |
| 50 | +__global__ void fastgrnn_cuda_backward_kernel( |
| 51 | + torch::PackedTensorAccessor<scalar_t,2,torch::RestrictPtrTraits,size_t> d_zeta, |
| 52 | + torch::PackedTensorAccessor<scalar_t,2,torch::RestrictPtrTraits,size_t> d_nu, |
| 53 | + torch::PackedTensorAccessor<scalar_t,2,torch::RestrictPtrTraits,size_t> d_precomp, |
| 54 | + torch::PackedTensorAccessor<scalar_t,2,torch::RestrictPtrTraits,size_t> d_bias_z, |
| 55 | + torch::PackedTensorAccessor<scalar_t,2,torch::RestrictPtrTraits,size_t> d_bias_h_prime_t, |
| 56 | + torch::PackedTensorAccessor<scalar_t,2,torch::RestrictPtrTraits,size_t> d_old_h, |
| 57 | + const torch::PackedTensorAccessor<scalar_t,2,torch::RestrictPtrTraits,size_t> grad_h, |
| 58 | + const torch::PackedTensorAccessor<scalar_t,2,torch::RestrictPtrTraits,size_t> old_h, |
| 59 | + const torch::PackedTensorAccessor<scalar_t,2,torch::RestrictPtrTraits,size_t> z_t, |
| 60 | + const torch::PackedTensorAccessor<scalar_t,2,torch::RestrictPtrTraits,size_t> h_prime_t, |
| 61 | + const torch::PackedTensorAccessor<scalar_t,2,torch::RestrictPtrTraits,size_t> pre_comp, |
| 62 | + const torch::PackedTensorAccessor<scalar_t,2,torch::RestrictPtrTraits,size_t> bias_z, |
| 63 | + const torch::PackedTensorAccessor<scalar_t,2,torch::RestrictPtrTraits,size_t> bias_h_prime, |
| 64 | + const torch::PackedTensorAccessor<scalar_t,2,torch::RestrictPtrTraits,size_t> zeta, |
| 65 | + const torch::PackedTensorAccessor<scalar_t,2,torch::RestrictPtrTraits,size_t> nu) { |
| 66 | + //batch index |
| 67 | + const int n = blockIdx.y; |
| 68 | + // column index |
| 69 | + const int c = blockIdx.x * blockDim.x + threadIdx.x; |
| 70 | + if (c < d_precomp.size(1)){ |
| 71 | + auto temp_grad = grad_h[n][c] * h_prime_t[n][c]; |
| 72 | + d_zeta[0][0] = temp_grad * (1 - z_t[n][c]) * d_sigmoid(zeta[0][0]); |
| 73 | + d_nu[0][0] = temp_grad * d_sigmoid(nu[0][0]); |
| 74 | + d_bias_z[n][c] = grad_h[n][c] * (sigmoid(zeta[0][0]) * -1 * h_prime_t[n][c] + old_h[n][c]) * d_sigmoid(pre_comp[n][c] + bias_z[n][c]);; |
| 75 | + d_bias_h_prime_t[n][c] = grad_h[n][c] * (sigmoid(zeta[0][0]) * (1 - z_t[n][c]) + sigmoid(nu[0][0])) * d_tanh(pre_comp[n][c] + bias_h_prime[n][c]); |
| 76 | + d_old_h[n][c] = grad_h[n][c] * z_t[n][c]; |
| 77 | + d_precomp[n][c] = d_bias_z[n][c] + d_bias_h_prime_t[n][c]; |
| 78 | + } |
| 79 | +} |
| 80 | +} // namespace |
| 81 | + |
| 82 | +std::vector<torch::Tensor> fastgrnn_cuda_forward( |
| 83 | + torch::Tensor input, |
| 84 | + torch::Tensor w, |
| 85 | + torch::Tensor u, |
| 86 | + torch::Tensor bias_z, |
| 87 | + torch::Tensor bias_h_prime, |
| 88 | + torch::Tensor old_h, |
| 89 | + torch::Tensor zeta, |
| 90 | + torch::Tensor nu) { |
| 91 | + auto w_comp = torch::mm(input, w); |
| 92 | + auto u_comp = torch::mm(old_h, u); |
| 93 | + auto pre_comp = torch::add(u_comp, w_comp); |
| 94 | + |
| 95 | + const auto batch_size = old_h.size(0); |
| 96 | + const auto state_size = old_h.size(1); |
| 97 | + |
| 98 | + auto new_h = torch::zeros_like(old_h); |
| 99 | + auto z_t = torch::zeros_like(old_h); |
| 100 | + auto h_prime_t = torch::zeros_like(old_h); |
| 101 | + |
| 102 | + const int threads = 1024; |
| 103 | + const dim3 blocks((state_size + threads - 1) / threads, batch_size); |
| 104 | + |
| 105 | + AT_DISPATCH_FLOATING_TYPES(pre_comp.type(), "fastgrnn_forward_cuda", ([&] { |
| 106 | + fastgrnn_cuda_forward_kernel<scalar_t><<<blocks, threads>>>( |
| 107 | + pre_comp.packed_accessor<scalar_t,2,torch::RestrictPtrTraits,size_t>(), |
| 108 | + old_h.packed_accessor<scalar_t,2,torch::RestrictPtrTraits,size_t>(), |
| 109 | + new_h.packed_accessor<scalar_t,2,torch::RestrictPtrTraits,size_t>(), |
| 110 | + z_t.packed_accessor<scalar_t,2,torch::RestrictPtrTraits,size_t>(), |
| 111 | + h_prime_t.packed_accessor<scalar_t,2,torch::RestrictPtrTraits,size_t>(), |
| 112 | + bias_z.packed_accessor<scalar_t,2,torch::RestrictPtrTraits,size_t>(), |
| 113 | + bias_h_prime.packed_accessor<scalar_t,2,torch::RestrictPtrTraits,size_t>(), |
| 114 | + zeta.packed_accessor<scalar_t,2,torch::RestrictPtrTraits,size_t>(), |
| 115 | + nu.packed_accessor<scalar_t,2,torch::RestrictPtrTraits,size_t>()); |
| 116 | + })); |
| 117 | + |
| 118 | + return {new_h, z_t, h_prime_t, pre_comp}; |
| 119 | +} |
| 120 | + |
| 121 | +std::vector<torch::Tensor> fastgrnn_cuda_backward( |
| 122 | + torch::Tensor grad_h, |
| 123 | + torch::Tensor input, |
| 124 | + torch::Tensor old_h, |
| 125 | + torch::Tensor z_t, |
| 126 | + torch::Tensor h_prime_t, |
| 127 | + torch::Tensor pre_comp, |
| 128 | + torch::Tensor w, |
| 129 | + torch::Tensor u, |
| 130 | + torch::Tensor bias_z, |
| 131 | + torch::Tensor bias_h_prime, |
| 132 | + torch::Tensor zeta, |
| 133 | + torch::Tensor nu) { |
| 134 | + auto d_precomp = torch::zeros_like(pre_comp); |
| 135 | + auto d_old_h = torch::zeros_like(old_h); |
| 136 | + auto d_zeta = torch::zeros_like(zeta); |
| 137 | + auto d_nu = torch::zeros_like(nu); |
| 138 | + auto d_bias_z = torch::zeros_like(bias_z); |
| 139 | + auto d_bias_h_prime = torch::zeros_like(bias_h_prime); |
| 140 | + |
| 141 | + const auto batch_size = old_h.size(0); |
| 142 | + const auto state_size = old_h.size(1); |
| 143 | + |
| 144 | + const int threads = 1024; |
| 145 | + const dim3 blocks((state_size + threads - 1) / threads, batch_size); |
| 146 | + |
| 147 | + AT_DISPATCH_FLOATING_TYPES(pre_comp.type(), "fastgrnn_forward_cuda", ([&] { |
| 148 | + fastgrnn_cuda_backward_kernel<scalar_t><<<blocks, threads>>>( |
| 149 | + d_zeta.packed_accessor<scalar_t,2,torch::RestrictPtrTraits,size_t>(), |
| 150 | + d_nu.packed_accessor<scalar_t,2,torch::RestrictPtrTraits,size_t>(), |
| 151 | + d_precomp.packed_accessor<scalar_t,2,torch::RestrictPtrTraits,size_t>(), |
| 152 | + d_bias_z.packed_accessor<scalar_t,2,torch::RestrictPtrTraits,size_t>(), |
| 153 | + d_bias_h_prime.packed_accessor<scalar_t,2,torch::RestrictPtrTraits,size_t>(), |
| 154 | + d_old_h.packed_accessor<scalar_t,2,torch::RestrictPtrTraits,size_t>(), |
| 155 | + grad_h.packed_accessor<scalar_t,2,torch::RestrictPtrTraits,size_t>(), |
| 156 | + old_h.packed_accessor<scalar_t,2,torch::RestrictPtrTraits,size_t>(), |
| 157 | + z_t.packed_accessor<scalar_t,2,torch::RestrictPtrTraits,size_t>(), |
| 158 | + h_prime_t.packed_accessor<scalar_t,2,torch::RestrictPtrTraits,size_t>(), |
| 159 | + pre_comp.packed_accessor<scalar_t,2,torch::RestrictPtrTraits,size_t>(), |
| 160 | + bias_z.packed_accessor<scalar_t,2,torch::RestrictPtrTraits,size_t>(), |
| 161 | + bias_h_prime.packed_accessor<scalar_t,2,torch::RestrictPtrTraits,size_t>(), |
| 162 | + zeta.packed_accessor<scalar_t,2,torch::RestrictPtrTraits,size_t>(), |
| 163 | + nu.packed_accessor<scalar_t,2,torch::RestrictPtrTraits,size_t>()); |
| 164 | + })); |
| 165 | + |
| 166 | + d_old_h = torch::add(d_old_h, torch::mm(torch::add(d_bias_h_prime, d_bias_z), u.transpose(0, 1))); |
| 167 | + auto d_input = torch::mm(d_precomp, w.transpose(0, 1)); |
| 168 | + auto d_w = torch::mm(input.transpose(0, 1), d_precomp); |
| 169 | + auto d_u = torch::mm(old_h.transpose(0, 1), d_precomp); |
| 170 | + |
| 171 | + return {d_old_h, d_input, d_w, d_u, d_bias_z, d_bias_h_prime, d_nu, d_zeta}; |
| 172 | +} |
0 commit comments