Skip to content

Commit 29677ec

Browse files
authored
Merge pull request xcompact3d#25 from semi-h/feature
Add an interface for solving a single tridiagonal system.
2 parents a97f7c4 + 68f95ff commit 29677ec

File tree

5 files changed

+103
-3
lines changed

5 files changed

+103
-3
lines changed

src/backend.f90

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ module m_base_backend
2727
procedure(transeq_ders), deferred :: transeq_x
2828
procedure(transeq_ders), deferred :: transeq_y
2929
procedure(transeq_ders), deferred :: transeq_z
30+
procedure(tds_solve), deferred :: tds_solve
3031
procedure(transposer), deferred :: trans_x2y
3132
procedure(transposer), deferred :: trans_x2z
3233
procedure(sum9into3), deferred :: sum_yzintox
@@ -54,6 +55,27 @@ subroutine transeq_ders(self, du, dv, dw, u, v, w, dirps)
5455
end subroutine transeq_ders
5556
end interface
5657

58+
abstract interface
59+
subroutine tds_solve(self, du, u, dirps, tdsops)
60+
!! transeq equation obtains the derivatives direction by
61+
!! direction, and the exact algorithm used to obtain these
62+
!! derivatives are decided at runtime. Backend implementations
63+
!! are responsible from directing calls to transeq_ders into
64+
!! the correct algorithm.
65+
import :: base_backend_t
66+
import :: field_t
67+
import :: dirps_t
68+
import :: tdsops_t
69+
implicit none
70+
71+
class(base_backend_t) :: self
72+
class(field_t), intent(inout) :: du
73+
class(field_t), intent(in) :: u
74+
type(dirps_t), intent(in) :: dirps
75+
class(tdsops_t), intent(in) :: tdsops
76+
end subroutine tds_solve
77+
end interface
78+
5779
abstract interface
5880
subroutine transposer(self, u_, v_, w_, u, v, w)
5981
!! transposer subroutines are straightforward, they rearrange

src/cuda/backend.f90

Lines changed: 62 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,8 @@ module m_cuda_backend
88

99
use m_cuda_allocator, only: cuda_allocator_t, cuda_field_t
1010
use m_cuda_common, only: SZ
11-
use m_cuda_exec_dist, only: exec_dist_transeq_3fused
12-
use m_cuda_sendrecv, only: sendrecv_3fields
11+
use m_cuda_exec_dist, only: exec_dist_transeq_3fused, exec_dist_tds_compact
12+
use m_cuda_sendrecv, only: sendrecv_fields, sendrecv_3fields
1313
use m_cuda_tdsops, only: cuda_tdsops_t
1414
use m_cuda_kernels_dist, only: transeq_3fused_dist, transeq_3fused_subs
1515

@@ -31,13 +31,15 @@ module m_cuda_backend
3131
procedure :: transeq_x => transeq_x_cuda
3232
procedure :: transeq_y => transeq_y_cuda
3333
procedure :: transeq_z => transeq_z_cuda
34+
procedure :: tds_solve => tds_solve_cuda
3435
procedure :: trans_x2y => trans_x2y_cuda
3536
procedure :: trans_x2z => trans_x2z_cuda
3637
procedure :: sum_yzintox => sum_yzintox_cuda
3738
procedure :: set_fields => set_fields_cuda
3839
procedure :: get_fields => get_fields_cuda
3940
procedure :: transeq_cuda_dist
4041
procedure :: transeq_cuda_thom
42+
procedure :: tds_solve_dist
4143
end type cuda_backend_t
4244

4345
interface cuda_backend_t
@@ -343,6 +345,64 @@ subroutine transeq_cuda_thom(self, du, dv, dw, u, v, w, dirps)
343345

344346
end subroutine transeq_cuda_thom
345347

