@@ -5,8 +5,8 @@ program test_omp_tridiag
55
66 use m_common, only: dp, pi
77 use m_omp_common, only: SZ
8- use m_omp_kernels_dist, only: der_univ_dist_omp, der_univ_subs_omp
98 use m_omp_sendrecv, only: sendrecv_fields
9+ use m_omp_exec_dist, only: exec_dist_tds_compact
1010
1111 use m_tdsops, only: tdsops_t, tdsops_init
1212
@@ -34,7 +34,7 @@ program test_omp_tridiag
3434
3535 integer :: n, n_block, i, j, k, n_halo, n_iters, iters, n_loc
3636 integer :: n_glob
37- integer :: nrank, nproc, pprev, pnext, tag1= 1234 , tag2= 1234
37+ integer :: nrank, nproc, pprev, pnext, tag1 = 1234 , tag2 = 1234
3838 integer :: ierr, ndevs, devnum, memClockRt, memBusWidth
3939
4040 real (dp) :: dx, dx_per, norm_du, tol = 1d-8 , tstart, tend
@@ -44,7 +44,7 @@ program test_omp_tridiag
4444 call MPI_Comm_rank(MPI_COMM_WORLD, nrank, ierr)
4545 call MPI_Comm_size(MPI_COMM_WORLD, nproc, ierr)
4646
47- if (nrank == 0 ) print * , ' Parallel run with' , nproc, ' ranks'
47+ if (nrank == 0 ) print * , ' Parallel run with' , nproc, ' ranks'
4848
4949 pnext = modulo (nrank - nproc + 1 , nproc)
5050 pprev = modulo (nrank - 1 , nproc)
@@ -54,14 +54,14 @@ program test_omp_tridiag
5454 n_block = 512 * 512 / SZ
5555 n_iters = 1
5656
57- allocate (u(SZ, n, n_block), du(SZ, n, n_block))
57+ allocate (u(SZ, n, n_block), du(SZ, n, n_block))
5858
5959 dx_per = 2 * pi/ n_glob
6060 dx = 2 * pi/ (n_glob - 1 )
6161
62- allocate (sin_0_2pi_per(n), cos_0_2pi_per(n))
63- allocate (sin_0_2pi(n), cos_0_2pi(n))
64- allocate (sin_stag(n), cos_stag(n))
62+ allocate (sin_0_2pi_per(n), cos_0_2pi_per(n))
63+ allocate (sin_0_2pi(n), cos_0_2pi(n))
64+ allocate (sin_stag(n), cos_stag(n))
6565 do j = 1 , n
6666 sin_0_2pi_per(j) = sin (((j - 1 ) + nrank* n)* dx_per)
6767 cos_0_2pi_per(j) = cos (((j - 1 ) + nrank* n)* dx_per)
@@ -74,13 +74,13 @@ program test_omp_tridiag
7474 n_halo = 4
7575
7676 ! arrays for exchanging data between ranks
77- allocate (u_send_s(SZ, n_halo, n_block))
78- allocate (u_send_e(SZ, n_halo, n_block))
79- allocate (u_recv_s(SZ, n_halo, n_block))
80- allocate (u_recv_e(SZ, n_halo, n_block))
77+ allocate (u_send_s(SZ, n_halo, n_block))
78+ allocate (u_send_e(SZ, n_halo, n_block))
79+ allocate (u_recv_s(SZ, n_halo, n_block))
80+ allocate (u_recv_e(SZ, n_halo, n_block))
8181
82- allocate (send_s(SZ, 1 , n_block), send_e(SZ, 1 , n_block))
83- allocate (recv_s(SZ, 1 , n_block), recv_e(SZ, 1 , n_block))
82+ allocate (send_s(SZ, 1 , n_block), send_e(SZ, 1 , n_block))
83+ allocate (recv_s(SZ, 1 , n_block), recv_e(SZ, 1 , n_block))
8484
8585 ! =========================================================================
8686 ! second derivative with periodic BC
@@ -97,17 +97,17 @@ program test_omp_tridiag
9797 )
9898
9999 tend = omp_get_wtime()
100- if (nrank == 0 ) print * , ' Total time' , tend- tstart
100+ if (nrank == 0 ) print * , ' Total time' , tend - tstart
101101
102102 call check_error_norm(du, sin_0_2pi_per, n, n_glob, n_block, 1 , norm_du)
103- if (nrank == 0 ) print * , ' error norm second-deriv periodic' , norm_du
103+ if (nrank == 0 ) print * , ' error norm second-deriv periodic' , norm_du
104104
105105 if (nrank == 0 ) then
106106 if (norm_du > tol) then
107107 allpass = .false.
108- write (stderr, ' (a)' ) ' Check 2nd derivatives, periodic BCs... failed'
108+ write (stderr, ' (a)' ) ' Check 2nd derivatives, periodic BCs... failed'
109109 else
110- write (stderr, ' (a)' ) ' Check 2nd derivatives, periodic BCs... passed'
110+ write (stderr, ' (a)' ) ' Check 2nd derivatives, periodic BCs... passed'
111111 end if
112112 end if
113113
@@ -124,14 +124,14 @@ program test_omp_tridiag
124124 )
125125
126126 call check_error_norm(du, cos_0_2pi_per, n, n_glob, n_block, - 1 , norm_du)
127- if (nrank == 0 ) print * , ' error norm first-deriv periodic' , norm_du
127+ if (nrank == 0 ) print * , ' error norm first-deriv periodic' , norm_du
128128
129129 if (nrank == 0 ) then
130130 if (norm_du > tol) then
131131 allpass = .false.
132- write (stderr, ' (a)' ) ' Check 1st derivatives, periodic BCs... failed'
132+ write (stderr, ' (a)' ) ' Check 1st derivatives, periodic BCs... failed'
133133 else
134- write (stderr, ' (a)' ) ' Check 1st derivatives, periodic BCs... passed'
134+ write (stderr, ' (a)' ) ' Check 1st derivatives, periodic BCs... passed'
135135 end if
136136 end if
137137
@@ -161,14 +161,14 @@ program test_omp_tridiag
161161 )
162162
163163 call check_error_norm(du, cos_0_2pi, n, n_glob, n_block, - 1 , norm_du)
164- if (nrank == 0 ) print * , ' error norm first deriv dir-neu' , norm_du
164+ if (nrank == 0 ) print * , ' error norm first deriv dir-neu' , norm_du
165165
166166 if (nrank == 0 ) then
167167 if (norm_du > tol) then
168168 allpass = .false.
169- write (stderr, ' (a)' ) ' Check 1st derivatives, dir-neu... failed'
169+ write (stderr, ' (a)' ) ' Check 1st derivatives, dir-neu... failed'
170170 else
171- write (stderr, ' (a)' ) ' Check 1st derivatives, dir-neu... passed'
171+ write (stderr, ' (a)' ) ' Check 1st derivatives, dir-neu... passed'
172172 end if
173173 end if
174174
@@ -189,14 +189,14 @@ program test_omp_tridiag
189189 )
190190
191191 call check_error_norm(du, cos_stag, n_loc, n_glob, n_block, - 1 , norm_du)
192- if (nrank == 0 ) print * , ' error norm interpolate' , norm_du
192+ if (nrank == 0 ) print * , ' error norm interpolate' , norm_du
193193
194194 if (nrank == 0 ) then
195195 if (norm_du > tol) then
196196 allpass = .false.
197- write (stderr, ' (a)' ) ' Check interpolation... failed'
197+ write (stderr, ' (a)' ) ' Check interpolation... failed'
198198 else
199- write (stderr, ' (a)' ) ' Check interpolation... passed'
199+ write (stderr, ' (a)' ) ' Check interpolation... passed'
200200 end if
201201 end if
202202
@@ -217,21 +217,21 @@ program test_omp_tridiag
217217 )
218218
219219 call check_error_norm(du, sin_0_2pi, n, n_glob, n_block, 1 , norm_du)
220- if (nrank == 0 ) print * , ' error norm hyperviscous' , norm_du
220+ if (nrank == 0 ) print * , ' error norm hyperviscous' , norm_du
221221
222222 if (nrank == 0 ) then
223223 if (norm_du > tol) then
224224 allpass = .false.
225- write (stderr, ' (a)' ) ' Check 2nd ders, hyperviscous, dir-neu... failed'
225+ write (stderr, ' (a)' ) ' Check 2nd ders, hyperviscous, dir-neu... failed'
226226 else
227- write (stderr, ' (a)' ) ' Check 2nd ders, hyperviscous, dir-neu... passed'
227+ write (stderr, ' (a)' ) ' Check 2nd ders, hyperviscous, dir-neu... passed'
228228 end if
229229 end if
230230
231231 ! =========================================================================
232232 ! BW utilisation and performance checks
233233 ! 3 in the first phase, 2 in the second phase, so 5 in total
234- achievedBW = 5._dp * n_iters* n* n_block* SZ* dp/ (tend- tstart)
234+ achievedBW = 5._dp * n_iters* n* n_block* SZ* dp/ (tend - tstart)
235235 call MPI_Allreduce(achievedBW, achievedBWmax, 1 , MPI_DOUBLE_PRECISION, &
236236 MPI_MAX, MPI_COMM_WORLD, ierr)
237237 call MPI_Allreduce(achievedBW, achievedBWmin, 1 , MPI_DOUBLE_PRECISION, &
@@ -247,13 +247,13 @@ program test_omp_tridiag
247247
248248 if (nrank == 0 ) then
249249 print ' (a, f8.3, a)' , ' Available BW: ' , deviceBW/ 2 ** 30 , &
250- ' GiB/s (per NUMA zone on ARCHER2)'
250+ ' GiB/s (per NUMA zone on ARCHER2)'
251251 print ' (a, f5.2)' , ' Effective BW util min: %' , achievedBWmin/ deviceBW* 100
252252 print ' (a, f5.2)' , ' Effective BW util max: %' , achievedBWmax/ deviceBW* 100
253253 end if
254254
255255 if (allpass) then
256- if (nrank == 0 ) write (stderr, ' (a)' ) ' ALL TESTS PASSED SUCCESSFULLY.'
256+ if (nrank == 0 ) write (stderr, ' (a)' ) ' ALL TESTS PASSED SUCCESSFULLY.'
257257 else
258258 error stop ' SOME TESTS FAILED.'
259259 end if
@@ -279,8 +279,7 @@ subroutine run_kernel(n_iters, n_block, u, du, tdsops, n, &
279279 send_s, send_e
280280 integer , intent (in ) :: nproc, pprev, pnext
281281
282- integer :: iters, i, j, k, ierr, tag1= 1234 , tag2= 1234
283- integer :: srerr(4 ), mpireq(4 )
282+ integer :: iters, i, j, k
284283
285284 do iters = 1 , n_iters
286285 ! first copy halo data into buffers
@@ -289,8 +288,8 @@ subroutine run_kernel(n_iters, n_block, u, du, tdsops, n, &
289288 do j = 1 , 4
290289 ! $omp simd
291290 do i = 1 , SZ
292- u_send_s(i,j, k) = u(i,j, k)
293- u_send_e(i,j, k) = u(i,n - n_halo+ j, k)
291+ u_send_s(i, j, k) = u(i, j, k)
292+ u_send_e(i, j, k) = u(i, n - n_halo + j, k)
294293 end do
295294 ! $omp end simd
296295 end do
@@ -299,48 +298,11 @@ subroutine run_kernel(n_iters, n_block, u, du, tdsops, n, &
299298
300299 ! halo exchange
301300 call sendrecv_fields(u_recv_s, u_recv_e, u_send_s, u_send_e, &
302- SZ* n_halo* n_block, nproc, pprev, pnext)
303-
304- ! $omp parallel do
305- do k = 1 , n_block
306- call der_univ_dist_omp( &
307- du(:, :, k), send_s(:, :, k), send_e(:, :, k), u(:, :, k), &
308- u_recv_s(:, :, k), u_recv_e(:, :, k), &
309- tdsops% coeffs_s, tdsops% coeffs_e, tdsops% coeffs, n, &
310- tdsops% dist_fw, tdsops% dist_bw, tdsops% dist_af &
311- )
312- end do
313- ! $omp end parallel do
301+ SZ* n_halo* n_block, nproc, pprev, pnext)
314302
315- ! halo exchange for 2x2 systems
316- if (nproc == 1 ) then
317- recv_s = send_e
318- recv_e = send_s
319- else
320- ! MPI send/recv for multi-rank simulations
321- call MPI_Isend(send_s, SZ* n_block, &
322- MPI_DOUBLE_PRECISION, pprev, tag1, MPI_COMM_WORLD, &
323- mpireq(1 ), srerr(1 ))
324- call MPI_Irecv(recv_e, SZ* n_block, &
325- MPI_DOUBLE_PRECISION, pnext, tag2, MPI_COMM_WORLD, &
326- mpireq(2 ), srerr(2 ))
327- call MPI_Isend(send_e, SZ* n_block, &
328- MPI_DOUBLE_PRECISION, pnext, tag2, MPI_COMM_WORLD, &
329- mpireq(3 ), srerr(3 ))
330- call MPI_Irecv(recv_s, SZ* n_block, &
331- MPI_DOUBLE_PRECISION, pprev, tag1, MPI_COMM_WORLD, &
332- mpireq(4 ), srerr(4 ))
333-
334- call MPI_Waitall(4 , mpireq, MPI_STATUSES_IGNORE, ierr)
335- end if
303+ call exec_dist_tds_compact(du, u, u_recv_s, u_recv_e, send_s, send_e, &
304+ recv_s, recv_e, tdsops, nproc, pprev, pnext, n_block)
336305
337- ! $omp parallel do
338- do k = 1 , n_block
339- call der_univ_subs_omp(du(:, :, k), &
340- recv_s(:, :, k), recv_e(:, :, k), &
341- n, tdsops% dist_sa, tdsops% dist_sc)
342- end do
343- ! $omp end parallel do
344306 end do
345307 end subroutine run_kernel
346308
@@ -384,7 +346,7 @@ subroutine check_error_norm(du, line, n, n_glob, n_block, c, norm)
384346 norm = norm2(du(:, 1 :n, :))
385347 norm = norm* norm/ n_glob/ n_block/ SZ
386348 call MPI_Allreduce(MPI_IN_PLACE, norm, 1 , MPI_DOUBLE_PRECISION, &
387- MPI_SUM, MPI_COMM_WORLD, ierr)
349+ MPI_SUM, MPI_COMM_WORLD, ierr)
388350 norm = sqrt (norm)
389351
390352 end subroutine check_error_norm
0 commit comments