@@ -12,6 +12,8 @@ module m_cuda_backend
1212 use m_cuda_sendrecv, only: sendrecv_fields, sendrecv_3fields
1313 use m_cuda_tdsops, only: cuda_tdsops_t
1414 use m_cuda_kernels_dist, only: transeq_3fused_dist, transeq_3fused_subs
15+ use m_cuda_kernels_trans, only: trans_x2y_k, trans_x2z_k, trans_y2x_k, &
16+ trans_y2z_k, trans_z2y_k
1517
1618 implicit none
1719
@@ -74,6 +76,10 @@ function init(globs, allocator) result(backend)
7476 backend% zthreads = dim3(SZ, 1 , 1 )
7577 backend% zblocks = dim3(globs% n_groups_z, 1 , 1 )
7678
79+ backend% nx_loc = globs% nx_loc
80+ backend% ny_loc = globs% ny_loc
81+ backend% nz_loc = globs% nz_loc
82+
7783 n_halo = 4
7884 n_block = globs% n_groups_x
7985
@@ -422,6 +428,17 @@ subroutine trans_x2y_cuda(self, u_y, u)
422428 class(field_t), intent (inout ) :: u_y
423429 class(field_t), intent (in ) :: u
424430
431+ real (dp), device, pointer , dimension (:, :, :) :: u_d, u_y_d
432+ type (dim3) :: blocks, threads
433+
434+ select type (u); type is (cuda_field_t); u_d = > u% data_d; end select
435+ select type (u_y); type is (cuda_field_t); u_y_d = > u_y% data_d; end select
436+
437+ blocks = dim3(self% nx_loc/ SZ, self% nz_loc, self% ny_loc/ SZ)
438+ threads = dim3(SZ, SZ, 1 )
439+
440+ call trans_x2y_k<<<blocks, threads>>>(u_y_d, u_d, self% nz_loc)
441+
425442 end subroutine trans_x2y_cuda
426443
427444 subroutine trans_x2z_cuda (self , u_z , u )
@@ -431,6 +448,17 @@ subroutine trans_x2z_cuda(self, u_z, u)
431448 class(field_t), intent (inout ) :: u_z
432449 class(field_t), intent (in ) :: u
433450
451+ real (dp), device, pointer , dimension (:, :, :) :: u_d, u_z_d
452+ type (dim3) :: blocks, threads
453+
454+ select type (u); type is (cuda_field_t); u_d = > u% data_d; end select
455+ select type (u_z); type is (cuda_field_t); u_z_d = > u_z% data_d; end select
456+
457+ blocks = dim3(self% nx_loc, self% ny_loc/ SZ, 1 )
458+ threads = dim3(SZ, 1 , 1 )
459+
460+ call trans_x2z_k<<<blocks, threads>>>(u_z_d, u_d, self% nz_loc)
461+
434462 end subroutine trans_x2z_cuda
435463
436464 subroutine trans_y2z_cuda (self , u_z , u_y )
@@ -440,6 +468,18 @@ subroutine trans_y2z_cuda(self, u_z, u_y)
440468 class(field_t), intent (inout ) :: u_z
441469 class(field_t), intent (in ) :: u_y
442470
471+ real (dp), device, pointer , dimension (:, :, :) :: u_z_d, u_y_d
472+ type (dim3) :: blocks, threads
473+
474+ select type (u_z); type is (cuda_field_t); u_z_d = > u_z% data_d; end select
475+ select type (u_y); type is (cuda_field_t); u_y_d = > u_y% data_d; end select
476+
477+ blocks = dim3(self% nx_loc/ SZ, self% ny_loc/ SZ, self% nz_loc)
478+ threads = dim3(SZ, SZ, 1 )
479+
480+ call trans_y2z_k<<<blocks, threads>>>(u_z_d, u_y_d, &
481+ self% nx_loc, self% nz_loc)
482+
443483 end subroutine trans_y2z_cuda
444484
445485 subroutine trans_z2y_cuda (self , u_y , u_z )
@@ -449,6 +489,18 @@ subroutine trans_z2y_cuda(self, u_y, u_z)
449489 class(field_t), intent (inout ) :: u_y
450490 class(field_t), intent (in ) :: u_z
451491
492+ real (dp), device, pointer , dimension (:, :, :) :: u_y_d, u_z_d
493+ type (dim3) :: blocks, threads
494+
495+ select type (u_y); type is (cuda_field_t); u_y_d = > u_y% data_d; end select
496+ select type (u_z); type is (cuda_field_t); u_z_d = > u_z% data_d; end select
497+
498+ blocks = dim3(self% nx_loc/ SZ, self% ny_loc/ SZ, self% nz_loc)
499+ threads = dim3(SZ, SZ, 1 )
500+
501+ call trans_z2y_k<<<blocks, threads>>>(u_y_d, u_z_d, &
502+ self% nx_loc, self% nz_loc)
503+
452504 end subroutine trans_z2y_cuda
453505
454506 subroutine trans_y2x_cuda (self , u_x , u_y )
@@ -458,6 +510,17 @@ subroutine trans_y2x_cuda(self, u_x, u_y)
458510 class(field_t), intent (inout ) :: u_x
459511 class(field_t), intent (in ) :: u_y
460512
513+ real (dp), device, pointer , dimension (:, :, :) :: u_x_d, u_y_d
514+ type (dim3) :: blocks, threads
515+
516+ select type (u_x); type is (cuda_field_t); u_x_d = > u_x% data_d; end select
517+ select type (u_y); type is (cuda_field_t); u_y_d = > u_y% data_d; end select
518+
519+ blocks = dim3(self% nx_loc/ SZ, self% ny_loc/ SZ, self% nz_loc)
520+ threads = dim3(SZ, SZ, 1 )
521+
522+ call trans_y2x_k<<<blocks, threads>>>(u_x_d, u_y_d, self% nz_loc)
523+
461524 end subroutine trans_y2x_cuda
462525
463526 subroutine sum_yzintox_cuda (self , du , dv , dw , &
0 commit comments