Skip to content

Commit 878496e

Browse files
committed
feat(tests/cuda): Use cuda_tdsops_t in CUDA tests.
1 parent bb3f5a4 commit 878496e

File tree

1 file changed

+68
-66
lines changed

1 file changed

+68
-66
lines changed

tests/cuda/test_cuda_tridiag.f90

Lines changed: 68 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -6,38 +6,30 @@ program test_cuda_tridiag
66
use m_common, only: dp, pi
77
use m_cuda_common, only: SZ
88
use m_cuda_kernels_dist, only: der_univ_dist, der_univ_subs
9-
use m_derparams, only: der_1_vv, der_2_vv
9+
use m_cuda_tdsops, only: cuda_tdsops_t, cuda_tdsops_init
1010

1111
implicit none
1212

1313
logical :: allpass = .true.
14-
real(dp), allocatable, dimension(:,:,:) :: u, du, u_s, u_e
15-
real(dp), device, allocatable, dimension(:,:,:) :: u_dev, du_dev
16-
real(dp), device, allocatable, dimension(:,:,:) :: u_recv_s_dev, u_recv_e_dev, &
17-
u_send_s_dev, u_send_e_dev
14+
real(dp), allocatable, dimension(:, :, :) :: u, du
15+
real(dp), device, allocatable, dimension(:, :, :) :: u_dev, du_dev
16+
real(dp), device, allocatable, dimension(:, :, :) :: &
17+
u_recv_s_dev, u_recv_e_dev, u_send_s_dev, u_send_e_dev
1818

19-
real(dp), device, allocatable, dimension(:,:,:) :: send_s_dev, send_e_dev, &
20-
recv_s_dev, recv_e_dev
19+
real(dp), device, allocatable, dimension(:, :, :) :: &
20+
send_s_dev, send_e_dev, recv_s_dev, recv_e_dev
2121

22-
real(dp), allocatable, dimension(:,:) :: coeffs_s, coeffs_e
23-
real(dp), allocatable, dimension(:) :: coeffs, dist_fr, dist_bc, dist_af, &
24-
dist_sa, dist_sc
22+
type(cuda_tdsops_t) :: tdsops
2523

26-
real(dp), device, allocatable, dimension(:,:) :: coeffs_s_dev, coeffs_e_dev
27-
real(dp), device, allocatable, dimension(:) :: coeffs_dev, &
28-
dist_fr_dev, dist_bc_dev, &
29-
dist_af_dev, &
30-
dist_sa_dev, dist_sc_dev
31-
32-
integer :: n, n_block, i, j, k, n_halo, n_stencil, n_iters
24+
integer :: n, n_block, i, j, k, n_halo, n_iters
3325
integer :: n_glob
3426
integer :: nrank, nproc, pprev, pnext, tag1=1234, tag2=1234
35-
integer, allocatable :: srerr(:), mpireq(:)
27+
integer :: srerr(4), mpireq(4)
3628
integer :: ierr, ndevs, devnum, memClockRt, memBusWidth
3729

3830
type(dim3) :: blocks, threads
39-
real(dp) :: dx, dx2, alfa, norm_du, tol = 1d-8, tstart, tend
40-
real(dp) :: achievedBW, deviceBW
31+
real(dp) :: dx, dx_per, norm_du, tol = 1d-8, tstart, tend
32+
real(dp) :: achievedBW, deviceBW, achievedBWmax, achievedBWmin
4133

4234
call MPI_Init(ierr)
4335
call MPI_Comm_rank(MPI_COMM_WORLD, nrank, ierr)
@@ -48,50 +40,34 @@ program test_cuda_tridiag
4840
ierr = cudaGetDeviceCount(ndevs)
4941
ierr = cudaSetDevice(mod(nrank, ndevs)) ! round-robin
5042
ierr = cudaGetDevice(devnum)
51-
print*, 'I am rank', nrank, 'I am running on device', devnum
52-
pnext = modulo(nrank-nproc+1, nproc)
53-
pprev = modulo(nrank-1, nproc)
54-
print*, 'rank', nrank, 'pnext', pnext, 'pprev', pprev
55-
allocate(srerr(nproc), mpireq(nproc))
43+
44+
!print*, 'I am rank', nrank, 'I am running on device', devnum
45+
pnext = modulo(nrank - nproc + 1, nproc)
46+
pprev = modulo(nrank - 1, nproc)
5647

