Skip to content

Commit 4527b70

Browse files
committed
refactor(omp): move omp exec_dist_tds_compact to its own file
1 parent 1754063 commit 4527b70

File tree

4 files changed

+108
-81
lines changed

4 files changed

+108
-81
lines changed

src/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ set(SRC
99
omp/common.f90
1010
omp/kernels_dist.f90
1111
omp/sendrecv.f90
12+
omp/exec_dist.f90
1213
)
1314
set(CUDASRC
1415
cuda/backend.f90

src/omp/exec_dist.f90

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
module m_omp_exec_dist
2+
use mpi
3+
4+
use m_common, only: dp
5+
use m_omp_common, only: SZ
6+
use m_omp_kernels_dist, only: der_univ_dist, der_univ_subs
7+
use m_tdsops, only: tdsops_t
8+
use m_omp_sendrecv, only: sendrecv_fields
9+
10+
implicit none
11+
12+
contains
13+
14+
subroutine exec_dist_tds_compact( &
15+
du, u, u_recv_s, u_recv_e, du_send_s, du_send_e, du_recv_s, du_recv_e, &
16+
tdsops, nproc, pprev, pnext, n_block &
17+
)
18+
implicit none
19+
20+
! du = d(u)
21+
real(dp), dimension(:, :, :), intent(out) :: du
22+
real(dp), dimension(:, :, :), intent(in) :: u, u_recv_s, u_recv_e
23+
24+
! The ones below are intent(out) just so that we can write data in them,
25+
! not because we actually need the data they store later where this
26+
! subroutine is called. We absolutely don't care about the data they pass back
27+
real(dp), dimension(:, :, :), intent(out) :: &
28+
du_send_s, du_send_e, du_recv_s, du_recv_e
29+
30+
type(tdsops_t), intent(in) :: tdsops
31+
integer, intent(in) :: nproc, pprev, pnext
32+
integer, intent(in) :: n_block
33+
34+
integer :: n_data
35+
integer :: k
36+
37+
n_data = SZ*n_block
38+
39+
!$omp parallel do
40+
do k = 1, n_block
41+
call der_univ_dist( &
42+
du(:, :, k), du_send_s(:, :, k), du_send_e(:, :, k), u(:, :, k), &
43+
u_recv_s(:, :, k), u_recv_e(:, :, k), &
44+
tdsops%coeffs_s, tdsops%coeffs_e, tdsops%coeffs, tdsops%n, &
45+
tdsops%dist_fw, tdsops%dist_bw, tdsops%dist_af &
46+
)
47+
end do
48+
49+
! halo exchange for 2x2 systems
50+
call sendrecv_fields(du_recv_s, du_recv_e, du_send_s, du_send_e, &
51+
n_data, nproc, pprev, pnext)
52+
53+
!$omp parallel do
54+
do k = 1, n_block
55+
call der_univ_subs(du(:, :, k), &
56+
du_recv_s(:, :, k), du_recv_e(:, :, k), &
57+
tdsops%n, tdsops%dist_sa, tdsops%dist_sc)
58+
end do
59+
!$omp end parallel do
60+
61+
end subroutine exec_dist_tds_compact
62+
63+
end module m_omp_exec_dist
64+

src/omp/kernels_dist.f90

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ module m_omp_kernels_dist
88

99
contains
1010

11-
subroutine der_univ_dist_omp( &
11+
subroutine der_univ_dist( &
1212
du, send_u_s, send_u_e, u, u_s, u_e, coeffs_s, coeffs_e, coeffs, n, &
1313
ffr, fbc, faf &
1414
)
@@ -134,9 +134,9 @@ subroutine der_univ_dist_omp( &
134134
end do
135135
!$omp end simd
136136

137-
end subroutine der_univ_dist_omp
137+
end subroutine der_univ_dist
138138

139-
subroutine der_univ_subs_omp(du, recv_u_s, recv_u_e, n, dist_sa, dist_sc)
139+
subroutine der_univ_subs(du, recv_u_s, recv_u_e, n, dist_sa, dist_sc)
140140
implicit none
141141

142142
! Arguments
@@ -193,6 +193,6 @@ subroutine der_univ_subs_omp(du, recv_u_s, recv_u_e, n, dist_sa, dist_sc)
193193
end do
194194
!$omp end simd
195195

196-
end subroutine der_univ_subs_omp
196+
end subroutine der_univ_subs
197197

198198
end module m_omp_kernels_dist

tests/omp/test_omp_tridiag.f90

Lines changed: 39 additions & 77 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,8 @@ program test_omp_tridiag
55

66
use m_common, only: dp, pi
77
use m_omp_common, only: SZ
8-
use m_omp_kernels_dist, only: der_univ_dist_omp, der_univ_subs_omp
98
use m_omp_sendrecv, only: sendrecv_fields
9+
use m_omp_exec_dist, only: exec_dist_tds_compact
1010

1111
use m_tdsops, only: tdsops_t, tdsops_init
1212

@@ -34,7 +34,7 @@ program test_omp_tridiag
3434

3535
integer :: n, n_block, i, j, k, n_halo, n_iters, iters, n_loc
3636
integer :: n_glob
37-
integer :: nrank, nproc, pprev, pnext, tag1=1234, tag2=1234
37+
integer :: nrank, nproc, pprev, pnext, tag1 = 1234, tag2 = 1234
3838
integer :: ierr, ndevs, devnum, memClockRt, memBusWidth
3939

4040
real(dp) :: dx, dx_per, norm_du, tol = 1d-8, tstart, tend
@@ -44,7 +44,7 @@ program test_omp_tridiag
4444
call MPI_Comm_rank(MPI_COMM_WORLD, nrank, ierr)
4545
call MPI_Comm_size(MPI_COMM_WORLD, nproc, ierr)
4646

47-
if (nrank == 0) print*, 'Parallel run with', nproc, 'ranks'
47+
if (nrank == 0) print *, 'Parallel run with', nproc, 'ranks'
4848

4949
pnext = modulo(nrank - nproc + 1, nproc)
5050
pprev = modulo(nrank - 1, nproc)
@@ -54,14 +54,14 @@ program test_omp_tridiag
5454
n_block = 512*512/SZ
5555
n_iters = 1
5656

57-
allocate(u(SZ, n, n_block), du(SZ, n, n_block))
57+
allocate (u(SZ, n, n_block), du(SZ, n, n_block))
5858

5959
dx_per = 2*pi/n_glob
6060
dx = 2*pi/(n_glob - 1)
6161

62-
allocate(sin_0_2pi_per(n), cos_0_2pi_per(n))
63-
allocate(sin_0_2pi(n), cos_0_2pi(n))
64-
allocate(sin_stag(n), cos_stag(n))
62+
allocate (sin_0_2pi_per(n), cos_0_2pi_per(n))
63+
allocate (sin_0_2pi(n), cos_0_2pi(n))
64+
allocate (sin_stag(n), cos_stag(n))
6565
do j = 1, n
6666
sin_0_2pi_per(j) = sin(((j - 1) + nrank*n)*dx_per)
6767
cos_0_2pi_per(j) = cos(((j - 1) + nrank*n)*dx_per)
@@ -74,13 +74,13 @@ program test_omp_tridiag
7474
n_halo = 4
7575

7676
! arrays for exchanging data between ranks
77-
allocate(u_send_s(SZ, n_halo, n_block))
78-
allocate(u_send_e(SZ, n_halo, n_block))
79-
allocate(u_recv_s(SZ, n_halo, n_block))
80-
allocate(u_recv_e(SZ, n_halo, n_block))
77+
allocate (u_send_s(SZ, n_halo, n_block))
78+
allocate (u_send_e(SZ, n_halo, n_block))
79+
allocate (u_recv_s(SZ, n_halo, n_block))
80+
allocate (u_recv_e(SZ, n_halo, n_block))
8181

82-
allocate(send_s(SZ, 1, n_block), send_e(SZ, 1, n_block))
83-
allocate(recv_s(SZ, 1, n_block), recv_e(SZ, 1, n_block))
82+
allocate (send_s(SZ, 1, n_block), send_e(SZ, 1, n_block))
83+
allocate (recv_s(SZ, 1, n_block), recv_e(SZ, 1, n_block))
8484

8585
! =========================================================================
8686
! second derivative with periodic BC
@@ -97,17 +97,17 @@ program test_omp_tridiag
9797
)
9898

9999
tend = omp_get_wtime()
100-
if (nrank == 0) print*, 'Total time', tend-tstart
100+
if (nrank == 0) print *, 'Total time', tend - tstart
101101

102102
call check_error_norm(du, sin_0_2pi_per, n, n_glob, n_block, 1, norm_du)
103-
if (nrank == 0) print*, 'error norm second-deriv periodic', norm_du
103+
if (nrank == 0) print *, 'error norm second-deriv periodic', norm_du
104104

105105
if (nrank == 0) then
106106
if (norm_du > tol) then
107107
allpass = .false.
108-
write(stderr, '(a)') 'Check 2nd derivatives, periodic BCs... failed'
108+
write (stderr, '(a)') 'Check 2nd derivatives, periodic BCs... failed'
109109
else
110-
write(stderr, '(a)') 'Check 2nd derivatives, periodic BCs... passed'
110+
write (stderr, '(a)') 'Check 2nd derivatives, periodic BCs... passed'
111111
end if
112112
end if
113113

@@ -124,14 +124,14 @@ program test_omp_tridiag
124124
)
125125

126126
call check_error_norm(du, cos_0_2pi_per, n, n_glob, n_block, -1, norm_du)
127-
if (nrank == 0) print*, 'error norm first-deriv periodic', norm_du
127+
if (nrank == 0) print *, 'error norm first-deriv periodic', norm_du
128128

129129
if (nrank == 0) then
130130
if (norm_du > tol) then
131131
allpass = .false.
132-
write(stderr, '(a)') 'Check 1st derivatives, periodic BCs... failed'
132+
write (stderr, '(a)') 'Check 1st derivatives, periodic BCs... failed'
133133
else
134-
write(stderr, '(a)') 'Check 1st derivatives, periodic BCs... passed'
134+
write (stderr, '(a)') 'Check 1st derivatives, periodic BCs... passed'
135135
end if
136136
end if
137137

@@ -161,14 +161,14 @@ program test_omp_tridiag
161161
)
162162

