Skip to content

Commit 3de4342

Browse files
committed
feat(tests/cuda): Use the transeq exec_dist subroutine in the tests.
1 parent 5c27535 commit 3de4342

File tree

1 file changed

+18
-108
lines changed

1 file changed

+18
-108
lines changed

tests/cuda/test_cuda_transeq.f90

Lines changed: 18 additions & 108 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,8 @@ program test_cuda_tridiag
55

66
use m_common, only: dp, pi
77
use m_cuda_common, only: SZ
8-
use m_cuda_kernels_dist, only: transeq_3fused_dist, transeq_3fused_subs
8+
use m_cuda_exec_dist, only: exec_dist_transeq_3fused
9+
use m_cuda_sendrecv, only: sendrecv_fields, sendrecv_3fields
910
use m_cuda_tdsops, only: cuda_tdsops_t
1011

1112
implicit none
@@ -119,116 +120,25 @@ program test_cuda_tridiag
119120
v_send_s_dev(:, :, :) = v_dev(:, 1:4, :)
120121
v_send_e_dev(:, :, :) = v_dev(:, n - n_halo + 1:n, :)
121122

122-
! halo exchange
123-
if (nproc == 1) then
124-
u_recv_s_dev = u_send_e_dev
125-
u_recv_e_dev = u_send_s_dev
126-
v_recv_s_dev = v_send_e_dev
127-
v_recv_e_dev = v_send_s_dev
128-
else
129-
! MPI send/recv for multi-rank simulations
130-
call MPI_Isend(u_send_s_dev, SZ*n_halo*n_block, &
131-
MPI_DOUBLE_PRECISION, pprev, tag1, MPI_COMM_WORLD, &
132-
mpireq(1), srerr(1))
133-
call MPI_Irecv(u_recv_e_dev, SZ*n_halo*n_block, &
134-
MPI_DOUBLE_PRECISION, pnext, tag1, MPI_COMM_WORLD, &
135-
mpireq(2), srerr(2))
136-
call MPI_Isend(u_send_e_dev, SZ*n_halo*n_block, &
137-
MPI_DOUBLE_PRECISION, pnext, tag2, MPI_COMM_WORLD, &
138-
mpireq(3), srerr(3))
139-
call MPI_Irecv(u_recv_s_dev, SZ*n_halo*n_block, &
140-
MPI_DOUBLE_PRECISION, pprev, tag2, MPI_COMM_WORLD, &
141-
mpireq(4), srerr(4))
142-
143-
call MPI_Isend(v_send_s_dev, SZ*n_halo*n_block, &
144-
MPI_DOUBLE_PRECISION, pprev, tag1, MPI_COMM_WORLD, &
145-
mpireq(5), srerr(5))
146-
call MPI_Irecv(v_recv_e_dev, SZ*n_halo*n_block, &
147-
MPI_DOUBLE_PRECISION, pnext, tag1, MPI_COMM_WORLD, &
148-
mpireq(6), srerr(6))
149-
call MPI_Isend(v_send_e_dev, SZ*n_halo*n_block, &
150-
MPI_DOUBLE_PRECISION, pnext, tag2, MPI_COMM_WORLD, &
151-
mpireq(7), srerr(7))
152-
call MPI_Irecv(v_recv_s_dev, SZ*n_halo*n_block, &
153-
MPI_DOUBLE_PRECISION, pprev, tag2, MPI_COMM_WORLD, &
154-
mpireq(8), srerr(8))
155-
156-
call MPI_Waitall(8, mpireq, MPI_STATUSES_IGNORE, ierr)
157-
end if
158123

159-
call transeq_3fused_dist<<<blocks, threads>>>( &
160-
du_dev, dud_dev, d2u_dev, &
161-
du_send_s_dev, du_send_e_dev, &
162-
dud_send_s_dev, dud_send_e_dev, &
163-
d2u_send_s_dev, d2u_send_e_dev, &
164-
u_dev, u_recv_s_dev, u_recv_e_dev, &
165-
v_dev, v_recv_s_dev, v_recv_e_dev, n, &
166-
der1st%coeffs_s_dev, der1st%coeffs_e_dev, der1st%coeffs_dev, &
167-
der1st%dist_fw_dev, der1st%dist_bw_dev, der1st%dist_af_dev, &
168-
der2nd%coeffs_s_dev, der2nd%coeffs_e_dev, der2nd%coeffs_dev, &
169-
der2nd%dist_fw_dev, der2nd%dist_bw_dev, der2nd%dist_af_dev &
170-
)
124+
! halo exchange
125+
call sendrecv_fields(u_recv_s_dev, u_recv_e_dev, &
126+
u_send_s_dev, u_send_e_dev, &
127+
SZ*4*n_block, nproc, pprev, pnext)
171128