5748
n_glob = 512*4
5849
n = n_glob/nproc
5950
n_block = 512*512/SZ
60-
n_iters = 1000
51+
n_iters = 100
6152

6253
allocate(u(SZ, n, n_block), du(SZ, n, n_block))
6354
allocate(u_dev(SZ, n, n_block), du_dev(SZ, n, n_block))
6455

65-
dx = 2*pi/n_glob
66-
dx2 = dx*dx
56+
dx_per = 2*pi/n_glob
57+
dx = 2*pi/(n_glob - 1)
6758

6859
do k = 1, n_block
6960
do j = 1, n
7061
do i = 1, SZ
71-
u(i, j, k) = sin((j-1+nrank*n)*dx)
62+
u(i, j, k) = sin((j - 1 + nrank*n)*dx_per)
7263
end do
7364
end do
7465
end do
7566

7667
! move data to device
7768
u_dev = u
7869

79-
! set up the tridiagonal solver coeffs
80-
call der_2_vv(coeffs, coeffs_s, coeffs_e, dist_fr, dist_bc, dist_af, &
81-
dist_sa, dist_sc, n_halo, dx2, n, 'periodic')
82-
83-
n_stencil = n_halo*2 + 1
84-
85-
allocate(coeffs_s_dev(n_stencil, n_halo), coeffs_e_dev(n_stencil, n_halo))
86-
allocate(coeffs_dev(n_stencil))
87-
coeffs_s_dev(:, :) = coeffs_s(:, :); coeffs_e_dev(:, :) = coeffs_e(:, :)
88-
coeffs_dev(:) = coeffs(:)
89-
90-
allocate(dist_fr_dev(n), dist_bc_dev(n), dist_af_dev(n), &
91-
dist_sa_dev(n), dist_sc_dev(n))
92-
dist_fr_dev(:) = dist_fr(:); dist_bc_dev(:) = dist_bc(:)
93-
dist_af_dev(:) = dist_af(:)
94-
dist_sa_dev(:) = dist_sa(:); dist_sc_dev(:) = dist_sc(:)
70+
n_halo = 4
9571

9672
! arrays for exchanging data between ranks
9773
allocate(u_send_s_dev(SZ, n_halo, n_block))
@@ -102,13 +78,17 @@ program test_cuda_tridiag
10278
allocate(send_s_dev(SZ, 1, n_block), send_e_dev(SZ, 1, n_block))
10379
allocate(recv_s_dev(SZ, 1, n_block), recv_e_dev(SZ, 1, n_block))
10480

81+
! preprocess the operator and coefficient arrays
82+
tdsops = cuda_tdsops_init(n, dx_per, operation='second-deriv', &
83+
scheme='compact6')
84+
10585
blocks = dim3(n_block, 1, 1)
10686
threads = dim3(SZ, 1, 1)
10787

10888
call cpu_time(tstart)
10989
do i = 1, n_iters
110-
u_send_s_dev(:,:,:) = u_dev(:,1:4,:)
111-
u_send_e_dev(:,:,:) = u_dev(:,n-n_halo+1:n,:)
90+
u_send_s_dev(:, :, :) = u_dev(:, 1:4, :)
91+
u_send_e_dev(:, :, :) = u_dev(:, n - n_halo + 1:n, :)
11292

11393
! halo exchange
11494
if (nproc == 1) then
@@ -135,8 +115,8 @@ program test_cuda_tridiag
135115

136116
call der_univ_dist<<<blocks, threads>>>( &
137117
du_dev, send_s_dev, send_e_dev, u_dev, u_recv_s_dev, u_recv_e_dev, &
138-
coeffs_s_dev, coeffs_e_dev, coeffs_dev, &
139-
n, dist_fr_dev, dist_bc_dev, dist_af_dev &
118+
tdsops%coeffs_s_dev, tdsops%coeffs_e_dev, tdsops%coeffs_dev, &
119+
n, tdsops%dist_fw_dev, tdsops%dist_bw_dev, tdsops%dist_af_dev &
140120
)
141121

142122
! halo exchange for 2x2 systems
@@ -163,36 +143,58 @@ program test_cuda_tridiag
163143

