@@ -87,6 +87,37 @@ __global__ void fastgrnn_cuda_backward_kernel(
87
87
d_nu[n][c] = h_prime[n][c] * grad_h[n][c] * d_nu_sigmoid[0 ][0 ];
88
88
}
89
89
}
90
+
91
+ template <typename scalar_t , scalar_t (*d_non_linearity) (scalar_t )>
92
+ __global__ void fastgrnn_unroll_cuda_backward_kernel (
93
+ torch::PackedTensorAccessor<scalar_t ,2 ,torch::RestrictPtrTraits,size_t > d_precomp,
94
+ torch::PackedTensorAccessor<scalar_t ,2 ,torch::RestrictPtrTraits,size_t > d_old_h,
95
+ torch::PackedTensorAccessor<scalar_t ,2 ,torch::RestrictPtrTraits,size_t > d_bias_z,
96
+ torch::PackedTensorAccessor<scalar_t ,2 ,torch::RestrictPtrTraits,size_t > d_bias_h_prime,
97
+ torch::PackedTensorAccessor<scalar_t ,2 ,torch::RestrictPtrTraits,size_t > d_nu,
98
+ torch::PackedTensorAccessor<scalar_t ,2 ,torch::RestrictPtrTraits,size_t > d_zeta,
99
+ const torch::PackedTensorAccessor<scalar_t ,2 ,torch::RestrictPtrTraits,size_t > grad_h,
100
+ const torch::PackedTensorAccessor<scalar_t ,2 ,torch::RestrictPtrTraits,size_t > z,
101
+ const torch::PackedTensorAccessor<scalar_t ,2 ,torch::RestrictPtrTraits,size_t > h_prime,
102
+ const torch::PackedTensorAccessor<scalar_t ,2 ,torch::RestrictPtrTraits,size_t > zeta,
103
+ const torch::PackedTensorAccessor<scalar_t ,2 ,torch::RestrictPtrTraits,size_t > nu,
104
+ const torch::PackedTensorAccessor<scalar_t ,2 ,torch::RestrictPtrTraits,size_t > d_zeta_sigmoid,
105
+ const torch::PackedTensorAccessor<scalar_t ,2 ,torch::RestrictPtrTraits,size_t > d_nu_sigmoid,
106
+ const torch::PackedTensorAccessor<scalar_t ,2 ,torch::RestrictPtrTraits,size_t > old_h) {
107
+ const int n = blockIdx .y ;
108
+ const int c = blockIdx .x * blockDim .x + threadIdx .x ;
109
+ if (c < old_h.size (1 )){
110
+ d_old_h[n][c] = z[n][c] * grad_h[n][c];
111
+ scalar_t temp_bias_h_prime = (zeta[0 ][0 ] * (1.0 - z[n][c]) + nu[0 ][0 ]) * d_tanh (h_prime[n][c]) * grad_h[n][c];
112
+ scalar_t temp_bias_z = (old_h[n][c] - zeta[0 ][0 ] * h_prime[n][c]) * d_non_linearity (z[n][c]) * grad_h[n][c];
113
+ d_bias_h_prime[n][c] += temp_bias_h_prime;
114
+ d_bias_z[n][c] += temp_bias_z;
115
+ d_precomp[n][c] = temp_bias_z + temp_bias_h_prime;
116
+ d_zeta[n][c] += (1.0 - z[n][c]) * h_prime[n][c] * grad_h[n][c] * d_zeta_sigmoid[0 ][0 ];
117
+ d_nu[n][c] += h_prime[n][c] * grad_h[n][c] * d_nu_sigmoid[0 ][0 ];
118
+ }
119
+ }
120
+
90
121
} // namespace
91
122
92
123
std::vector<torch::Tensor> fastgrnn_cuda_forward (
@@ -246,3 +277,202 @@ std::vector<torch::Tensor> fastgrnn_cuda_backward(
246
277
247
278
return {d_input, d_w, d_u, d_bias_z, d_bias_h_prime, d_zeta, d_nu, d_old_h};
248
279
}
280
+
281
+ std::vector<torch::Tensor> fastgrnn_unroll_cuda_forward (
282
+ torch::Tensor input,
283
+ torch::Tensor w,
284
+ torch::Tensor u,
285
+ torch::Tensor bias_z,
286
+ torch::Tensor bias_h_prime,
287
+ torch::Tensor zeta,
288
+ torch::Tensor nu,
289
+ torch::Tensor initial_h,
290
+ int z_non_linearity) {
291
+ auto options = torch::TensorOptions ().dtype (input.dtype ()).device (input.device ().type ());
292
+ const auto timesteps = input.size (0 );
293
+ const auto batch_size = initial_h.size (0 );
294
+ const auto state_size = initial_h.size (1 );
295
+
296
+ auto hidden_states = torch::zeros ({timesteps, batch_size, state_size}, options);
297
+ auto z_s = torch::zeros_like (hidden_states);
298
+ auto h_prime_s = torch::zeros_like (hidden_states);
299
+
300
+ auto prev_h = initial_h;
301
+ auto new_h = torch::zeros_like (prev_h);
302
+ auto z = torch::zeros_like (prev_h);
303
+ auto h_prime = torch::zeros_like (prev_h);
304
+ auto pre_comp = torch::zeros_like (prev_h);
305
+
306
+ const int threads = 1024 ;
307
+ const dim3 blocks ((state_size + threads - 1 ) / threads, batch_size);
308
+
309
+ w = w.transpose (0 , 1 );
310
+ u = u.transpose (0 , 1 );
311
+ zeta = torch::sigmoid (zeta);
312
+ nu = torch::sigmoid (nu);
313
+
314
+ for (int t=0 ; t < timesteps; t++) {
315
+ pre_comp = torch::addmm (torch::mm (input[t], w), prev_h, u);
316
+
317
+ if (z_non_linearity == 0 )
318
+ AT_DISPATCH_FLOATING_TYPES (pre_comp.type (), " fastgrnn_forward_cuda" , ([&] {
319
+ fastgrnn_cuda_forward_kernel<scalar_t , sigmoid><<<blocks, threads>>> (
320
+ new_h.packed_accessor <scalar_t ,2 ,torch::RestrictPtrTraits,size_t >(),
321
+ z.packed_accessor <scalar_t ,2 ,torch::RestrictPtrTraits,size_t >(),
322
+ h_prime.packed_accessor <scalar_t ,2 ,torch::RestrictPtrTraits,size_t >(),
323
+ pre_comp.packed_accessor <scalar_t ,2 ,torch::RestrictPtrTraits,size_t >(),
324
+ bias_z.packed_accessor <scalar_t ,2 ,torch::RestrictPtrTraits,size_t >(),
325
+ bias_h_prime.packed_accessor <scalar_t ,2 ,torch::RestrictPtrTraits,size_t >(),
326
+ nu.packed_accessor <scalar_t ,2 ,torch::RestrictPtrTraits,size_t >(),
327
+ zeta.packed_accessor <scalar_t ,2 ,torch::RestrictPtrTraits,size_t >(),
328
+ prev_h.packed_accessor <scalar_t ,2 ,torch::RestrictPtrTraits,size_t >());
329
+ }));
330
+ else if (z_non_linearity == 1 )
331
+ AT_DISPATCH_FLOATING_TYPES (pre_comp.type (), " fastgrnn_forward_cuda" , ([&] {
332
+ fastgrnn_cuda_forward_kernel<scalar_t , relu><<<blocks, threads>>> (
333
+ new_h.packed_accessor <scalar_t ,2 ,torch::RestrictPtrTraits,size_t >(),
334
+ z.packed_accessor <scalar_t ,2 ,torch::RestrictPtrTraits,size_t >(),
335
+ h_prime.packed_accessor <scalar_t ,2 ,torch::RestrictPtrTraits,size_t >(),
336
+ pre_comp.packed_accessor <scalar_t ,2 ,torch::RestrictPtrTraits,size_t >(),
337
+ bias_z.packed_accessor <scalar_t ,2 ,torch::RestrictPtrTraits,size_t >(),
338
+ bias_h_prime.packed_accessor <scalar_t ,2 ,torch::RestrictPtrTraits,size_t >(),
339
+ nu.packed_accessor <scalar_t ,2 ,torch::RestrictPtrTraits,size_t >(),
340
+ zeta.packed_accessor <scalar_t ,2 ,torch::RestrictPtrTraits,size_t >(),
341
+ prev_h.packed_accessor <scalar_t ,2 ,torch::RestrictPtrTraits,size_t >());
342
+ }));
343
+ else if (z_non_linearity == 2 )
344
+ AT_DISPATCH_FLOATING_TYPES (pre_comp.type (), " fastgrnn_forward_cuda" , ([&] {
345
+ fastgrnn_cuda_forward_kernel<scalar_t , tanh><<<blocks, threads>>> (
346
+ new_h.packed_accessor <scalar_t ,2 ,torch::RestrictPtrTraits,size_t >(),
347
+ z.packed_accessor <scalar_t ,2 ,torch::RestrictPtrTraits,size_t >(),
348
+ h_prime.packed_accessor <scalar_t ,2 ,torch::RestrictPtrTraits,size_t >(),
349
+ pre_comp.packed_accessor <scalar_t ,2 ,torch::RestrictPtrTraits,size_t >(),
350
+ bias_z.packed_accessor <scalar_t ,2 ,torch::RestrictPtrTraits,size_t >(),
351
+ bias_h_prime.packed_accessor <scalar_t ,2 ,torch::RestrictPtrTraits,size_t >(),
352
+ nu.packed_accessor <scalar_t ,2 ,torch::RestrictPtrTraits,size_t >(),
353
+ zeta.packed_accessor <scalar_t ,2 ,torch::RestrictPtrTraits,size_t >(),
354
+ prev_h.packed_accessor <scalar_t ,2 ,torch::RestrictPtrTraits,size_t >());
355
+ }));
356
+ hidden_states[t] = new_h;
357
+ z_s[t] = z;
358
+ h_prime_s[t] = h_prime;
359
+ prev_h = new_h;
360
+ }
361
+ return {hidden_states, z_s, h_prime_s};
362
+ }
363
+
364
+ std::vector<torch::Tensor> fastgrnn_unroll_cuda_backward (
365
+ torch::Tensor grad_h,
366
+ torch::Tensor input,
367
+ torch::Tensor hidden_states,
368
+ torch::Tensor zeta,
369
+ torch::Tensor nu,
370
+ torch::Tensor w,
371
+ torch::Tensor u,
372
+ torch::Tensor z,
373
+ torch::Tensor h_prime,
374
+ torch::Tensor initial_h,
375
+ int z_non_linearity) {
376
+
377
+ auto d_input = torch::zeros_like (input);
378
+ auto d_w = torch::zeros_like (w);
379
+ auto d_u = torch::zeros_like (u);
380
+ auto d_zeta = torch::zeros_like (initial_h);
381
+ auto d_nu = torch::zeros_like (initial_h);
382
+ auto d_bias_z = torch::zeros_like (initial_h);
383
+ auto d_bias_h_prime = torch::zeros_like (initial_h);
384
+
385
+ zeta = torch::sigmoid (zeta);
386
+ nu = torch::sigmoid (nu);
387
+ auto d_nu_sigmoid = d_sigmoid (nu);
388
+ auto d_zeta_sigmoid = d_sigmoid (zeta);
389
+
390
+
391
+ auto grad_curr_h = torch::zeros_like (initial_h);
392
+ auto d_precomp = torch::zeros_like (initial_h);
393
+ auto d_old_h = torch::zeros_like (initial_h);
394
+ auto prev_h_ = hidden_states[0 ];
395
+ auto z_t_ = torch::zeros_like (initial_h);
396
+ auto h_prime_t_ = torch::zeros_like (initial_h);
397
+
398
+ const auto batch_size = hidden_states.size (1 );
399
+ const auto state_size = hidden_states.size (2 );
400
+
401
+ const int threads = 1024 ;
402
+ const dim3 blocks ((state_size + threads - 1 ) / threads, batch_size);
403
+ for (auto t = hidden_states.size (0 ) - 1 ; t>=0 ; t--) {
404
+ grad_curr_h = torch::add (grad_h[t], d_old_h);
405
+ z_t_ = z[t];
406
+ h_prime_t_ = h_prime[t];
407
+
408
+ if (t == 0 )
409
+ prev_h_ = initial_h;
410
+ else
411
+ prev_h_ = hidden_states[t-1 ];
412
+
413
+ if (z_non_linearity == 0 )
414
+ AT_DISPATCH_FLOATING_TYPES (z_t_.type (), " fastgrnn_forward_cuda" , ([&] {
415
+ fastgrnn_unroll_cuda_backward_kernel<scalar_t , d_sigmoid><<<blocks, threads>>> (
416
+ d_precomp.packed_accessor <scalar_t ,2 ,torch::RestrictPtrTraits,size_t >(),
417
+ d_old_h.packed_accessor <scalar_t ,2 ,torch::RestrictPtrTraits,size_t >(),
418
+ d_bias_z.packed_accessor <scalar_t ,2 ,torch::RestrictPtrTraits,size_t >(),
419
+ d_bias_h_prime.packed_accessor <scalar_t ,2 ,torch::RestrictPtrTraits,size_t >(),
420
+ d_nu.packed_accessor <scalar_t ,2 ,torch::RestrictPtrTraits,size_t >(),
421
+ d_zeta.packed_accessor <scalar_t ,2 ,torch::RestrictPtrTraits,size_t >(),
422
+ grad_curr_h.packed_accessor <scalar_t ,2 ,torch::RestrictPtrTraits,size_t >(),
423
+ z_t_.packed_accessor <scalar_t ,2 ,torch::RestrictPtrTraits,size_t >(),
424
+ h_prime_t_.packed_accessor <scalar_t ,2 ,torch::RestrictPtrTraits,size_t >(),
425
+ zeta.packed_accessor <scalar_t ,2 ,torch::RestrictPtrTraits,size_t >(),
426
+ nu.packed_accessor <scalar_t ,2 ,torch::RestrictPtrTraits,size_t >(),
427
+ d_zeta_sigmoid.packed_accessor <scalar_t ,2 ,torch::RestrictPtrTraits,size_t >(),
428
+ d_nu_sigmoid.packed_accessor <scalar_t ,2 ,torch::RestrictPtrTraits,size_t >(),
429
+ prev_h_.packed_accessor <scalar_t ,2 ,torch::RestrictPtrTraits,size_t >());
430
+ }));
431
+ else if (z_non_linearity == 1 )
432
+ AT_DISPATCH_FLOATING_TYPES (z_t_.type (), " fastgrnn_forward_cuda" , ([&] {
433
+ fastgrnn_unroll_cuda_backward_kernel<scalar_t , d_relu><<<blocks, threads>>> (
434
+ d_precomp.packed_accessor <scalar_t ,2 ,torch::RestrictPtrTraits,size_t >(),
435
+ d_old_h.packed_accessor <scalar_t ,2 ,torch::RestrictPtrTraits,size_t >(),
436
+ d_bias_z.packed_accessor <scalar_t ,2 ,torch::RestrictPtrTraits,size_t >(),
437
+ d_bias_h_prime.packed_accessor <scalar_t ,2 ,torch::RestrictPtrTraits,size_t >(),
438
+ d_nu.packed_accessor <scalar_t ,2 ,torch::RestrictPtrTraits,size_t >(),
439
+ d_zeta.packed_accessor <scalar_t ,2 ,torch::RestrictPtrTraits,size_t >(),
440
+ grad_curr_h.packed_accessor <scalar_t ,2 ,torch::RestrictPtrTraits,size_t >(),
441
+ z_t_.packed_accessor <scalar_t ,2 ,torch::RestrictPtrTraits,size_t >(),
442
+ h_prime_t_.packed_accessor <scalar_t ,2 ,torch::RestrictPtrTraits,size_t >(),
443
+ zeta.packed_accessor <scalar_t ,2 ,torch::RestrictPtrTraits,size_t >(),
444
+ nu.packed_accessor <scalar_t ,2 ,torch::RestrictPtrTraits,size_t >(),
445
+ d_zeta_sigmoid.packed_accessor <scalar_t ,2 ,torch::RestrictPtrTraits,size_t >(),
446
+ d_nu_sigmoid.packed_accessor <scalar_t ,2 ,torch::RestrictPtrTraits,size_t >(),
447
+ prev_h_.packed_accessor <scalar_t ,2 ,torch::RestrictPtrTraits,size_t >());
448
+ }));
449
+ else if (z_non_linearity == 2 )
450
+ AT_DISPATCH_FLOATING_TYPES (z_t_.type (), " fastgrnn_forward_cuda" , ([&] {
451
+ fastgrnn_unroll_cuda_backward_kernel<scalar_t , d_sigmoid><<<blocks, threads>>> (
452
+ d_precomp.packed_accessor <scalar_t ,2 ,torch::RestrictPtrTraits,size_t >(),
453
+ d_old_h.packed_accessor <scalar_t ,2 ,torch::RestrictPtrTraits,size_t >(),
454
+ d_bias_z.packed_accessor <scalar_t ,2 ,torch::RestrictPtrTraits,size_t >(),
455
+ d_bias_h_prime.packed_accessor <scalar_t ,2 ,torch::RestrictPtrTraits,size_t >(),
456
+ d_nu.packed_accessor <scalar_t ,2 ,torch::RestrictPtrTraits,size_t >(),
457
+ d_zeta.packed_accessor <scalar_t ,2 ,torch::RestrictPtrTraits,size_t >(),
458
+ grad_curr_h.packed_accessor <scalar_t ,2 ,torch::RestrictPtrTraits,size_t >(),
459
+ z_t_.packed_accessor <scalar_t ,2 ,torch::RestrictPtrTraits,size_t >(),
460
+ h_prime_t_.packed_accessor <scalar_t ,2 ,torch::RestrictPtrTraits,size_t >(),
461
+ zeta.packed_accessor <scalar_t ,2 ,torch::RestrictPtrTraits,size_t >(),
462
+ nu.packed_accessor <scalar_t ,2 ,torch::RestrictPtrTraits,size_t >(),
463
+ d_zeta_sigmoid.packed_accessor <scalar_t ,2 ,torch::RestrictPtrTraits,size_t >(),
464
+ d_nu_sigmoid.packed_accessor <scalar_t ,2 ,torch::RestrictPtrTraits,size_t >(),
465
+ prev_h_.packed_accessor <scalar_t ,2 ,torch::RestrictPtrTraits,size_t >());
466
+ }));
467
+ d_old_h = torch::addmm (d_old_h, d_precomp, u);
468
+ d_input[t] = torch::mm (d_precomp, w);
469
+ d_w = torch::addmm (d_w, d_precomp.transpose (0 , 1 ), input[t]);
470
+ d_u = torch::addmm (d_u, d_precomp.transpose (0 , 1 ), prev_h_);
471
+ // grad_curr_h = d_old_h;
472
+ }
473
+ d_bias_z = d_bias_z.sum (0 , true );
474
+ d_bias_h_prime = d_bias_h_prime.sum (0 , true );
475
+ d_zeta = (d_zeta.sum (0 , true )).sum (1 , true );
476
+ d_nu = (d_nu.sum (0 , true )).sum (1 , true );
477
+ return {d_input, d_w, d_u, d_bias_z, d_bias_h_prime, d_zeta, d_nu, d_old_h};
478
+ }
0 commit comments