163163
call check_error_norm(du, cos_0_2pi, n, n_glob, n_block, -1, norm_du)
164-
if (nrank == 0) print*, 'error norm first deriv dir-neu', norm_du
164+
if (nrank == 0) print *, 'error norm first deriv dir-neu', norm_du
165165

166166
if (nrank == 0) then
167167
if (norm_du > tol) then
168168
allpass = .false.
169-
write(stderr, '(a)') 'Check 1st derivatives, dir-neu... failed'
169+
write (stderr, '(a)') 'Check 1st derivatives, dir-neu... failed'
170170
else
171-
write(stderr, '(a)') 'Check 1st derivatives, dir-neu... passed'
171+
write (stderr, '(a)') 'Check 1st derivatives, dir-neu... passed'
172172
end if
173173
end if
174174

@@ -189,14 +189,14 @@ program test_omp_tridiag
189189
)
190190

191191
call check_error_norm(du, cos_stag, n_loc, n_glob, n_block, -1, norm_du)
192-
if (nrank == 0) print*, 'error norm interpolate', norm_du
192+
if (nrank == 0) print *, 'error norm interpolate', norm_du
193193

194194
if (nrank == 0) then
195195
if (norm_du > tol) then
196196
allpass = .false.
197-
write(stderr, '(a)') 'Check interpolation... failed'
197+
write (stderr, '(a)') 'Check interpolation... failed'
198198
else
199-
write(stderr, '(a)') 'Check interpolation... passed'
199+
write (stderr, '(a)') 'Check interpolation... passed'
200200
end if
201201
end if
202202

