4
4
5
5
std::vector<torch::Tensor> fastgrnn_cuda_forward (
6
6
torch::Tensor input,
7
- torch::Tensor w ,
8
- torch::Tensor u ,
7
+ torch::Tensor W ,
8
+ torch::Tensor U ,
9
9
torch::Tensor bias_gate,
10
10
torch::Tensor bias_update,
11
11
torch::Tensor zeta,
12
12
torch::Tensor nu,
13
- torch::Tensor old_h);
13
+ torch::Tensor old_h,
14
+ int z_non_linearity);
14
15
15
16
std::vector<torch::Tensor> fastgrnn_cuda_backward (
16
17
torch::Tensor grad_h,
17
18
torch::Tensor input,
18
19
torch::Tensor old_h,
19
20
torch::Tensor zeta,
20
21
torch::Tensor nu,
21
- torch::Tensor w,
22
- torch::Tensor u,
22
+ torch::Tensor W,
23
+ torch::Tensor U,
24
+ int z_non_linearity,
23
25
torch::Tensor z,
24
26
torch::Tensor h_prime);
25
27
@@ -29,23 +31,24 @@ std::vector<torch::Tensor> fastgrnn_cuda_backward(
29
31
30
32
std::vector<torch::Tensor> fastgrnn_forward (
31
33
torch::Tensor input,
32
- torch::Tensor w ,
33
- torch::Tensor u ,
34
+ torch::Tensor W ,
35
+ torch::Tensor U ,
34
36
torch::Tensor bias_gate,
35
37
torch::Tensor bias_update,
36
38
torch::Tensor zeta,
37
39
torch::Tensor nu,
38
- torch::Tensor old_h) {
40
+ torch::Tensor old_h,
41
+ int z_non_linearity) {
39
42
CHECK_INPUT (input);
40
- CHECK_INPUT (w );
41
- CHECK_INPUT (u );
43
+ CHECK_INPUT (W );
44
+ CHECK_INPUT (U );
42
45
CHECK_INPUT (bias_gate);
43
46
CHECK_INPUT (bias_update);
44
47
CHECK_INPUT (zeta);
45
48
CHECK_INPUT (nu);
46
49
CHECK_INPUT (old_h);
47
50
48
- return fastgrnn_cuda_forward (input, w, u , bias_gate, bias_update, zeta, nu, old_h);
51
+ return fastgrnn_cuda_forward (input, W, U , bias_gate, bias_update, zeta, nu, old_h, z_non_linearity );
49
52
}
50
53
51
54
std::vector<torch::Tensor> fastgrnn_backward (
@@ -54,21 +57,22 @@ std::vector<torch::Tensor> fastgrnn_backward(
54
57
torch::Tensor old_h,
55
58
torch::Tensor zeta,
56
59
torch::Tensor nu,
57
- torch::Tensor w ,
58
- torch::Tensor u ,
60
+ torch::Tensor W ,
61
+ torch::Tensor U ,
59
62
torch::Tensor z,
60
- torch::Tensor h_prime) {
63
+ torch::Tensor h_prime,
64
+ int z_non_linearity) {
61
65
CHECK_INPUT (grad_h);
62
66
CHECK_INPUT (input);
63
67
CHECK_INPUT (old_h);
64
68
CHECK_INPUT (zeta);
65
69
CHECK_INPUT (nu);
66
70
CHECK_INPUT (z);
67
71
CHECK_INPUT (h_prime);
68
- CHECK_INPUT (w );
69
- CHECK_INPUT (u );
72
+ CHECK_INPUT (W );
73
+ CHECK_INPUT (U );
70
74
71
- return fastgrnn_cuda_backward (grad_h, input, old_h, zeta, nu, w, u , z, h_prime);
75
+ return fastgrnn_cuda_backward (grad_h, input, old_h, zeta, nu, W, U, z_non_linearity , z, h_prime);
72
76
}
73
77
74
78
PYBIND11_MODULE (TORCH_EXTENSION_NAME, m) {
0 commit comments