@@ -8,8 +8,8 @@ module m_cuda_backend
88
99 use m_cuda_allocator, only: cuda_allocator_t, cuda_field_t
1010 use m_cuda_common, only: SZ
11- use m_cuda_exec_dist, only: exec_dist_transeq_3fused
12- use m_cuda_sendrecv, only: sendrecv_3fields
11+ use m_cuda_exec_dist, only: exec_dist_transeq_3fused, exec_dist_tds_compact
12+ use m_cuda_sendrecv, only: sendrecv_fields, sendrecv_3fields
1313 use m_cuda_tdsops, only: cuda_tdsops_t
1414 use m_cuda_kernels_dist, only: transeq_3fused_dist, transeq_3fused_subs
1515
@@ -31,13 +31,15 @@ module m_cuda_backend
3131 procedure :: transeq_x = > transeq_x_cuda
3232 procedure :: transeq_y = > transeq_y_cuda
3333 procedure :: transeq_z = > transeq_z_cuda
34+ procedure :: tds_solve = > tds_solve_cuda
3435 procedure :: trans_x2y = > trans_x2y_cuda
3536 procedure :: trans_x2z = > trans_x2z_cuda
3637 procedure :: sum_yzintox = > sum_yzintox_cuda
3738 procedure :: set_fields = > set_fields_cuda
3839 procedure :: get_fields = > get_fields_cuda
3940 procedure :: transeq_cuda_dist
4041 procedure :: transeq_cuda_thom
42+ procedure :: tds_solve_dist
4143 end type cuda_backend_t
4244
4345 interface cuda_backend_t
@@ -343,6 +345,64 @@ subroutine transeq_cuda_thom(self, du, dv, dw, u, v, w, dirps)
343345
344346 end subroutine transeq_cuda_thom
345347
348+ subroutine tds_solve_cuda (self , du , u , dirps , tdsops )
349+ implicit none
350+
351+ class(cuda_backend_t) :: self
352+ class(field_t), intent (inout ) :: du
353+ class(field_t), intent (in ) :: u
354+ type (dirps_t), intent (in ) :: dirps
355+ class(tdsops_t), intent (in ) :: tdsops
356+
357+ type (dim3) :: blocks, threads
358+
359+ blocks = dim3(dirps% n_blocks, 1 , 1 ); threads = dim3(SZ, 1 , 1 )
360+
361+ call tds_solve_dist(self, du, u, dirps, tdsops, blocks, threads)
362+
363+ end subroutine tds_solve_cuda
364+
365+ subroutine tds_solve_dist (self , du , u , dirps , tdsops , blocks , threads )
366+ implicit none
367+
368+ class(cuda_backend_t) :: self
369+ class(field_t), intent (inout ) :: du
370+ class(field_t), intent (in ) :: u
371+ type (dirps_t), intent (in ) :: dirps
372+ class(tdsops_t), intent (in ) :: tdsops
373+ type (dim3), intent (in ) :: blocks, threads
374+
375+ real (dp), device, pointer , dimension (:, :, :) :: du_dev, u_dev
376+
377+ type (cuda_tdsops_t), pointer :: tdsops_dev
378+
379+ select type (du); type is (cuda_field_t); du_dev = > du% data_d; end select
380+ select type (u); type is (cuda_field_t); u_dev = > u% data_d; end select
381+
382+ select type (tdsops)
383+ type is (cuda_tdsops_t); tdsops_dev = > tdsops
384+ end select
385+
386+ call copy_into_buffers(self% u_send_s_dev, self% u_send_e_dev, u_dev, &
387+ tdsops_dev% n)
388+
389+ call sendrecv_fields(self% u_recv_s_dev, self% u_recv_e_dev, &
390+ self% u_send_s_dev, self% u_send_e_dev, &
391+ SZ* 4 * blocks% x, dirps% nproc, &
392+ dirps% pprev, dirps% pnext)
393+
394+ ! call exec_dist
395+ call exec_dist_tds_compact( &
396+ du_dev, u_dev, &
397+ self% u_recv_s_dev, self% u_recv_e_dev, &
398+ self% du_send_s_dev, self% du_send_e_dev, &
399+ self% du_recv_s_dev, self% du_recv_e_dev, &
400+ tdsops_dev, dirps% nproc, dirps% pprev, dirps% pnext, &
401+ blocks, threads &
402+ )
403+
404+ end subroutine tds_solve_dist
405+
346406 subroutine trans_x2y_cuda (self , u_y , v_y , w_y , u , v , w )
347407 implicit none
348408
0 commit comments