@@ -6,38 +6,30 @@ program test_cuda_tridiag
66 use m_common, only: dp, pi
77 use m_cuda_common, only: SZ
88 use m_cuda_kernels_dist, only: der_univ_dist, der_univ_subs
9- use m_derparams , only: der_1_vv, der_2_vv
9+ use m_cuda_tdsops , only: cuda_tdsops_t, cuda_tdsops_init
1010
1111 implicit none
1212
1313 logical :: allpass = .true.
14- real (dp), allocatable , dimension (:,:, :) :: u, du, u_s, u_e
15- real (dp), device, allocatable , dimension (:,:, :) :: u_dev, du_dev
16- real (dp), device, allocatable , dimension (:,:, :) :: u_recv_s_dev, u_recv_e_dev, &
17- u_send_s_dev, u_send_e_dev
14+ real (dp), allocatable , dimension (:, :, :) :: u, du
15+ real (dp), device, allocatable , dimension (:, :, :) :: u_dev, du_dev
16+ real (dp), device, allocatable , dimension (:, :, :) :: &
17+ u_recv_s_dev, u_recv_e_dev, u_send_s_dev, u_send_e_dev
1818
19- real (dp), device, allocatable , dimension (:,:, :) :: send_s_dev, send_e_dev, &
20- recv_s_dev, recv_e_dev
19+ real (dp), device, allocatable , dimension (:, :, :) :: &
20+ send_s_dev, send_e_dev, recv_s_dev, recv_e_dev
2121
22- real (dp), allocatable , dimension (:,:) :: coeffs_s, coeffs_e
23- real (dp), allocatable , dimension (:) :: coeffs, dist_fr, dist_bc, dist_af, &
24- dist_sa, dist_sc
22+ type (cuda_tdsops_t) :: tdsops
2523
26- real (dp), device, allocatable , dimension (:,:) :: coeffs_s_dev, coeffs_e_dev
27- real (dp), device, allocatable , dimension (:) :: coeffs_dev, &
28- dist_fr_dev, dist_bc_dev, &
29- dist_af_dev, &
30- dist_sa_dev, dist_sc_dev
31-
32- integer :: n, n_block, i, j, k, n_halo, n_stencil, n_iters
24+ integer :: n, n_block, i, j, k, n_halo, n_iters
3325 integer :: n_glob
3426 integer :: nrank, nproc, pprev, pnext, tag1= 1234 , tag2= 1234
35- integer , allocatable :: srerr(: ), mpireq(: )
27+ integer :: srerr(4 ), mpireq(4 )
3628 integer :: ierr, ndevs, devnum, memClockRt, memBusWidth
3729
3830 type (dim3) :: blocks, threads
39- real (dp) :: dx, dx2, alfa , norm_du, tol = 1d-8 , tstart, tend
40- real (dp) :: achievedBW, deviceBW
31+ real (dp) :: dx, dx_per , norm_du, tol = 1d-8 , tstart, tend
32+ real (dp) :: achievedBW, deviceBW, achievedBWmax, achievedBWmin
4133
4234 call MPI_Init(ierr)
4335 call MPI_Comm_rank(MPI_COMM_WORLD, nrank, ierr)
@@ -48,50 +40,34 @@ program test_cuda_tridiag
4840 ierr = cudaGetDeviceCount(ndevs)
4941 ierr = cudaSetDevice(mod (nrank, ndevs)) ! round-robin
5042 ierr = cudaGetDevice(devnum)
51- print * , ' I am rank' , nrank, ' I am running on device' , devnum
52- pnext = modulo (nrank- nproc+1 , nproc)
53- pprev = modulo (nrank-1 , nproc)
54- print * , ' rank' , nrank, ' pnext' , pnext, ' pprev' , pprev
55- allocate (srerr(nproc), mpireq(nproc))
43+
44+ ! print*, 'I am rank', nrank, 'I am running on device', devnum
45+ pnext = modulo (nrank - nproc + 1 , nproc)
46+ pprev = modulo (nrank - 1 , nproc)
5647
5748 n_glob = 512 * 4
5849 n = n_glob/ nproc
5950 n_block = 512 * 512 / SZ
60- n_iters = 1000
51+ n_iters = 100
6152
6253 allocate (u(SZ, n, n_block), du(SZ, n, n_block))
6354 allocate (u_dev(SZ, n, n_block), du_dev(SZ, n, n_block))
6455
65- dx = 2 * pi/ n_glob
66- dx2 = dx * dx
56+ dx_per = 2 * pi/ n_glob
57+ dx = 2 * pi / (n_glob - 1 )
6758
6859 do k = 1 , n_block
6960 do j = 1 , n
7061 do i = 1 , SZ
71- u(i, j, k) = sin ((j-1 + nrank* n)* dx )
62+ u(i, j, k) = sin ((j - 1 + nrank* n)* dx_per )
7263 end do
7364 end do
7465 end do
7566
7667 ! move data to device
7768 u_dev = u
7869
79- ! set up the tridiagonal solver coeffs
80- call der_2_vv(coeffs, coeffs_s, coeffs_e, dist_fr, dist_bc, dist_af, &
81- dist_sa, dist_sc, n_halo, dx2, n, ' periodic' )
82-
83- n_stencil = n_halo* 2 + 1
84-
85- allocate (coeffs_s_dev(n_stencil, n_halo), coeffs_e_dev(n_stencil, n_halo))
86- allocate (coeffs_dev(n_stencil))
87- coeffs_s_dev(:, :) = coeffs_s(:, :); coeffs_e_dev(:, :) = coeffs_e(:, :)
88- coeffs_dev(:) = coeffs(:)
89-
90- allocate (dist_fr_dev(n), dist_bc_dev(n), dist_af_dev(n), &
91- dist_sa_dev(n), dist_sc_dev(n))
92- dist_fr_dev(:) = dist_fr(:); dist_bc_dev(:) = dist_bc(:)
93- dist_af_dev(:) = dist_af(:)
94- dist_sa_dev(:) = dist_sa(:); dist_sc_dev(:) = dist_sc(:)
70+ n_halo = 4
9571
9672 ! arrays for exchanging data between ranks
9773 allocate (u_send_s_dev(SZ, n_halo, n_block))
@@ -102,13 +78,17 @@ program test_cuda_tridiag
10278 allocate (send_s_dev(SZ, 1 , n_block), send_e_dev(SZ, 1 , n_block))
10379 allocate (recv_s_dev(SZ, 1 , n_block), recv_e_dev(SZ, 1 , n_block))
10480
81+ ! preprocess the operator and coefficient arrays
82+ tdsops = cuda_tdsops_init(n, dx_per, operation= ' second-deriv' , &
83+ scheme= ' compact6' )
84+
10585 blocks = dim3(n_block, 1 , 1 )
10686 threads = dim3(SZ, 1 , 1 )
10787
10888 call cpu_time(tstart)
10989 do i = 1 , n_iters
110- u_send_s_dev(:,:, :) = u_dev(:,1 :4 ,:)
111- u_send_e_dev(:,:, :) = u_dev(:,n - n_halo+ 1 :n,:)
90+ u_send_s_dev(:, :, :) = u_dev(:, 1 :4 , :)
91+ u_send_e_dev(:, :, :) = u_dev(:, n - n_halo + 1 :n, :)
11292
11393 ! halo exchange
11494 if (nproc == 1 ) then
@@ -135,8 +115,8 @@ program test_cuda_tridiag
135115
136116 call der_univ_dist<<<blocks, threads>>>( &
137117 du_dev, send_s_dev, send_e_dev, u_dev, u_recv_s_dev, u_recv_e_dev, &
138- coeffs_s_dev, coeffs_e_dev, coeffs_dev, &
139- n, dist_fr_dev, dist_bc_dev, dist_af_dev &
118+ tdsops % coeffs_s_dev, tdsops % coeffs_e_dev, tdsops % coeffs_dev, &
119+ n, tdsops % dist_fw_dev, tdsops % dist_bw_dev, tdsops % dist_af_dev &
140120 )
141121
142122 ! halo exchange for 2x2 systems
@@ -163,36 +143,58 @@ program test_cuda_tridiag
163143
164144 call der_univ_subs<<<blocks, threads>>>( &
165145 du_dev, recv_s_dev, recv_e_dev, &
166- n, dist_sa_dev, dist_sc_dev &
146+ n, tdsops % dist_sa_dev, tdsops % dist_sc_dev &
167147 )
168148 end do
169- call cpu_time(tend)
170- print * , ' Total time' , tend- tstart
171149
172- achievedBW = 6._dp * n_iters* n* n_block* SZ* dp/ (tend- tstart)
173- print ' (a, f8.3, a)' , ' Achieved BW: ' , achievedBW/ 2 ** 30 , ' GiB/s'
150+ call cpu_time(tend)
151+ if (nrank == 0 ) print * , ' Total time' , tend - tstart
152+
153+ ! BW utilisation and performance checks
154+ ! 4 in the first phase, 2 in the second phase, 6 in total
155+ achievedBW = 6._dp * n_iters* n* n_block* SZ* dp/ (tend - tstart)
156+ call MPI_Allreduce(achievedBW, achievedBWmax, 1 , MPI_DOUBLE_PRECISION, &
157+ MPI_MAX, MPI_COMM_WORLD, ierr)
158+ call MPI_Allreduce(achievedBW, achievedBWmin, 1 , MPI_DOUBLE_PRECISION, &
159+ MPI_MIN, MPI_COMM_WORLD, ierr)
160+
161+ if (nrank == 0 ) then
162+ print ' (a, f8.3, a)' , ' Achieved BW min: ' , achievedBWmin/ 2 ** 30 , ' GiB/s'
163+ print ' (a, f8.3, a)' , ' Achieved BW max: ' , achievedBWmax/ 2 ** 30 , ' GiB/s'
164+ end if
174165
175166 ierr = cudaDeviceGetAttribute(memClockRt, cudaDevAttrMemoryClockRate, 0 )
176- ierr = cudaDeviceGetAttribute(memBusWidth, cudaDevAttrGlobalMemoryBusWidth, 0 )
167+ ierr = cudaDeviceGetAttribute(memBusWidth, &
168+ cudaDevAttrGlobalMemoryBusWidth, 0 )
177169 deviceBW = 2 * memBusWidth/ 8._dp * memClockRt* 1000
178170
179- print ' (a, f8.3, a)' , ' Device BW: ' , deviceBW/ 2 ** 30 , ' GiB/s'
180- print ' (a, f5.2)' , ' Effective BW utilization: %' , achievedBW/ deviceBW* 100
171+ if (nrank == 0 ) then
172+ print ' (a, f8.3, a)' , ' Device BW: ' , deviceBW/ 2 ** 30 , ' GiB/s'
173+ print ' (a, f5.2)' , ' Effective BW util min: %' , achievedBWmin/ deviceBW* 100
174+ print ' (a, f5.2)' , ' Effective BW util max: %' , achievedBWmax/ deviceBW* 100
175+ end if
181176
182177 ! check error
183178 du = du_dev
184- norm_du = norm2(u+ du)/ n_block
185- print * , ' error norm' , norm_du
186-
187- if ( norm_du > tol ) then
188- allpass = .false.
189- write (stderr, ' (a)' ) ' Check second derivatives... failed'
190- else
191- write (stderr, ' (a)' ) ' Check second derivatives... passed'
179+ norm_du = norm2(u + du)
180+ norm_du = norm_du* norm_du/ n_glob/ n_block/ SZ
181+ call MPI_Allreduce(MPI_IN_PLACE, norm_du, 1 , MPI_DOUBLE_PRECISION, &
182+ MPI_SUM, MPI_COMM_WORLD, ierr)
183+ norm_du = sqrt (norm_du)
184+
185+ if (nrank == 0 ) print * , ' error norm' , norm_du
186+
187+ if (nrank == 0 ) then
188+ if ( norm_du > tol ) then
189+ allpass = .false.
190+ write (stderr, ' (a)' ) ' Check second derivatives... failed'
191+ else
192+ write (stderr, ' (a)' ) ' Check second derivatives... passed'
193+ end if
192194 end if
193195
194196 if (allpass) then
195- write (stderr, ' (a)' ) ' ALL TESTS PASSED SUCCESSFULLY.'
197+ if (nrank == 0 ) write (stderr, ' (a)' ) ' ALL TESTS PASSED SUCCESSFULLY.'
196198 else
197199 error stop ' SOME TESTS FAILED.'
198200 end if
0 commit comments