172-
! halo exchange for 2x2 systems
173-
if (nproc == 1) then
174-
du_recv_s_dev = du_send_e_dev
175-
du_recv_e_dev = du_send_s_dev
176-
dud_recv_s_dev = dud_send_e_dev
177-
dud_recv_e_dev = dud_send_s_dev
178-
d2u_recv_s_dev = d2u_send_e_dev
179-
d2u_recv_e_dev = d2u_send_s_dev
180-
else
181-
! MPI send/recv for multi-rank simulations
182-
call MPI_Isend(du_send_s_dev, SZ*n_block, &
183-
MPI_DOUBLE_PRECISION, pprev, tag1, MPI_COMM_WORLD, &
184-
mpireq(1), srerr(1))
185-
call MPI_Irecv(du_recv_e_dev, SZ*n_block, &
186-
MPI_DOUBLE_PRECISION, pnext, tag2, MPI_COMM_WORLD, &
187-
mpireq(2), srerr(2))
188-
call MPI_Isend(du_send_e_dev, SZ*n_block, &
189-
MPI_DOUBLE_PRECISION, pnext, tag2, MPI_COMM_WORLD, &
190-
mpireq(3), srerr(3))
191-
call MPI_Irecv(du_recv_s_dev, SZ*n_block, &
192-
MPI_DOUBLE_PRECISION, pprev, tag1, MPI_COMM_WORLD, &
193-
mpireq(4), srerr(4))
194-
195-
call MPI_Isend(dud_send_s_dev, SZ*n_block, &
196-
MPI_DOUBLE_PRECISION, pprev, tag1, MPI_COMM_WORLD, &
197-
mpireq(5), srerr(5))
198-
call MPI_Irecv(dud_recv_e_dev, SZ*n_block, &
199-
MPI_DOUBLE_PRECISION, pnext, tag2, MPI_COMM_WORLD, &
200-
mpireq(6), srerr(6))
201-
call MPI_Isend(dud_send_e_dev, SZ*n_block, &
202-
MPI_DOUBLE_PRECISION, pnext, tag2, MPI_COMM_WORLD, &
203-
mpireq(7), srerr(7))
204-
call MPI_Irecv(dud_recv_s_dev, SZ*n_block, &
205-
MPI_DOUBLE_PRECISION, pprev, tag1, MPI_COMM_WORLD, &
206-
mpireq(8), srerr(8))
207-
208-
call MPI_Isend(d2u_send_s_dev, SZ*n_block, &
209-
MPI_DOUBLE_PRECISION, pprev, tag1, MPI_COMM_WORLD, &
210-
mpireq(9), srerr(9))
211-
call MPI_Irecv(d2u_recv_e_dev, SZ*n_block, &
212-
MPI_DOUBLE_PRECISION, pnext, tag2, MPI_COMM_WORLD, &
213-
mpireq(10), srerr(10))
214-
call MPI_Isend(d2u_send_e_dev, SZ*n_block, &
215-
MPI_DOUBLE_PRECISION, pnext, tag2, MPI_COMM_WORLD, &
216-
mpireq(11), srerr(11))
217-
call MPI_Irecv(d2u_recv_s_dev, SZ*n_block, &
218-
MPI_DOUBLE_PRECISION, pprev, tag1, MPI_COMM_WORLD, &
219-
mpireq(12), srerr(12))
220-
221-
call MPI_Waitall(12, mpireq, MPI_STATUSES_IGNORE, ierr)
222-
end if
129+
call sendrecv_fields(v_recv_s_dev, v_recv_e_dev, &
130+
v_send_s_dev, v_send_e_dev, &
131+
SZ*4*n_block, nproc, pprev, pnext)
223132

224-
call transeq_3fused_subs<<<blocks, threads>>>( &
225-
r_u_dev, v_dev, du_dev, dud_dev, d2u_dev, &
226-
du_recv_s_dev, du_recv_e_dev, &
227-
dud_recv_s_dev, dud_recv_e_dev, &
228-
d2u_recv_s_dev, d2u_recv_e_dev, &
229-
der1st%dist_sa_dev, der1st%dist_sc_dev, &
230-
der2nd%dist_sa_dev, der2nd%dist_sc_dev, &
231-
n, nu &
133+
call exec_dist_transeq_3fused( &
134+
r_u_dev, &
135+
u_dev, u_recv_s_dev, u_recv_e_dev, &
136+
v_dev, v_recv_s_dev, v_recv_e_dev, &
137+
du_dev, dud_dev, d2u_dev, &
138+
du_send_s_dev, du_send_e_dev, du_recv_s_dev, du_recv_e_dev, &
139+
dud_send_s_dev, dud_send_e_dev, dud_recv_s_dev, dud_recv_e_dev, &
140+
d2u_send_s_dev, d2u_send_e_dev, d2u_recv_s_dev, d2u_recv_e_dev, &
141+
der1st, der2nd, nu, nproc, pprev, pnext, blocks, threads &
232142
)
233143
end do
234144

0 commit comments

Comments
 (0)