@@ -217,21 +217,21 @@ program test_omp_tridiag
217217
)
218218

219219
call check_error_norm(du, sin_0_2pi, n, n_glob, n_block, 1, norm_du)
220-
if (nrank == 0) print*, 'error norm hyperviscous', norm_du
220+
if (nrank == 0) print *, 'error norm hyperviscous', norm_du
221221

222222
if (nrank == 0) then
223223
if (norm_du > tol) then
224224
allpass = .false.
225-
write(stderr, '(a)') 'Check 2nd ders, hyperviscous, dir-neu... failed'
225+
write (stderr, '(a)') 'Check 2nd ders, hyperviscous, dir-neu... failed'
226226
else
227-
write(stderr, '(a)') 'Check 2nd ders, hyperviscous, dir-neu... passed'
227+
write (stderr, '(a)') 'Check 2nd ders, hyperviscous, dir-neu... passed'
228228
end if
229229
end if
230230

231231
! =========================================================================
232232
! BW utilisation and performance checks
233233
! 3 in the first phase, 2 in the second phase, so 5 in total
234-
achievedBW = 5._dp*n_iters*n*n_block*SZ*dp/(tend-tstart)
234+
achievedBW = 5._dp*n_iters*n*n_block*SZ*dp/(tend - tstart)
235235
call MPI_Allreduce(achievedBW, achievedBWmax, 1, MPI_DOUBLE_PRECISION, &
236236
MPI_MAX, MPI_COMM_WORLD, ierr)
237237
call MPI_Allreduce(achievedBW, achievedBWmin, 1, MPI_DOUBLE_PRECISION, &
@@ -247,13 +247,13 @@ program test_omp_tridiag
247247

248248
if (nrank == 0) then
249249
print'(a, f8.3, a)', 'Available BW: ', deviceBW/2**30, &
250-
' GiB/s (per NUMA zone on ARCHER2)'
250+
' GiB/s (per NUMA zone on ARCHER2)'
251251
print'(a, f5.2)', 'Effective BW util min: %', achievedBWmin/deviceBW*100
252252
print'(a, f5.2)', 'Effective BW util max: %', achievedBWmax/deviceBW*100
253253
end if
254254

255255
if (allpass) then
256-
if (nrank == 0) write(stderr, '(a)') 'ALL TESTS PASSED SUCCESSFULLY.'
256+
if (nrank == 0) write (stderr, '(a)') 'ALL TESTS PASSED SUCCESSFULLY.'
257257
else
258258
error stop 'SOME TESTS FAILED.'
259259
end if
@@ -279,8 +279,7 @@ subroutine run_kernel(n_iters, n_block, u, du, tdsops, n, &
279279
send_s, send_e
280280
integer, intent(in) :: nproc, pprev, pnext
281281

282-
integer :: iters, i, j, k, ierr, tag1=1234, tag2=1234
283-
integer :: srerr(4), mpireq(4)
282+
integer :: iters, i, j, k
284283

285284
do iters = 1, n_iters
286285
! first copy halo data into buffers
@@ -289,8 +288,8 @@ subroutine run_kernel(n_iters, n_block, u, du, tdsops, n, &
289288
do j = 1, 4
290289
!$omp simd
291290
do i = 1, SZ
292-
u_send_s(i,j,k) = u(i,j,k)
293-
u_send_e(i,j,k) = u(i,n-n_halo+j,k)
291+
u_send_s(i, j, k) = u(i, j, k)
292+
u_send_e(i, j, k) = u(i, n - n_halo + j, k)
294293
end do
295294
!$omp end simd
296295
end do
@@ -299,48 +298,11 @@ subroutine run_kernel(n_iters, n_block, u, du, tdsops, n, &
299298

300299
! halo exchange
301300
call sendrecv_fields(u_recv_s, u_recv_e, u_send_s, u_send_e, &
302-
SZ*n_halo*n_block, nproc, pprev, pnext)
303-
304-
!$omp parallel do
305-
do k = 1, n_block
306-
call der_univ_dist_omp( &
307-
du(:, :, k), send_s(:, :, k), send_e(:, :, k), u(:, :, k), &
308-
u_recv_s(:, :, k), u_recv_e(:, :, k), &
309-
tdsops%coeffs_s, tdsops%coeffs_e, tdsops%coeffs, n, &
310-
tdsops%dist_fw, tdsops%dist_bw, tdsops%dist_af &
311-
)
312-
end do
313-
!$omp end parallel do
301+
SZ*n_halo*n_block, nproc, pprev, pnext)
314302

315-
! halo exchange for 2x2 systems
316-
if (nproc == 1) then
317-
recv_s = send_e
318-
recv_e = send_s
319-
else
320-
! MPI send/recv for multi-rank simulations
321-
call MPI_Isend(send_s, SZ*n_block, &
322-
MPI_DOUBLE_PRECISION, pprev, tag1, MPI_COMM_WORLD, &
323-
mpireq(1), srerr(1))
324-
call MPI_Irecv(recv_e, SZ*n_block, &
325-
MPI_DOUBLE_PRECISION, pnext, tag2, MPI_COMM_WORLD, &
326-
mpireq(2), srerr(2))
327-
call MPI_Isend(send_e, SZ*n_block, &
328-
MPI_DOUBLE_PRECISION, pnext, tag2, MPI_COMM_WORLD, &
329-
mpireq(3), srerr(3))
330-
call MPI_Irecv(recv_s, SZ*n_block, &
331-
MPI_DOUBLE_PRECISION, pprev, tag1, MPI_COMM_WORLD, &
332-
mpireq(4), srerr(4))
333-
334-
call MPI_Waitall(4, mpireq, MPI_STATUSES_IGNORE, ierr)
335-
end if
303+
call exec_dist_tds_compact(du, u, u_recv_s, u_recv_e, send_s, send_e, &
304+
recv_s, recv_e, tdsops, nproc, pprev, pnext, n_block)
336305

337-
!$omp parallel do
338-
do k = 1, n_block
339-
call der_univ_subs_omp(du(:, :, k), &
340-
recv_s(:, :, k), recv_e(:, :, k), &
341-
n, tdsops%dist_sa, tdsops%dist_sc)
342-
end do
343-
!$omp end parallel do
344306
end do
345307
end subroutine run_kernel
346308

@@ -384,7 +346,7 @@ subroutine check_error_norm(du, line, n, n_glob, n_block, c, norm)
384346
norm = norm2(du(:, 1:n, :))
385347
norm = norm*norm/n_glob/n_block/SZ
386348
call MPI_Allreduce(MPI_IN_PLACE, norm, 1, MPI_DOUBLE_PRECISION, &
387-
MPI_SUM, MPI_COMM_WORLD, ierr)
349+
MPI_SUM, MPI_COMM_WORLD, ierr)
388350
norm = sqrt(norm)
389351

390352
end subroutine check_error_norm

0 commit comments

Comments
 (0)