|
| 1 | +#include <torch/extension.h> |
| 2 | + |
| 3 | +#include <vector> |
| 4 | + |
| 5 | +std::vector<torch::Tensor> fastgrnn_cuda_forward( |
| 6 | + torch::Tensor input, |
| 7 | + torch::Tensor w, |
| 8 | + torch::Tensor u, |
| 9 | + torch::Tensor bias_gate, |
| 10 | + torch::Tensor bias_update, |
| 11 | + torch::Tensor zeta, |
| 12 | + torch::Tensor nu, |
| 13 | + torch::Tensor old_h, |
| 14 | + int z_non_linearity, |
| 15 | + torch::Tensor w1, |
| 16 | + torch::Tensor w2, |
| 17 | + torch::Tensor u1, |
| 18 | + torch::Tensor u2); |
| 19 | + |
| 20 | +std::vector<torch::Tensor> fastgrnn_cuda_backward( |
| 21 | + torch::Tensor grad_h, |
| 22 | + torch::Tensor input, |
| 23 | + torch::Tensor old_h, |
| 24 | + torch::Tensor zeta, |
| 25 | + torch::Tensor nu, |
| 26 | + torch::Tensor w, |
| 27 | + torch::Tensor u, |
| 28 | + int z_non_linearity, |
| 29 | + torch::Tensor z, |
| 30 | + torch::Tensor h_prime, |
| 31 | + torch::Tensor w1, |
| 32 | + torch::Tensor w2, |
| 33 | + torch::Tensor u1, |
| 34 | + torch::Tensor u2); |
| 35 | + |
| 36 | +std::vector<torch::Tensor> fastgrnn_unroll_cuda_forward( |
| 37 | + torch::Tensor input, |
| 38 | + torch::Tensor w, |
| 39 | + torch::Tensor u, |
| 40 | + torch::Tensor bias_z, |
| 41 | + torch::Tensor bias_h_prime, |
| 42 | + torch::Tensor zeta, |
| 43 | + torch::Tensor nu, |
| 44 | + torch::Tensor initial_h, |
| 45 | + int z_non_linearity, |
| 46 | + torch::Tensor w1, |
| 47 | + torch::Tensor w2, |
| 48 | + torch::Tensor u1, |
| 49 | + torch::Tensor u2); |
| 50 | + |
| 51 | +std::vector<torch::Tensor> fastgrnn_unroll_cuda_backward( |
| 52 | + torch::Tensor grad_h, |
| 53 | + torch::Tensor input, |
| 54 | + torch::Tensor hidden_states, |
| 55 | + torch::Tensor zeta, |
| 56 | + torch::Tensor nu, |
| 57 | + torch::Tensor w, |
| 58 | + torch::Tensor u, |
| 59 | + torch::Tensor z, |
| 60 | + torch::Tensor h_prime, |
| 61 | + torch::Tensor initial_h, |
| 62 | + int z_non_linearity, |
| 63 | + torch::Tensor w1, |
| 64 | + torch::Tensor w2, |
| 65 | + torch::Tensor u1, |
| 66 | + torch::Tensor u2); |
| 67 | + |
| 68 | +#define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor") |
| 69 | +#define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous") |
| 70 | +#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) |
| 71 | + |
| 72 | +std::vector<torch::Tensor> fastgrnn_forward( |
| 73 | + torch::Tensor input, |
| 74 | + torch::Tensor w, |
| 75 | + torch::Tensor u, |
| 76 | + torch::Tensor bias_gate, |
| 77 | + torch::Tensor bias_update, |
| 78 | + torch::Tensor zeta, |
| 79 | + torch::Tensor nu, |
| 80 | + torch::Tensor old_h, |
| 81 | + int z_non_linearity, |
| 82 | + torch::Tensor w1, |
| 83 | + torch::Tensor w2, |
| 84 | + torch::Tensor u1, |
| 85 | + torch::Tensor u2) { |
| 86 | + CHECK_INPUT(input); |
| 87 | + if(w1.size(0) == 0) { |
| 88 | + CHECK_INPUT(w); |
| 89 | + } else { |
| 90 | + CHECK_INPUT(w1); |
| 91 | + CHECK_INPUT(w2); |
| 92 | + } |
| 93 | + if (u1.size(0) == 0) { |
| 94 | + CHECK_INPUT(u); |
| 95 | + } else { |
| 96 | + CHECK_INPUT(u1); |
| 97 | + CHECK_INPUT(u2); |
| 98 | + } |
| 99 | + CHECK_INPUT(bias_gate); |
| 100 | + CHECK_INPUT(bias_update); |
| 101 | + CHECK_INPUT(zeta); |
| 102 | + CHECK_INPUT(nu); |
| 103 | + CHECK_INPUT(old_h); |
| 104 | + |
| 105 | + return fastgrnn_cuda_forward(input, w, u, bias_gate, bias_update, zeta, nu, old_h, z_non_linearity, w1, w2, u1, u2); |
| 106 | +} |
| 107 | + |
| 108 | +std::vector<torch::Tensor> fastgrnn_backward( |
| 109 | + torch::Tensor grad_h, |
| 110 | + torch::Tensor input, |
| 111 | + torch::Tensor old_h, |
| 112 | + torch::Tensor zeta, |
| 113 | + torch::Tensor nu, |
| 114 | + torch::Tensor w, |
| 115 | + torch::Tensor u, |
| 116 | + torch::Tensor z, |
| 117 | + torch::Tensor h_prime, |
| 118 | + torch::Tensor w1, |
| 119 | + torch::Tensor w2, |
| 120 | + torch::Tensor u1, |
| 121 | + torch::Tensor u2, |
| 122 | + int z_non_linearity) { |
| 123 | + CHECK_INPUT(grad_h); |
| 124 | + CHECK_INPUT(input); |
| 125 | + CHECK_INPUT(old_h); |
| 126 | + CHECK_INPUT(zeta); |
| 127 | + CHECK_INPUT(nu); |
| 128 | + CHECK_INPUT(z); |
| 129 | + CHECK_INPUT(h_prime); |
| 130 | + if(w1.size(0) == 0) { |
| 131 | + CHECK_INPUT(w); |
| 132 | + } else { |
| 133 | + CHECK_INPUT(w1); |
| 134 | + CHECK_INPUT(w2); |
| 135 | + } |
| 136 | + if (u1.size(0) == 0) { |
| 137 | + CHECK_INPUT(u); |
| 138 | + } else { |
| 139 | + CHECK_INPUT(u1); |
| 140 | + CHECK_INPUT(u2); |
| 141 | + } |
| 142 | + |
| 143 | + return fastgrnn_cuda_backward(grad_h, input, old_h, zeta, nu, w, u, z_non_linearity, z, h_prime, w1, w2, u1, u2); |
| 144 | +} |
| 145 | + |
| 146 | +std::vector<torch::Tensor> fastgrnn_unroll_forward( |
| 147 | + torch::Tensor input, |
| 148 | + torch::Tensor w, |
| 149 | + torch::Tensor u, |
| 150 | + torch::Tensor bias_z, |
| 151 | + torch::Tensor bias_h_prime, |
| 152 | + torch::Tensor zeta, |
| 153 | + torch::Tensor nu, |
| 154 | + torch::Tensor initial_h, |
| 155 | + int z_non_linearity, |
| 156 | + torch::Tensor w1, |
| 157 | + torch::Tensor w2, |
| 158 | + torch::Tensor u1, |
| 159 | + torch::Tensor u2) { |
| 160 | + CHECK_INPUT(input); |
| 161 | + if(w1.size(0) == 0) { |
| 162 | + CHECK_INPUT(w); |
| 163 | + } else { |
| 164 | + CHECK_INPUT(w1); |
| 165 | + CHECK_INPUT(w2); |
| 166 | + } |
| 167 | + if (u1.size(0) == 0) { |
| 168 | + CHECK_INPUT(u); |
| 169 | + } else { |
| 170 | + CHECK_INPUT(u1); |
| 171 | + CHECK_INPUT(u2); |
| 172 | + } |
| 173 | + CHECK_INPUT(bias_z); |
| 174 | + CHECK_INPUT(bias_h_prime); |
| 175 | + CHECK_INPUT(initial_h); |
| 176 | + CHECK_INPUT(zeta); |
| 177 | + CHECK_INPUT(nu); |
| 178 | + return fastgrnn_unroll_cuda_forward(input, w, u, bias_z, bias_h_prime, zeta, nu, initial_h, z_non_linearity, w1, w2, u1, u2); |
| 179 | +} |
| 180 | + |
| 181 | +std::vector<torch::Tensor> fastgrnn_unroll_backward( |
| 182 | + torch::Tensor grad_h, |
| 183 | + torch::Tensor input, |
| 184 | + torch::Tensor hidden_states, |
| 185 | + torch::Tensor zeta, |
| 186 | + torch::Tensor nu, |
| 187 | + torch::Tensor w, |
| 188 | + torch::Tensor u, |
| 189 | + torch::Tensor z, |
| 190 | + torch::Tensor h_prime, |
| 191 | + torch::Tensor initial_h, |
| 192 | + torch::Tensor w1, |
| 193 | + torch::Tensor w2, |
| 194 | + torch::Tensor u1, |
| 195 | + torch::Tensor u2, |
| 196 | + int z_non_linearity) { |
| 197 | + CHECK_INPUT(grad_h); |
| 198 | + CHECK_INPUT(input); |
| 199 | + CHECK_INPUT(hidden_states); |
| 200 | + CHECK_INPUT(z); |
| 201 | + CHECK_INPUT(h_prime); |
| 202 | + if(w1.size(0) == 0) { |
| 203 | + CHECK_INPUT(w); |
| 204 | + } else { |
| 205 | + CHECK_INPUT(w1); |
| 206 | + CHECK_INPUT(w2); |
| 207 | + } |
| 208 | + if (u1.size(0) == 0) { |
| 209 | + CHECK_INPUT(u); |
| 210 | + } else { |
| 211 | + CHECK_INPUT(u1); |
| 212 | + CHECK_INPUT(u2); |
| 213 | + } |
| 214 | + CHECK_INPUT(zeta); |
| 215 | + CHECK_INPUT(nu); |
| 216 | + CHECK_INPUT(initial_h); |
| 217 | + |
| 218 | + return fastgrnn_unroll_cuda_backward( |
| 219 | + grad_h, |
| 220 | + input, |
| 221 | + hidden_states, |
| 222 | + zeta, |
| 223 | + nu, |
| 224 | + w, |
| 225 | + u, |
| 226 | + z, |
| 227 | + h_prime, |
| 228 | + initial_h, |
| 229 | + z_non_linearity, |
| 230 | + w1, w2, u1, u2); |
| 231 | +} |
| 232 | + |
| 233 | + |
| 234 | +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { |
| 235 | + m.def("forward", &fastgrnn_forward, "FastGRNN forward (CUDA)"); |
| 236 | + m.def("backward", &fastgrnn_backward, "FastGRNN backward (CUDA)"); |
| 237 | + m.def("forward_unroll", &fastgrnn_unroll_forward, "FastGRNN Unrolled forward (CUDA)"); |
| 238 | + m.def("backward_unroll", &fastgrnn_unroll_backward, "FastGRNN Unrolled backward (CUDA)"); |
| 239 | +} |
0 commit comments