Skip to content

Commit e4ce97f

Browse files
committed
fastgrnncuda: adds support for different non-linearities
1 parent 61ce245 commit e4ce97f

File tree

3 files changed

+200
-117
lines changed

3 files changed

+200
-117
lines changed

pytorch/edgeml_pytorch/cuda/fastgrnn_cuda.cpp

Lines changed: 21 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -4,22 +4,24 @@
44

55
std::vector<torch::Tensor> fastgrnn_cuda_forward(
66
torch::Tensor input,
7-
torch::Tensor w,
8-
torch::Tensor u,
7+
torch::Tensor W,
8+
torch::Tensor U,
99
torch::Tensor bias_gate,
1010
torch::Tensor bias_update,
1111
torch::Tensor zeta,
1212
torch::Tensor nu,
13-
torch::Tensor old_h);
13+
torch::Tensor old_h,
14+
int z_non_linearity);
1415

1516
std::vector<torch::Tensor> fastgrnn_cuda_backward(
1617
torch::Tensor grad_h,
1718
torch::Tensor input,
1819
torch::Tensor old_h,
1920
torch::Tensor zeta,
2021
torch::Tensor nu,
21-
torch::Tensor w,
22-
torch::Tensor u,
22+
torch::Tensor W,
23+
torch::Tensor U,
24+
int z_non_linearity,
2325
torch::Tensor z,
2426
torch::Tensor h_prime);
2527

@@ -29,23 +31,24 @@ std::vector<torch::Tensor> fastgrnn_cuda_backward(
2931

3032
std::vector<torch::Tensor> fastgrnn_forward(
3133
torch::Tensor input,
32-
torch::Tensor w,
33-
torch::Tensor u,
34+
torch::Tensor W,
35+
torch::Tensor U,
3436
torch::Tensor bias_gate,
3537
torch::Tensor bias_update,
3638
torch::Tensor zeta,
3739
torch::Tensor nu,
38-
torch::Tensor old_h) {
40+
torch::Tensor old_h,
41+
int z_non_linearity) {
3942
CHECK_INPUT(input);
40-
CHECK_INPUT(w);
41-
CHECK_INPUT(u);
43+
CHECK_INPUT(W);
44+
CHECK_INPUT(U);
4245
CHECK_INPUT(bias_gate);
4346
CHECK_INPUT(bias_update);
4447
CHECK_INPUT(zeta);
4548
CHECK_INPUT(nu);
4649
CHECK_INPUT(old_h);
4750

48-
return fastgrnn_cuda_forward(input, w, u, bias_gate, bias_update, zeta, nu, old_h);
51+
return fastgrnn_cuda_forward(input, W, U, bias_gate, bias_update, zeta, nu, old_h, z_non_linearity);
4952
}
5053

5154
std::vector<torch::Tensor> fastgrnn_backward(
@@ -54,21 +57,22 @@ std::vector<torch::Tensor> fastgrnn_backward(
5457
torch::Tensor old_h,
5558
torch::Tensor zeta,
5659
torch::Tensor nu,
57-
torch::Tensor w,
58-
torch::Tensor u,
60+
torch::Tensor W,
61+
torch::Tensor U,
5962
torch::Tensor z,
60-
torch::Tensor h_prime) {
63+
torch::Tensor h_prime,
64+
int z_non_linearity) {
6165
CHECK_INPUT(grad_h);
6266
CHECK_INPUT(input);
6367
CHECK_INPUT(old_h);
6468
CHECK_INPUT(zeta);
6569
CHECK_INPUT(nu);
6670
CHECK_INPUT(z);
6771
CHECK_INPUT(h_prime);
68-
CHECK_INPUT(w);
69-
CHECK_INPUT(u);
72+
CHECK_INPUT(W);
73+
CHECK_INPUT(U);
7074

71-
return fastgrnn_cuda_backward(grad_h, input, old_h, zeta, nu, w, u, z, h_prime);
75+
return fastgrnn_cuda_backward(grad_h, input, old_h, zeta, nu, W, U, z_non_linearity, z, h_prime);
7276
}
7377

7478
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {

0 commit comments

Comments
 (0)