Skip to content

Commit 410dac4

Browse files
authored
Merge pull request xcompact3d#16 from semi-h/feature
Add new subroutines to execute the generic and fused distributed solvers.
2 parents 4aab26c + 3de4342 commit 410dac4

File tree

5 files changed

+251
-165
lines changed

5 files changed

+251
-165
lines changed

src/CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,9 @@ set(SRC
1313
set(CUDASRC
1414
cuda/common.f90
1515
cuda/cuda_allocator.f90
16+
cuda/exec_dist.f90
1617
cuda/kernels_dist.f90
18+
cuda/sendrecv.f90
1719
cuda/tdsops.f90
1820
)
1921

src/cuda/exec_dist.f90

Lines changed: 124 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,124 @@
1+
module m_cuda_exec_dist
2+
use cudafor
3+
use mpi
4+
5+
use m_common, only: dp
6+
use m_cuda_common, only: SZ
7+
use m_cuda_kernels_dist, only: der_univ_dist, der_univ_subs, &
8+
transeq_3fused_dist, transeq_3fused_subs
9+
use m_cuda_sendrecv, only: sendrecv_fields, sendrecv_3fields
10+
use m_cuda_tdsops, only: cuda_tdsops_t
11+
12+
implicit none
13+
14+
contains
15+
16+
subroutine exec_dist_tds_compact( &
17+
du, u, u_recv_s, u_recv_e, du_send_s, du_send_e, du_recv_s, du_recv_e, &
18+
tdsops, nproc, pprev, pnext, blocks, threads &
19+
)
20+
implicit none
21+
22+
! du = d(u)
23+
real(dp), device, dimension(:, :, :), intent(out) :: du
24+
real(dp), device, dimension(:, :, :), intent(in) :: u, u_recv_s, u_recv_e
25+
26+
! The ones below are intent(out) just so that we can write data in them,
27+
! not because we actually need the data they store later where this
28+
! subroutine is called. We absolutely don't care the data they pass back
29+
real(dp), device, dimension(:, :, :), intent(out) :: &
30+
du_send_s, du_send_e, du_recv_s, du_recv_e
31+
32+
type(cuda_tdsops_t), intent(in) :: tdsops
33+
integer, intent(in) :: nproc, pprev, pnext
34+
type(dim3), intent(in) :: blocks, threads
35+
36+
integer :: n_data
37+
38+
n_data = SZ*1*blocks%x
39+
40+
call der_univ_dist<<<blocks, threads>>>( &
41+
du, du_send_s, du_send_e, u, u_recv_s, u_recv_e, &
42+
tdsops%coeffs_s_dev, tdsops%coeffs_e_dev, tdsops%coeffs_dev, &
43+
tdsops%n, tdsops%dist_fw_dev, tdsops%dist_bw_dev, tdsops%dist_af_dev &
44+
)
45+
46+
! halo exchange for 2x2 systems
47+
call sendrecv_fields(du_recv_s, du_recv_e, du_send_s, du_send_e, &
48+
n_data, nproc, pprev, pnext)
49+
50+
call der_univ_subs<<<blocks, threads>>>( &
51+
du, du_recv_s, du_recv_e, &
52+
tdsops%n, tdsops%dist_sa_dev, tdsops%dist_sc_dev &
53+
)
54+
55+
end subroutine exec_dist_tds_compact
56+
57+
subroutine exec_dist_transeq_3fused( &
58+
r_u, u, u_recv_s, u_recv_e, v, v_recv_s, v_recv_e, &
59+
du, dud, d2u, &
60+
du_send_s, du_send_e, du_recv_s, du_recv_e, &
61+
dud_send_s, dud_send_e, dud_recv_s, dud_recv_e, &
62+
d2u_send_s, d2u_send_e, d2u_recv_s, d2u_recv_e, &
63+
der1st, der2nd, nu, nproc, pprev, pnext, blocks, threads &
64+
)
65+
implicit none
66+
67+
! r_u = -1/2*(v*d1(u) + d1(u*v)) + nu*d2(u)
68+
real(dp), device, dimension(:, :, :), intent(out) :: r_u
69+
real(dp), device, dimension(:, :, :), intent(in) :: u, u_recv_s, u_recv_e
70+
real(dp), device, dimension(:, :, :), intent(in) :: v, v_recv_s, v_recv_e
71+
72+
! The ones below are intent(out) just so that we can write data in them,
73+
! not because we actually need the data they store later where this
74+
! subroutine is called. We absolutely don't care the data they pass back
75+
real(dp), device, dimension(:, :, :), intent(out) :: du, dud, d2u
76+
real(dp), device, dimension(:, :, :), intent(out) :: &
77+
du_send_s, du_send_e, du_recv_s, du_recv_e, &
78+
dud_send_s, dud_send_e, dud_recv_s, dud_recv_e, &
79+
d2u_send_s, d2u_send_e, d2u_recv_s, d2u_recv_e
80+
81+
type(cuda_tdsops_t), intent(in) :: der1st, der2nd
82+
real(dp), intent(in) :: nu
83+
integer, intent(in) :: nproc, pprev, pnext
84+
type(dim3), intent(in) :: blocks, threads
85+
86+
integer :: n_data
87+
88+
n_data = SZ*1*blocks%x
89+
90+
call transeq_3fused_dist<<<blocks, threads>>>( &
91+
du, dud, d2u, &
92+
du_send_s, du_send_e, &
93+
dud_send_s, dud_send_e, &
94+
d2u_send_s, d2u_send_e, &
95+
u, u_recv_s, u_recv_e, &
96+
v, v_recv_s, v_recv_e, der1st%n, &
97+
der1st%coeffs_s_dev, der1st%coeffs_e_dev, der1st%coeffs_dev, &
98+
der1st%dist_fw_dev, der1st%dist_bw_dev, der1st%dist_af_dev, &
99+
der2nd%coeffs_s_dev, der2nd%coeffs_e_dev, der2nd%coeffs_dev, &
100+
der2nd%dist_fw_dev, der2nd%dist_bw_dev, der2nd%dist_af_dev &
101+
)
102+
103+
! halo exchange for 2x2 systems
104+
call sendrecv_3fields( &
105+
du_recv_s, du_recv_e, dud_recv_s, dud_recv_e, &
106+
d2u_recv_s, d2u_recv_e, &
107+
du_send_s, du_send_e, dud_send_s, dud_send_e, &
108+
d2u_send_s, d2u_send_e, &
109+
n_data, nproc, pprev, pnext &
110+
)
111+
112+
call transeq_3fused_subs<<<blocks, threads>>>( &
113+
r_u, v, du, dud, d2u, &
114+
du_recv_s, du_recv_e, &
115+
dud_recv_s, dud_recv_e, &
116+
d2u_recv_s, d2u_recv_e, &
117+
der1st%dist_sa_dev, der1st%dist_sc_dev, &
118+
der2nd%dist_sa_dev, der2nd%dist_sc_dev, &
119+
der1st%n, nu &
120+
)
121+
122+
end subroutine exec_dist_transeq_3fused
123+
124+
end module m_cuda_exec_dist

src/cuda/sendrecv.f90

Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,94 @@
1+
module m_cuda_sendrecv
2+
use cudafor
3+
use mpi
4+
5+
use m_common, only: dp
6+
7+
implicit none
8+
9+
contains
10+
11+
subroutine sendrecv_fields(f_recv_s, f_recv_e, f_send_s, f_send_e, &
12+
n_data, nproc, prev, next)
13+
implicit none
14+
15+
real(dp), device, dimension(:, :, :), intent(out) :: f_recv_s, f_recv_e
16+
real(dp), device, dimension(:, :, :), intent(in) :: f_send_s, f_send_e
17+
integer, intent(in) :: n_data, nproc, prev, next
18+
19+
integer :: req(4), err(4), ierr, tag = 1234
20+
21+
if (nproc == 1) then
22+
f_recv_s = f_send_e
23+
f_recv_e = f_send_s
24+
else
25+
call MPI_Isend(f_send_s, n_data, MPI_DOUBLE_PRECISION, &
26+
prev, tag, MPI_COMM_WORLD, req(1), err(1))
27+
call MPI_Irecv(f_recv_e, n_data, MPI_DOUBLE_PRECISION, &
28+
next, tag, MPI_COMM_WORLD, req(2), err(2))
29+
call MPI_Isend(f_send_e, n_data, MPI_DOUBLE_PRECISION, &
30+
next, tag, MPI_COMM_WORLD, req(3), err(3))
31+
call MPI_Irecv(f_recv_s, n_data, MPI_DOUBLE_PRECISION, &
32+
prev, tag, MPI_COMM_WORLD, req(4), err(4))
33+
34+
call MPI_Waitall(4, req, MPI_STATUSES_IGNORE, ierr)
35+
end if
36+
37+
end subroutine sendrecv_fields
38+
39+
subroutine sendrecv_3fields( &
40+
f1_recv_s, f1_recv_e, f2_recv_s, f2_recv_e, f3_recv_s, f3_recv_e, &
41+
f1_send_s, f1_send_e, f2_send_s, f2_send_e, f3_send_s, f3_send_e, &
42+
n_data, nproc, prev, next &
43+
)
44+
implicit none
45+
46+
real(dp), device, dimension(:, :, :), intent(out) :: &
47+
f1_recv_s, f1_recv_e, f2_recv_s, f2_recv_e, f3_recv_s, f3_recv_e
48+
real(dp), device, dimension(:, :, :), intent(in) :: &
49+
f1_send_s, f1_send_e, f2_send_s, f2_send_e, f3_send_s, f3_send_e
50+
integer, intent(in) :: n_data, nproc, prev, next
51+
52+
integer :: req(12), err(12), ierr, tag = 1234
53+
54+
if (nproc == 1) then
55+
f1_recv_s = f1_send_e
56+
f1_recv_e = f1_send_s
57+
f2_recv_s = f2_send_e
58+
f2_recv_e = f2_send_s
59+
f3_recv_s = f3_send_e
60+
f3_recv_e = f3_send_s
61+
else
62+
call MPI_Isend(f1_send_s, n_data, MPI_DOUBLE_PRECISION, &
63+
prev, tag, MPI_COMM_WORLD, req(1), err(1))
64+
call MPI_Irecv(f1_recv_e, n_data, MPI_DOUBLE_PRECISION, &
65+
next, tag, MPI_COMM_WORLD, req(2), err(2))
66+
call MPI_Isend(f1_send_e, n_data, MPI_DOUBLE_PRECISION, &
67+
next, tag, MPI_COMM_WORLD, req(3), err(3))
68+
call MPI_Irecv(f1_recv_s, n_data, MPI_DOUBLE_PRECISION, &
69+
prev, tag, MPI_COMM_WORLD, req(4), err(4))
70+
71+
call MPI_Isend(f2_send_s, n_data, MPI_DOUBLE_PRECISION, &
72+
prev, tag, MPI_COMM_WORLD, req(5), err(5))
73+
call MPI_Irecv(f2_recv_e, n_data, MPI_DOUBLE_PRECISION, &
74+
next, tag, MPI_COMM_WORLD, req(6), err(6))
75+
call MPI_Isend(f2_send_e, n_data, MPI_DOUBLE_PRECISION, &
76+
next, tag, MPI_COMM_WORLD, req(7), err(7))
77+
call MPI_Irecv(f2_recv_s, n_data, MPI_DOUBLE_PRECISION, &
78+
prev, tag, MPI_COMM_WORLD, req(8), err(8))
79+
80+
call MPI_Isend(f3_send_s, n_data, MPI_DOUBLE_PRECISION, &
81+
prev, tag, MPI_COMM_WORLD, req(9), err(9))
82+
call MPI_Irecv(f3_recv_e, n_data, MPI_DOUBLE_PRECISION, &
83+
next, tag, MPI_COMM_WORLD, req(10), err(10))
84+
call MPI_Isend(f3_send_e, n_data, MPI_DOUBLE_PRECISION, &
85+
next, tag, MPI_COMM_WORLD, req(11), err(11))
86+
call MPI_Irecv(f3_recv_s, n_data, MPI_DOUBLE_PRECISION, &
87+
prev, tag, MPI_COMM_WORLD, req(12), err(12))
88+
89+
call MPI_Waitall(12, req, MPI_STATUSES_IGNORE, ierr)
90+
end if
91+
92+
end subroutine sendrecv_3fields
93+
94+
end module m_cuda_sendrecv

tests/cuda/test_cuda_transeq.f90

Lines changed: 18 additions & 108 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,8 @@ program test_cuda_tridiag
55

66
use m_common, only: dp, pi
77
use m_cuda_common, only: SZ
8-
use m_cuda_kernels_dist, only: transeq_3fused_dist, transeq_3fused_subs
8+
use m_cuda_exec_dist, only: exec_dist_transeq_3fused
9+
use m_cuda_sendrecv, only: sendrecv_fields, sendrecv_3fields
910
use m_cuda_tdsops, only: cuda_tdsops_t
1011

1112
implicit none
@@ -119,116 +120,25 @@ program test_cuda_tridiag
119120
v_send_s_dev(:, :, :) = v_dev(:, 1:4, :)
120121
v_send_e_dev(:, :, :) = v_dev(:, n - n_halo + 1:n, :)
121122

122-
! halo exchange
123-
if (nproc == 1) then
124-
u_recv_s_dev = u_send_e_dev
125-
u_recv_e_dev = u_send_s_dev
126-
v_recv_s_dev = v_send_e_dev
127-
v_recv_e_dev = v_send_s_dev
128-
else
129-
! MPI send/recv for multi-rank simulations
130-
call MPI_Isend(u_send_s_dev, SZ*n_halo*n_block, &
131-
MPI_DOUBLE_PRECISION, pprev, tag1, MPI_COMM_WORLD, &
132-
mpireq(1), srerr(1))
133-
call MPI_Irecv(u_recv_e_dev, SZ*n_halo*n_block, &
134-
MPI_DOUBLE_PRECISION, pnext, tag1, MPI_COMM_WORLD, &
135-
mpireq(2), srerr(2))
136-
call MPI_Isend(u_send_e_dev, SZ*n_halo*n_block, &
137-
MPI_DOUBLE_PRECISION, pnext, tag2, MPI_COMM_WORLD, &
138-
mpireq(3), srerr(3))
139-
call MPI_Irecv(u_recv_s_dev, SZ*n_halo*n_block, &
140-
MPI_DOUBLE_PRECISION, pprev, tag2, MPI_COMM_WORLD, &
141-
mpireq(4), srerr(4))
142-
143-
call MPI_Isend(v_send_s_dev, SZ*n_halo*n_block, &
144-
MPI_DOUBLE_PRECISION, pprev, tag1, MPI_COMM_WORLD, &
145-
mpireq(5), srerr(5))
146-
call MPI_Irecv(v_recv_e_dev, SZ*n_halo*n_block, &
147-
MPI_DOUBLE_PRECISION, pnext, tag1, MPI_COMM_WORLD, &
148-
mpireq(6), srerr(6))
149-
call MPI_Isend(v_send_e_dev, SZ*n_halo*n_block, &
150-
MPI_DOUBLE_PRECISION, pnext, tag2, MPI_COMM_WORLD, &
151-
mpireq(7), srerr(7))
152-
call MPI_Irecv(v_recv_s_dev, SZ*n_halo*n_block, &
153-
MPI_DOUBLE_PRECISION, pprev, tag2, MPI_COMM_WORLD, &
154-
mpireq(8), srerr(8))
155-
156-
call MPI_Waitall(8, mpireq, MPI_STATUSES_IGNORE, ierr)
157-
end if
158123

159-
call transeq_3fused_dist<<<blocks, threads>>>( &
160-
du_dev, dud_dev, d2u_dev, &
161-
du_send_s_dev, du_send_e_dev, &
162-
dud_send_s_dev, dud_send_e_dev, &
163-
d2u_send_s_dev, d2u_send_e_dev, &
164-
u_dev, u_recv_s_dev, u_recv_e_dev, &
165-
v_dev, v_recv_s_dev, v_recv_e_dev, n, &
166-
der1st%coeffs_s_dev, der1st%coeffs_e_dev, der1st%coeffs_dev, &
167-
der1st%dist_fw_dev, der1st%dist_bw_dev, der1st%dist_af_dev, &
168-
der2nd%coeffs_s_dev, der2nd%coeffs_e_dev, der2nd%coeffs_dev, &
169-
der2nd%dist_fw_dev, der2nd%dist_bw_dev, der2nd%dist_af_dev &
170-
)
124+
! halo exchange
125+
call sendrecv_fields(u_recv_s_dev, u_recv_e_dev, &
126+
u_send_s_dev, u_send_e_dev, &
127+
SZ*4*n_block, nproc, pprev, pnext)
171128

172-
! halo exchange for 2x2 systems
173-
if (nproc == 1) then
174-
du_recv_s_dev = du_send_e_dev
175-
du_recv_e_dev = du_send_s_dev
176-
dud_recv_s_dev = dud_send_e_dev
177-
dud_recv_e_dev = dud_send_s_dev
178-
d2u_recv_s_dev = d2u_send_e_dev
179-
d2u_recv_e_dev = d2u_send_s_dev
180-
else
181-
! MPI send/recv for multi-rank simulations
182-
call MPI_Isend(du_send_s_dev, SZ*n_block, &
183-
MPI_DOUBLE_PRECISION, pprev, tag1, MPI_COMM_WORLD, &
184-
mpireq(1), srerr(1))
185-
call MPI_Irecv(du_recv_e_dev, SZ*n_block, &
186-
MPI_DOUBLE_PRECISION, pnext, tag2, MPI_COMM_WORLD, &
187-
mpireq(2), srerr(2))
188-
call MPI_Isend(du_send_e_dev, SZ*n_block, &
189-
MPI_DOUBLE_PRECISION, pnext, tag2, MPI_COMM_WORLD, &
190-
mpireq(3), srerr(3))
191-
call MPI_Irecv(du_recv_s_dev, SZ*n_block, &
192-
MPI_DOUBLE_PRECISION, pprev, tag1, MPI_COMM_WORLD, &
193-
mpireq(4), srerr(4))
194-
195-
call MPI_Isend(dud_send_s_dev, SZ*n_block, &
196-
MPI_DOUBLE_PRECISION, pprev, tag1, MPI_COMM_WORLD, &
197-
mpireq(5), srerr(5))
198-
call MPI_Irecv(dud_recv_e_dev, SZ*n_block, &
199-
MPI_DOUBLE_PRECISION, pnext, tag2, MPI_COMM_WORLD, &
200-
mpireq(6), srerr(6))
201-
call MPI_Isend(dud_send_e_dev, SZ*n_block, &
202-
MPI_DOUBLE_PRECISION, pnext, tag2, MPI_COMM_WORLD, &
203-
mpireq(7), srerr(7))
204-
call MPI_Irecv(dud_recv_s_dev, SZ*n_block, &
205-
MPI_DOUBLE_PRECISION, pprev, tag1, MPI_COMM_WORLD, &
206-
mpireq(8), srerr(8))
207-
208-
call MPI_Isend(d2u_send_s_dev, SZ*n_block, &
209-
MPI_DOUBLE_PRECISION, pprev, tag1, MPI_COMM_WORLD, &
210-
mpireq(9), srerr(9))
211-
call MPI_Irecv(d2u_recv_e_dev, SZ*n_block, &
212-
MPI_DOUBLE_PRECISION, pnext, tag2, MPI_COMM_WORLD, &
213-
mpireq(10), srerr(10))
214-
call MPI_Isend(d2u_send_e_dev, SZ*n_block, &
215-
MPI_DOUBLE_PRECISION, pnext, tag2, MPI_COMM_WORLD, &
216-
mpireq(11), srerr(11))
217-
call MPI_Irecv(d2u_recv_s_dev, SZ*n_block, &
218-
MPI_DOUBLE_PRECISION, pprev, tag1, MPI_COMM_WORLD, &
219-
mpireq(12), srerr(12))
220-
221-
call MPI_Waitall(12, mpireq, MPI_STATUSES_IGNORE, ierr)
222-
end if
129+
call sendrecv_fields(v_recv_s_dev, v_recv_e_dev, &
130+
v_send_s_dev, v_send_e_dev, &
131+
SZ*4*n_block, nproc, pprev, pnext)
223132

224-
call transeq_3fused_subs<<<blocks, threads>>>( &
225-
r_u_dev, v_dev, du_dev, dud_dev, d2u_dev, &
226-
du_recv_s_dev, du_recv_e_dev, &
227-
dud_recv_s_dev, dud_recv_e_dev, &
228-
d2u_recv_s_dev, d2u_recv_e_dev, &
229-
der1st%dist_sa_dev, der1st%dist_sc_dev, &
230-
der2nd%dist_sa_dev, der2nd%dist_sc_dev, &
231-
n, nu &
133+
call exec_dist_transeq_3fused( &
134+
r_u_dev, &
135+
u_dev, u_recv_s_dev, u_recv_e_dev, &
136+
v_dev, v_recv_s_dev, v_recv_e_dev, &
137+
du_dev, dud_dev, d2u_dev, &
138+
du_send_s_dev, du_send_e_dev, du_recv_s_dev, du_recv_e_dev, &
139+
dud_send_s_dev, dud_send_e_dev, dud_recv_s_dev, dud_recv_e_dev, &
140+
d2u_send_s_dev, d2u_send_e_dev, d2u_recv_s_dev, d2u_recv_e_dev, &
141+
der1st, der2nd, nu, nproc, pprev, pnext, blocks, threads &
232142
)
233143
end do
234144

0 commit comments

Comments
 (0)