@@ -5,7 +5,8 @@ program test_cuda_tridiag
55
66 use m_common, only: dp, pi
77 use m_cuda_common, only: SZ
8- use m_cuda_kernels_dist, only: transeq_3fused_dist, transeq_3fused_subs
8+ use m_cuda_exec_dist, only: exec_dist_transeq_3fused
9+ use m_cuda_sendrecv, only: sendrecv_fields, sendrecv_3fields
910 use m_cuda_tdsops, only: cuda_tdsops_t
1011
1112 implicit none
@@ -119,116 +120,25 @@ program test_cuda_tridiag
119120 v_send_s_dev(:, :, :) = v_dev(:, 1 :4 , :)
120121 v_send_e_dev(:, :, :) = v_dev(:, n - n_halo + 1 :n, :)
121122
122- ! halo exchange
123- if (nproc == 1 ) then
124- u_recv_s_dev = u_send_e_dev
125- u_recv_e_dev = u_send_s_dev
126- v_recv_s_dev = v_send_e_dev
127- v_recv_e_dev = v_send_s_dev
128- else
129- ! MPI send/recv for multi-rank simulations
130- call MPI_Isend(u_send_s_dev, SZ* n_halo* n_block, &
131- MPI_DOUBLE_PRECISION, pprev, tag1, MPI_COMM_WORLD, &
132- mpireq(1 ), srerr(1 ))
133- call MPI_Irecv(u_recv_e_dev, SZ* n_halo* n_block, &
134- MPI_DOUBLE_PRECISION, pnext, tag1, MPI_COMM_WORLD, &
135- mpireq(2 ), srerr(2 ))
136- call MPI_Isend(u_send_e_dev, SZ* n_halo* n_block, &
137- MPI_DOUBLE_PRECISION, pnext, tag2, MPI_COMM_WORLD, &
138- mpireq(3 ), srerr(3 ))
139- call MPI_Irecv(u_recv_s_dev, SZ* n_halo* n_block, &
140- MPI_DOUBLE_PRECISION, pprev, tag2, MPI_COMM_WORLD, &
141- mpireq(4 ), srerr(4 ))
142-
143- call MPI_Isend(v_send_s_dev, SZ* n_halo* n_block, &
144- MPI_DOUBLE_PRECISION, pprev, tag1, MPI_COMM_WORLD, &
145- mpireq(5 ), srerr(5 ))
146- call MPI_Irecv(v_recv_e_dev, SZ* n_halo* n_block, &
147- MPI_DOUBLE_PRECISION, pnext, tag1, MPI_COMM_WORLD, &
148- mpireq(6 ), srerr(6 ))
149- call MPI_Isend(v_send_e_dev, SZ* n_halo* n_block, &
150- MPI_DOUBLE_PRECISION, pnext, tag2, MPI_COMM_WORLD, &
151- mpireq(7 ), srerr(7 ))
152- call MPI_Irecv(v_recv_s_dev, SZ* n_halo* n_block, &
153- MPI_DOUBLE_PRECISION, pprev, tag2, MPI_COMM_WORLD, &
154- mpireq(8 ), srerr(8 ))
155-
156- call MPI_Waitall(8 , mpireq, MPI_STATUSES_IGNORE, ierr)
157- end if
158123
159- call transeq_3fused_dist<<<blocks, threads>>>( &
160- du_dev, dud_dev, d2u_dev, &
161- du_send_s_dev, du_send_e_dev, &
162- dud_send_s_dev, dud_send_e_dev, &
163- d2u_send_s_dev, d2u_send_e_dev, &
164- u_dev, u_recv_s_dev, u_recv_e_dev, &
165- v_dev, v_recv_s_dev, v_recv_e_dev, n, &
166- der1st% coeffs_s_dev, der1st% coeffs_e_dev, der1st% coeffs_dev, &
167- der1st% dist_fw_dev, der1st% dist_bw_dev, der1st% dist_af_dev, &
168- der2nd% coeffs_s_dev, der2nd% coeffs_e_dev, der2nd% coeffs_dev, &
169- der2nd% dist_fw_dev, der2nd% dist_bw_dev, der2nd% dist_af_dev &
170- )
124+ ! halo exchange
125+ call sendrecv_fields(u_recv_s_dev, u_recv_e_dev, &
126+ u_send_s_dev, u_send_e_dev, &
127+ SZ* 4 * n_block, nproc, pprev, pnext)
171128
172- ! halo exchange for 2x2 systems
173- if (nproc == 1 ) then
174- du_recv_s_dev = du_send_e_dev
175- du_recv_e_dev = du_send_s_dev
176- dud_recv_s_dev = dud_send_e_dev
177- dud_recv_e_dev = dud_send_s_dev
178- d2u_recv_s_dev = d2u_send_e_dev
179- d2u_recv_e_dev = d2u_send_s_dev
180- else
181- ! MPI send/recv for multi-rank simulations
182- call MPI_Isend(du_send_s_dev, SZ* n_block, &
183- MPI_DOUBLE_PRECISION, pprev, tag1, MPI_COMM_WORLD, &
184- mpireq(1 ), srerr(1 ))
185- call MPI_Irecv(du_recv_e_dev, SZ* n_block, &
186- MPI_DOUBLE_PRECISION, pnext, tag2, MPI_COMM_WORLD, &
187- mpireq(2 ), srerr(2 ))
188- call MPI_Isend(du_send_e_dev, SZ* n_block, &
189- MPI_DOUBLE_PRECISION, pnext, tag2, MPI_COMM_WORLD, &
190- mpireq(3 ), srerr(3 ))
191- call MPI_Irecv(du_recv_s_dev, SZ* n_block, &
192- MPI_DOUBLE_PRECISION, pprev, tag1, MPI_COMM_WORLD, &
193- mpireq(4 ), srerr(4 ))
194-
195- call MPI_Isend(dud_send_s_dev, SZ* n_block, &
196- MPI_DOUBLE_PRECISION, pprev, tag1, MPI_COMM_WORLD, &
197- mpireq(5 ), srerr(5 ))
198- call MPI_Irecv(dud_recv_e_dev, SZ* n_block, &
199- MPI_DOUBLE_PRECISION, pnext, tag2, MPI_COMM_WORLD, &
200- mpireq(6 ), srerr(6 ))
201- call MPI_Isend(dud_send_e_dev, SZ* n_block, &
202- MPI_DOUBLE_PRECISION, pnext, tag2, MPI_COMM_WORLD, &
203- mpireq(7 ), srerr(7 ))
204- call MPI_Irecv(dud_recv_s_dev, SZ* n_block, &
205- MPI_DOUBLE_PRECISION, pprev, tag1, MPI_COMM_WORLD, &
206- mpireq(8 ), srerr(8 ))
207-
208- call MPI_Isend(d2u_send_s_dev, SZ* n_block, &
209- MPI_DOUBLE_PRECISION, pprev, tag1, MPI_COMM_WORLD, &
210- mpireq(9 ), srerr(9 ))
211- call MPI_Irecv(d2u_recv_e_dev, SZ* n_block, &
212- MPI_DOUBLE_PRECISION, pnext, tag2, MPI_COMM_WORLD, &
213- mpireq(10 ), srerr(10 ))
214- call MPI_Isend(d2u_send_e_dev, SZ* n_block, &
215- MPI_DOUBLE_PRECISION, pnext, tag2, MPI_COMM_WORLD, &
216- mpireq(11 ), srerr(11 ))
217- call MPI_Irecv(d2u_recv_s_dev, SZ* n_block, &
218- MPI_DOUBLE_PRECISION, pprev, tag1, MPI_COMM_WORLD, &
219- mpireq(12 ), srerr(12 ))
220-
221- call MPI_Waitall(12 , mpireq, MPI_STATUSES_IGNORE, ierr)
222- end if
129+ call sendrecv_fields(v_recv_s_dev, v_recv_e_dev, &
130+ v_send_s_dev, v_send_e_dev, &
131+ SZ* 4 * n_block, nproc, pprev, pnext)
223132
224- call transeq_3fused_subs<<<blocks, threads>>>( &
225- r_u_dev, v_dev, du_dev, dud_dev, d2u_dev, &
226- du_recv_s_dev, du_recv_e_dev, &
227- dud_recv_s_dev, dud_recv_e_dev, &
228- d2u_recv_s_dev, d2u_recv_e_dev, &
229- der1st% dist_sa_dev, der1st% dist_sc_dev, &
230- der2nd% dist_sa_dev, der2nd% dist_sc_dev, &
231- n, nu &
133+ call exec_dist_transeq_3fused( &
134+ r_u_dev, &
135+ u_dev, u_recv_s_dev, u_recv_e_dev, &
136+ v_dev, v_recv_s_dev, v_recv_e_dev, &
137+ du_dev, dud_dev, d2u_dev, &
138+ du_send_s_dev, du_send_e_dev, du_recv_s_dev, du_recv_e_dev, &
139+ dud_send_s_dev, dud_send_e_dev, dud_recv_s_dev, dud_recv_e_dev, &
140+ d2u_send_s_dev, d2u_send_e_dev, d2u_recv_s_dev, d2u_recv_e_dev, &
141+ der1st, der2nd, nu, nproc, pprev, pnext, blocks, threads &
232142 )
233143 end do
234144
0 commit comments