Skip to content

Commit 3769b96

Browse files
committed
test(cuda): Add a test for the fused distributed transeq kernel.
1 parent 477963f commit 3769b96

File tree

2 files changed

+292
-0
lines changed

2 files changed

+292
-0
lines changed

tests/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ set(TESTSRC
88
set(CUDATESTSRC
99
cuda/test_cuda_allocator.f90
1010
cuda/test_cuda_tridiag.f90
11+
cuda/test_cuda_transeq.f90
1112
)
1213

1314
if(${CMAKE_Fortran_COMPILER_ID} STREQUAL "PGI")

tests/cuda/test_cuda_transeq.f90

Lines changed: 291 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,291 @@
1+
program test_cuda_tridiag
2+
use iso_fortran_env, only: stderr => error_unit
3+
use cudafor
4+
use mpi
5+
6+
use m_common, only: dp, pi
7+
use m_cuda_common, only: SZ
8+
use m_cuda_kernels_dist, only: transeq_3fused_dist, transeq_3fused_subs
9+
use m_cuda_tdsops, only: cuda_tdsops_t
10+
11+
implicit none
12+
13+
logical :: allpass = .true.
14+
real(dp), allocatable, dimension(:, :, :) :: u, v, r_u
15+
real(dp), device, allocatable, dimension(:, :, :) :: &
16+
u_dev, v_dev, r_u_dev, & ! main fields u, v and result r_u
17+
du_dev, dud_dev, d2u_dev ! intermediate solution arrays
18+
real(dp), device, allocatable, dimension(:, :, :) :: &
19+
du_recv_s_dev, du_recv_e_dev, du_send_s_dev, du_send_e_dev, &
20+
dud_recv_s_dev, dud_recv_e_dev, dud_send_s_dev, dud_send_e_dev, &
21+
d2u_recv_s_dev, d2u_recv_e_dev, d2u_send_s_dev, d2u_send_e_dev
22+
23+
real(dp), device, allocatable, dimension(:, :, :) :: &
24+
u_send_s_dev, u_send_e_dev, u_recv_s_dev, u_recv_e_dev, &
25+
v_send_s_dev, v_send_e_dev, v_recv_s_dev, v_recv_e_dev
26+
27+
type(cuda_tdsops_t) :: der1st, der2nd
28+
29+
integer :: n, n_block, i, j, k, n_halo, n_iters
30+
integer :: n_glob
31+
integer :: nrank, nproc, pprev, pnext, tag1=1234, tag2=1234
32+
integer :: srerr(12), mpireq(12)
33+
integer :: ierr, ndevs, devnum, memClockRt, memBusWidth
34+
35+
type(dim3) :: blocks, threads
36+
real(dp) :: dx, dx_per, nu, norm_du, tol = 1d-8, tstart, tend
37+
real(dp) :: achievedBW, deviceBW, achievedBWmax, achievedBWmin
38+
39+
call MPI_Init(ierr)
40+
call MPI_Comm_rank(MPI_COMM_WORLD, nrank, ierr)
41+
call MPI_Comm_size(MPI_COMM_WORLD, nproc, ierr)
42+
43+
if (nrank == 0) print*, 'Parallel run with', nproc, 'ranks'
44+
45+
ierr = cudaGetDeviceCount(ndevs)
46+
ierr = cudaSetDevice(mod(nrank, ndevs)) ! round-robin
47+
ierr = cudaGetDevice(devnum)
48+
49+
!print*, 'I am rank', nrank, 'I am running on device', devnum
50+
pnext = modulo(nrank - nproc + 1, nproc)
51+
pprev = modulo(nrank - 1, nproc)
52+
53+
n_glob = 512*4
54+
n = n_glob/nproc
55+
n_block = 512*512/SZ
56+
n_iters = 100
57+
58+
nu = 1._dp
59+
60+
allocate(u(SZ, n, n_block), v(SZ, n, n_block), r_u(SZ, n, n_block))
61+
62+
! main input fields
63+
allocate(u_dev(SZ, n, n_block), v_dev(SZ, n, n_block))
64+
! field for storing the result
65+
allocate(r_u_dev(SZ, n, n_block))
66+
! intermediate solution fields
67+
allocate(du_dev(SZ, n, n_block))
68+
allocate(dud_dev(SZ, n, n_block))
69+
allocate(d2u_dev(SZ, n, n_block))
70+
71+
dx_per = 2*pi/n_glob
72+
dx = 2*pi/(n_glob - 1)
73+
74+
do k = 1, n_block
75+
do j = 1, n
76+
do i = 1, SZ
77+
u(i, j, k) = sin((j - 1 + nrank*n)*dx_per)
78+
v(i, j, k) = cos((j - 1 + nrank*n)*dx_per)
79+
end do
80+
end do
81+
end do
82+
83+
! move data to device
84+
u_dev = u
85+
v_dev = v
86+
87+
n_halo = 4
88+
89+
! arrays for exchanging data between ranks
90+
allocate(u_send_s_dev(SZ, n_halo, n_block))
91+
allocate(u_send_e_dev(SZ, n_halo, n_block))
92+
allocate(u_recv_s_dev(SZ, n_halo, n_block))
93+
allocate(u_recv_e_dev(SZ, n_halo, n_block))
94+
allocate(v_send_s_dev(SZ, n_halo, n_block))
95+
allocate(v_send_e_dev(SZ, n_halo, n_block))
96+
allocate(v_recv_s_dev(SZ, n_halo, n_block))
97+
allocate(v_recv_e_dev(SZ, n_halo, n_block))
98+
99+
allocate(du_send_s_dev(SZ, 1, n_block), du_send_e_dev(SZ, 1, n_block))
100+
allocate(du_recv_s_dev(SZ, 1, n_block), du_recv_e_dev(SZ, 1, n_block))
101+
allocate(dud_send_s_dev(SZ, 1, n_block), dud_send_e_dev(SZ, 1, n_block))
102+
allocate(dud_recv_s_dev(SZ, 1, n_block), dud_recv_e_dev(SZ, 1, n_block))
103+
allocate(d2u_send_s_dev(SZ, 1, n_block), d2u_send_e_dev(SZ, 1, n_block))
104+
allocate(d2u_recv_s_dev(SZ, 1, n_block), d2u_recv_e_dev(SZ, 1, n_block))
105+
106+
! preprocess the operator and coefficient arrays
107+
der1st = cuda_tdsops_t(n, dx_per, operation='first-deriv', &
108+
scheme='compact6')
109+
der2nd = cuda_tdsops_t(n, dx_per, operation='second-deriv', &
110+
scheme='compact6')
111+
112+
blocks = dim3(n_block, 1, 1)
113+
threads = dim3(SZ, 1, 1)
114+
115+
call cpu_time(tstart)
116+
do i = 1, n_iters
117+
u_send_s_dev(:, :, :) = u_dev(:, 1:4, :)
118+
u_send_e_dev(:, :, :) = u_dev(:, n - n_halo + 1:n, :)
119+
v_send_s_dev(:, :, :) = v_dev(:, 1:4, :)
120+
v_send_e_dev(:, :, :) = v_dev(:, n - n_halo + 1:n, :)
121+
122+
! halo exchange
123+
if (nproc == 1) then
124+
u_recv_s_dev = u_send_e_dev
125+
u_recv_e_dev = u_send_s_dev
126+
v_recv_s_dev = v_send_e_dev
127+
v_recv_e_dev = v_send_s_dev
128+
else
129+
! MPI send/recv for multi-rank simulations
130+
call MPI_Isend(u_send_s_dev, SZ*n_halo*n_block, &
131+
MPI_DOUBLE_PRECISION, pprev, tag1, MPI_COMM_WORLD, &
132+
mpireq(1), srerr(1))
133+
call MPI_Irecv(u_recv_e_dev, SZ*n_halo*n_block, &
134+
MPI_DOUBLE_PRECISION, pnext, tag1, MPI_COMM_WORLD, &
135+
mpireq(2), srerr(2))
136+
call MPI_Isend(u_send_e_dev, SZ*n_halo*n_block, &
137+
MPI_DOUBLE_PRECISION, pnext, tag2, MPI_COMM_WORLD, &
138+
mpireq(3), srerr(3))
139+
call MPI_Irecv(u_recv_s_dev, SZ*n_halo*n_block, &
140+
MPI_DOUBLE_PRECISION, pprev, tag2, MPI_COMM_WORLD, &
141+
mpireq(4), srerr(4))
142+
143+
call MPI_Isend(v_send_s_dev, SZ*n_halo*n_block, &
144+
MPI_DOUBLE_PRECISION, pprev, tag1, MPI_COMM_WORLD, &
145+
mpireq(5), srerr(5))
146+
call MPI_Irecv(v_recv_e_dev, SZ*n_halo*n_block, &
147+
MPI_DOUBLE_PRECISION, pnext, tag1, MPI_COMM_WORLD, &
148+
mpireq(6), srerr(6))
149+
call MPI_Isend(v_send_e_dev, SZ*n_halo*n_block, &
150+
MPI_DOUBLE_PRECISION, pnext, tag2, MPI_COMM_WORLD, &
151+
mpireq(7), srerr(7))
152+
call MPI_Irecv(v_recv_s_dev, SZ*n_halo*n_block, &
153+
MPI_DOUBLE_PRECISION, pprev, tag2, MPI_COMM_WORLD, &
154+
mpireq(8), srerr(8))
155+
156+
call MPI_Waitall(8, mpireq, MPI_STATUSES_IGNORE, ierr)
157+
end if
158+
159+
call transeq_3fused_dist<<<blocks, threads>>>( &
160+
du_dev, dud_dev, d2u_dev, &
161+
du_send_s_dev, du_send_e_dev, &
162+
dud_send_s_dev, dud_send_e_dev, &
163+
d2u_send_s_dev, d2u_send_e_dev, &
164+
u_dev, u_recv_s_dev, u_recv_e_dev, &
165+
v_dev, v_recv_s_dev, v_recv_e_dev, n, &
166+
der1st%coeffs_s_dev, der1st%coeffs_e_dev, der1st%coeffs_dev, &
167+
der1st%dist_fw_dev, der1st%dist_bw_dev, der1st%dist_af_dev, &
168+
der2nd%coeffs_s_dev, der2nd%coeffs_e_dev, der2nd%coeffs_dev, &
169+
der2nd%dist_fw_dev, der2nd%dist_bw_dev, der2nd%dist_af_dev &
170+
)
171+
172+
! halo exchange for 2x2 systems
173+
if (nproc == 1) then
174+
du_recv_s_dev = du_send_e_dev
175+
du_recv_e_dev = du_send_s_dev
176+
dud_recv_s_dev = dud_send_e_dev
177+
dud_recv_e_dev = dud_send_s_dev
178+
d2u_recv_s_dev = d2u_send_e_dev
179+
d2u_recv_e_dev = d2u_send_s_dev
180+
else
181+
! MPI send/recv for multi-rank simulations
182+
call MPI_Isend(du_send_s_dev, SZ*n_block, &
183+
MPI_DOUBLE_PRECISION, pprev, tag1, MPI_COMM_WORLD, &
184+
mpireq(1), srerr(1))
185+
call MPI_Irecv(du_recv_e_dev, SZ*n_block, &
186+
MPI_DOUBLE_PRECISION, pnext, tag2, MPI_COMM_WORLD, &
187+
mpireq(2), srerr(2))
188+
call MPI_Isend(du_send_e_dev, SZ*n_block, &
189+
MPI_DOUBLE_PRECISION, pnext, tag2, MPI_COMM_WORLD, &
190+
mpireq(3), srerr(3))
191+
call MPI_Irecv(du_recv_s_dev, SZ*n_block, &
192+
MPI_DOUBLE_PRECISION, pprev, tag1, MPI_COMM_WORLD, &
193+
mpireq(4), srerr(4))
194+
195+
call MPI_Isend(dud_send_s_dev, SZ*n_block, &
196+
MPI_DOUBLE_PRECISION, pprev, tag1, MPI_COMM_WORLD, &
197+
mpireq(5), srerr(5))
198+
call MPI_Irecv(dud_recv_e_dev, SZ*n_block, &
199+
MPI_DOUBLE_PRECISION, pnext, tag2, MPI_COMM_WORLD, &
200+
mpireq(6), srerr(6))
201+
call MPI_Isend(dud_send_e_dev, SZ*n_block, &
202+
MPI_DOUBLE_PRECISION, pnext, tag2, MPI_COMM_WORLD, &
203+
mpireq(7), srerr(7))
204+
call MPI_Irecv(dud_recv_s_dev, SZ*n_block, &
205+
MPI_DOUBLE_PRECISION, pprev, tag1, MPI_COMM_WORLD, &
206+
mpireq(8), srerr(8))
207+
208+
call MPI_Isend(d2u_send_s_dev, SZ*n_block, &
209+
MPI_DOUBLE_PRECISION, pprev, tag1, MPI_COMM_WORLD, &
210+
mpireq(9), srerr(9))
211+
call MPI_Irecv(d2u_recv_e_dev, SZ*n_block, &
212+
MPI_DOUBLE_PRECISION, pnext, tag2, MPI_COMM_WORLD, &
213+
mpireq(10), srerr(10))
214+
call MPI_Isend(d2u_send_e_dev, SZ*n_block, &
215+
MPI_DOUBLE_PRECISION, pnext, tag2, MPI_COMM_WORLD, &
216+
mpireq(11), srerr(11))
217+
call MPI_Irecv(d2u_recv_s_dev, SZ*n_block, &
218+
MPI_DOUBLE_PRECISION, pprev, tag1, MPI_COMM_WORLD, &
219+
mpireq(12), srerr(12))
220+
221+
call MPI_Waitall(12, mpireq, MPI_STATUSES_IGNORE, ierr)
222+
end if
223+
224+
call transeq_3fused_subs<<<blocks, threads>>>( &
225+
r_u_dev, v_dev, du_dev, dud_dev, d2u_dev, &
226+
du_recv_s_dev, du_recv_e_dev, &
227+
dud_recv_s_dev, dud_recv_e_dev, &
228+
d2u_recv_s_dev, d2u_recv_e_dev, &
229+
der1st%dist_sa_dev, der1st%dist_sc_dev, &
230+
der2nd%dist_sa_dev, der2nd%dist_sc_dev, &
231+
n, nu &
232+
)
233+
end do
234+
235+
call cpu_time(tend)
236+
if (nrank == 0) print*, 'Total time', tend - tstart
237+
238+
! BW utilisation and performance checks
239+
! 11 in the first phase, 5 in the second phase, 16 in total
240+
achievedBW = 16._dp*n_iters*n*n_block*SZ*dp/(tend - tstart)
241+
call MPI_Allreduce(achievedBW, achievedBWmax, 1, MPI_DOUBLE_PRECISION, &
242+
MPI_MAX, MPI_COMM_WORLD, ierr)
243+
call MPI_Allreduce(achievedBW, achievedBWmin, 1, MPI_DOUBLE_PRECISION, &
244+
MPI_MIN, MPI_COMM_WORLD, ierr)
245+
246+
if (nrank == 0) then
247+
print'(a, f8.3, a)', 'Achieved BW min: ', achievedBWmin/2**30, ' GiB/s'
248+
print'(a, f8.3, a)', 'Achieved BW max: ', achievedBWmax/2**30, ' GiB/s'
249+
end if
250+
251+
ierr = cudaDeviceGetAttribute(memClockRt, cudaDevAttrMemoryClockRate, 0)
252+
ierr = cudaDeviceGetAttribute(memBusWidth, &
253+
cudaDevAttrGlobalMemoryBusWidth, 0)
254+
deviceBW = 2*memBusWidth/8._dp*memClockRt*1000
255+
256+
if (nrank == 0) then
257+
print'(a, f8.3, a)', 'Device BW: ', deviceBW/2**30, ' GiB/s'
258+
print'(a, f5.2)', 'Effective BW util min: %', achievedBWmin/deviceBW*100
259+
print'(a, f5.2)', 'Effective BW util max: %', achievedBWmax/deviceBW*100
260+
end if
261+
262+
! check error
263+
r_u = r_u_dev
264+
r_u = r_u - (-v*v + 0.5_dp*u*u - nu*u)
265+
norm_du = norm2(r_u)
266+
norm_du = norm_du*norm_du/n_glob/n_block/SZ
267+
call MPI_Allreduce(MPI_IN_PLACE, norm_du, 1, MPI_DOUBLE_PRECISION, &
268+
MPI_SUM, MPI_COMM_WORLD, ierr)
269+
norm_du = sqrt(norm_du)
270+
271+
if (nrank == 0) print*, 'error norm', norm_du
272+
273+
if (nrank == 0) then
274+
if ( norm_du > tol ) then
275+
allpass = .false.
276+
write(stderr, '(a)') 'Check second derivatives... failed'
277+
else
278+
write(stderr, '(a)') 'Check second derivatives... passed'
279+
end if
280+
end if
281+
282+
if (allpass) then
283+
if (nrank == 0) write(stderr, '(a)') 'ALL TESTS PASSED SUCCESSFULLY.'
284+
else
285+
error stop 'SOME TESTS FAILED.'
286+
end if
287+
288+
call MPI_Finalize(ierr)
289+
290+
end program test_cuda_tridiag
291+

0 commit comments

Comments
 (0)