348+
subroutine tds_solve_cuda(self, du, u, dirps, tdsops)
349+
implicit none
350+
351+
class(cuda_backend_t) :: self
352+
class(field_t), intent(inout) :: du
353+
class(field_t), intent(in) :: u
354+
type(dirps_t), intent(in) :: dirps
355+
class(tdsops_t), intent(in) :: tdsops
356+
357+
type(dim3) :: blocks, threads
358+
359+
blocks = dim3(dirps%n_blocks, 1, 1); threads = dim3(SZ, 1, 1)
360+
361+
call tds_solve_dist(self, du, u, dirps, tdsops, blocks, threads)
362+
363+
end subroutine tds_solve_cuda
364+
365+
subroutine tds_solve_dist(self, du, u, dirps, tdsops, blocks, threads)
366+
implicit none
367+
368+
class(cuda_backend_t) :: self
369+
class(field_t), intent(inout) :: du
370+
class(field_t), intent(in) :: u
371+
type(dirps_t), intent(in) :: dirps
372+
class(tdsops_t), intent(in) :: tdsops
373+
type(dim3), intent(in) :: blocks, threads
374+
375+
real(dp), device, pointer, dimension(:, :, :) :: du_dev, u_dev
376+
377+
type(cuda_tdsops_t), pointer :: tdsops_dev
378+
379+
select type(du); type is (cuda_field_t); du_dev => du%data_d; end select
380+
select type(u); type is (cuda_field_t); u_dev => u%data_d; end select
381+
382+
select type (tdsops)
383+
type is (cuda_tdsops_t); tdsops_dev => tdsops
384+
end select
385+
386+
call copy_into_buffers(self%u_send_s_dev, self%u_send_e_dev, u_dev, &
387+
tdsops_dev%n)
388+
389+
call sendrecv_fields(self%u_recv_s_dev, self%u_recv_e_dev, &
390+
self%u_send_s_dev, self%u_send_e_dev, &
391+
SZ*4*blocks%x, dirps%nproc, &
392+
dirps%pprev, dirps%pnext)
393+
394+
! call exec_dist
395+
call exec_dist_tds_compact( &
396+
du_dev, u_dev, &
397+
self%u_recv_s_dev, self%u_recv_e_dev, &
398+
self%du_send_s_dev, self%du_send_e_dev, &
399+
self%du_recv_s_dev, self%du_recv_e_dev, &
400+
tdsops_dev, dirps%nproc, dirps%pprev, dirps%pnext, &
401+
blocks, threads &
402+
)
403+
404+
end subroutine tds_solve_dist
405+
346406
subroutine trans_x2y_cuda(self, u_y, v_y, w_y, u, v, w)
347407
implicit none
348408

src/omp/backend.f90

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ module m_omp_backend
2323
procedure :: transeq_x => transeq_x_omp
2424
procedure :: transeq_y => transeq_y_omp
2525
procedure :: transeq_z => transeq_z_omp
26+
procedure :: tds_solve => tds_solve_omp
2627
procedure :: trans_x2y => trans_x2y_omp
2728
procedure :: trans_x2z => trans_x2z_omp
2829
procedure :: sum_yzintox => sum_yzintox_omp
@@ -138,6 +139,19 @@ subroutine transeq_z_omp(self, du, dv, dw, u, v, w, dirps)
138139

139140
end subroutine transeq_z_omp
140141

142+
subroutine tds_solve_omp(self, du, u, dirps, tdsops)
143+
implicit none
144+
145+
class(omp_backend_t) :: self
146+
class(field_t), intent(inout) :: du
147+
class(field_t), intent(in) :: u
148+
type(dirps_t), intent(in) :: dirps
149+
class(tdsops_t), intent(in) :: tdsops
150+
151+
!call self%tds_solve_dist(self, du, u, dirps, tdsops)
152+
153+
end subroutine tds_solve_omp
154+
141155
subroutine trans_x2y_omp(self, u_, v_, w_, u, v, w)
142156
implicit none
143157

src/tdsops.f90

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ module m_tdsops
4141
der2nd, der2nd_sym, &
4242
stagder_v2p, stagder_p2v, &
4343
interpl_v2p, interpl_p2v
44-
integer :: nrank, nproc, pnext, pprev, n
44+
integer :: nrank, nproc, pnext, pprev, n, n_blocks
4545
end type dirps_t
4646

4747
contains

src/xcompact.f90

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,10 @@ program xcompact
9494
ydirps%n = globs%ny_loc
9595
zdirps%n = globs%nz_loc
9696

97+
xdirps%n_blocks = globs%n_groups_x
98+
ydirps%n_blocks = globs%n_groups_y
99+
zdirps%n_blocks = globs%n_groups_z
100+
97101
#ifdef CUDA
98102
cuda_allocator = cuda_allocator_t([SZ, globs%nx_loc, globs%n_groups_x])
99103
allocator => cuda_allocator

0 commit comments

Comments
 (0)