@@ -397,6 +397,10 @@ std::vector<torch::Tensor> fastgrnn_unroll_cuda_backward(
397
397
auto d_nu = torch::zeros_like (initial_h);
398
398
auto d_bias_z = torch::zeros_like (initial_h);
399
399
auto d_bias_h_prime = torch::zeros_like (initial_h);
400
+ auto d_w1 = torch::empty (0 );
401
+ auto d_w2 = torch::empty (0 );
402
+ auto d_u1 = torch::empty (0 );
403
+ auto d_u2 = torch::empty (0 );
400
404
401
405
bool w_low_rank = w1.size (0 ) != 0 ;
402
406
bool u_low_rank = u1.size (0 ) != 0 ;
@@ -501,20 +505,14 @@ std::vector<torch::Tensor> fastgrnn_unroll_cuda_backward(
501
505
d_zeta = (d_zeta.sum (0 , true )).sum (1 , true );
502
506
d_nu = (d_nu.sum (0 , true )).sum (1 , true );
503
507
if (w_low_rank) {
504
- auto d_w1 = torch::mm (w2.transpose (0 , 1 ), d_w);
505
- auto d_w2 = torch::mm (d_w, w1.transpose (0 , 1 ));
508
+ d_w1 = torch::mm (w2.transpose (0 , 1 ), d_w);
509
+ d_w2 = torch::mm (d_w, w1.transpose (0 , 1 ));
506
510
d_w = torch::empty (0 );
507
- } else {
508
- auto d_w1 = torch::empty (0 );
509
- auto d_w2 = torch::empty (0 );
510
511
}
511
512
if (u_low_rank) {
512
- auto d_u1 = torch::mm (u2.transpose (0 , 1 ), d_u);
513
- auto d_u2 = torch::mm (d_u, u1.transpose (0 , 1 ));
513
+ d_u1 = torch::mm (u2.transpose (0 , 1 ), d_u);
514
+ d_u2 = torch::mm (d_u, u1.transpose (0 , 1 ));
514
515
d_u = torch::empty (0 );
515
- } else {
516
- auto d_u1 = torch::empty (0 );
517
- auto d_u2 = torch::empty (0 );
518
516
}
519
517
return {d_input, d_bias_z, d_bias_h_prime, d_zeta, d_nu, d_old_h, d_w, d_u, d_w1, d_w2, d_u1, d_u2};
520
518
}
0 commit comments