164144
call der_univ_subs<<<blocks, threads>>>( &
165145
du_dev, recv_s_dev, recv_e_dev, &
166-
n, dist_sa_dev, dist_sc_dev &
146+
n, tdsops%dist_sa_dev, tdsops%dist_sc_dev &
167147
)
168148
end do
169-
call cpu_time(tend)
170-
print*, 'Total time', tend-tstart
171149

172-
achievedBW = 6._dp*n_iters*n*n_block*SZ*dp/(tend-tstart)
173-
print'(a, f8.3, a)', 'Achieved BW: ', achievedBW/2**30, ' GiB/s'
150+
call cpu_time(tend)
151+
if (nrank == 0) print*, 'Total time', tend - tstart
152+
153+
! BW utilisation and performance checks
154+
! 4 in the first phase, 2 in the second phase, 6 in total
155+
achievedBW = 6._dp*n_iters*n*n_block*SZ*dp/(tend - tstart)
156+
call MPI_Allreduce(achievedBW, achievedBWmax, 1, MPI_DOUBLE_PRECISION, &
157+
MPI_MAX, MPI_COMM_WORLD, ierr)
158+
call MPI_Allreduce(achievedBW, achievedBWmin, 1, MPI_DOUBLE_PRECISION, &
159+
MPI_MIN, MPI_COMM_WORLD, ierr)
160+
161+
if (nrank == 0) then
162+
print'(a, f8.3, a)', 'Achieved BW min: ', achievedBWmin/2**30, ' GiB/s'
163+
print'(a, f8.3, a)', 'Achieved BW max: ', achievedBWmax/2**30, ' GiB/s'
164+
end if
174165

175166
ierr = cudaDeviceGetAttribute(memClockRt, cudaDevAttrMemoryClockRate, 0)
176-
ierr = cudaDeviceGetAttribute(memBusWidth, cudaDevAttrGlobalMemoryBusWidth, 0)
167+
ierr = cudaDeviceGetAttribute(memBusWidth, &
168+
cudaDevAttrGlobalMemoryBusWidth, 0)
177169
deviceBW = 2*memBusWidth/8._dp*memClockRt*1000
178170

179-
print'(a, f8.3, a)', 'Device BW: ', deviceBW/2**30, ' GiB/s'
180-
print'(a, f5.2)', 'Effective BW utilization: %', achievedBW/deviceBW*100
171+
if (nrank == 0) then
172+
print'(a, f8.3, a)', 'Device BW: ', deviceBW/2**30, ' GiB/s'
173+
print'(a, f5.2)', 'Effective BW util min: %', achievedBWmin/deviceBW*100
174+
print'(a, f5.2)', 'Effective BW util max: %', achievedBWmax/deviceBW*100
175+
end if
181176

182177
! check error
183178
du = du_dev
184-
norm_du = norm2(u+du)/n_block
185-
print*, 'error norm', norm_du
186-
187-
if ( norm_du > tol ) then
188-
allpass = .false.
189-
write(stderr, '(a)') 'Check second derivatives... failed'
190-
else
191-
write(stderr, '(a)') 'Check second derivatives... passed'
179+
norm_du = norm2(u + du)
180+
norm_du = norm_du*norm_du/n_glob/n_block/SZ
181+
call MPI_Allreduce(MPI_IN_PLACE, norm_du, 1, MPI_DOUBLE_PRECISION, &
182+
MPI_SUM, MPI_COMM_WORLD, ierr)
183+
norm_du = sqrt(norm_du)
184+
185+
if (nrank == 0) print*, 'error norm', norm_du
186+
187+
if (nrank == 0) then
188+
if ( norm_du > tol ) then
189+
allpass = .false.
190+
write(stderr, '(a)') 'Check second derivatives... failed'
191+
else
192+
write(stderr, '(a)') 'Check second derivatives... passed'
193+
end if
192194
end if
193195

194196
if (allpass) then
195-
write(stderr, '(a)') 'ALL TESTS PASSED SUCCESSFULLY.'
197+
if (nrank == 0) write(stderr, '(a)') 'ALL TESTS PASSED SUCCESSFULLY.'
196198
else
197199
error stop 'SOME TESTS FAILED.'
198200
end if

0 commit comments

Comments
 (0)