Skip to content

Commit cb1e26f

Browse files
committed
fastgrnncuda: fix gradient return
1 parent 4bcdf17 commit cb1e26f

File tree

1 file changed

+8
-10
lines changed

1 file changed

+8
-10
lines changed

pytorch/edgeml_pytorch/cuda/fastgrnn_cuda_kernel.cu

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -397,6 +397,10 @@ std::vector<torch::Tensor> fastgrnn_unroll_cuda_backward(
397397
auto d_nu = torch::zeros_like(initial_h);
398398
auto d_bias_z = torch::zeros_like(initial_h);
399399
auto d_bias_h_prime = torch::zeros_like(initial_h);
400+
auto d_w1 = torch::empty(0);
401+
auto d_w2 = torch::empty(0);
402+
auto d_u1 = torch::empty(0);
403+
auto d_u2 = torch::empty(0);
400404

401405
bool w_low_rank = w1.size(0) != 0;
402406
bool u_low_rank = u1.size(0) != 0;
@@ -501,20 +505,14 @@ std::vector<torch::Tensor> fastgrnn_unroll_cuda_backward(
501505
d_zeta = (d_zeta.sum(0, true)).sum(1, true);
502506
d_nu = (d_nu.sum(0, true)).sum(1, true);
503507
if (w_low_rank) {
504-
auto d_w1 = torch::mm(w2.transpose(0, 1), d_w);
505-
auto d_w2 = torch::mm(d_w, w1.transpose(0, 1));
508+
d_w1 = torch::mm(w2.transpose(0, 1), d_w);
509+
d_w2 = torch::mm(d_w, w1.transpose(0, 1));
506510
d_w = torch::empty(0);
507-
} else {
508-
auto d_w1 = torch::empty(0);
509-
auto d_w2 = torch::empty(0);
510511
}
511512
if(u_low_rank) {
512-
auto d_u1 = torch::mm(u2.transpose(0, 1), d_u);
513-
auto d_u2 = torch::mm(d_u, u1.transpose(0, 1));
513+
d_u1 = torch::mm(u2.transpose(0, 1), d_u);
514+
d_u2 = torch::mm(d_u, u1.transpose(0, 1));
514515
d_u = torch::empty(0);
515-
} else {
516-
auto d_u1 = torch::empty(0);
517-
auto d_u2 = torch::empty(0);
518516
}
519517
return {d_input, d_bias_z, d_bias_h_prime, d_zeta, d_nu, d_old_h, d_w, d_u, d_w1, d_w2, d_u1, d_u2};
520518
}

0 commit comments

Comments
 (0)