Skip to content

Commit 02a104d

Browse files
committed
perf(cuda): Improve performance of the fused dist transeq kernel.
1 parent 3769b96 commit 02a104d

File tree

1 file changed

+59
-43
lines changed

1 file changed

+59
-43
lines changed

src/cuda/kernels_dist.f90

Lines changed: 59 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -217,7 +217,7 @@ attributes(global) subroutine transeq_3fused_dist( &
217217
real(dp), device, intent(in) :: d2_fw(:), d2_bw(:), d2_af(:)
218218

219219
! Local variables
220-
integer :: i, j, b, k
220+
integer :: i, j, b
221221

222222
real(dp) :: d1_c_m4, d1_c_m3, d1_c_m2, d1_c_m1, d1_c_j, &
223223
d1_c_p1, d1_c_p2, d1_c_p3, d1_c_p4, &
@@ -226,23 +226,14 @@ attributes(global) subroutine transeq_3fused_dist( &
226226
d2_c_p1, d2_c_p2, d2_c_p3, d2_c_p4, &
227227
d2_alpha, d2_last_r
228228
real(dp) :: temp_du, temp_dud, temp_d2u
229+
real(dp) :: u_m4, u_m3, u_m2, u_m1, u_j, u_p1, u_p2, u_p3, u_p4
230+
real(dp) :: v_m4, v_m3, v_m2, v_m1, v_j, v_p1, v_p2, v_p3, v_p4
231+
real(dp) :: old_du, old_dud, old_d2u
229232

230233
i = threadIdx%x
231234
b = blockIdx%x
232235

233-
! store bulk coeffs in the registers
234-
d1_c_m4 = d1_coeffs(1); d1_c_m3 = d1_coeffs(2)
235-
d1_c_m2 = d1_coeffs(3); d1_c_m1 = d1_coeffs(4)
236-
d1_c_j = d1_coeffs(5)
237-
d1_c_p1 = d1_coeffs(6); d1_c_p2 = d1_coeffs(7)
238-
d1_c_p3 = d1_coeffs(8); d1_c_p4 = d1_coeffs(9)
239236
d1_last_r = d1_fw(1)
240-
241-
d2_c_m4 = d2_coeffs(1); d2_c_m3 = d2_coeffs(2)
242-
d2_c_m2 = d2_coeffs(3); d2_c_m1 = d2_coeffs(4)
243-
d2_c_j = d2_coeffs(5)
244-
d2_c_p1 = d2_coeffs(6); d2_c_p2 = d2_coeffs(7)
245-
d2_c_p3 = d2_coeffs(8); d2_c_p4 = d2_coeffs(9)
246237
d2_last_r = d2_fw(1)
247238

248239
! j = 1
@@ -373,40 +364,65 @@ attributes(global) subroutine transeq_3fused_dist( &
373364
d1_alpha = d1_af(5)
374365
d2_alpha = d2_af(5)
375366

367+
! store bulk coeffs in the registers
368+
d1_c_m4 = d1_coeffs(1); d1_c_m3 = d1_coeffs(2)
369+
d1_c_m2 = d1_coeffs(3); d1_c_m1 = d1_coeffs(4)
370+
d1_c_j = d1_coeffs(5)
371+
d1_c_p1 = d1_coeffs(6); d1_c_p2 = d1_coeffs(7)
372+
d1_c_p3 = d1_coeffs(8); d1_c_p4 = d1_coeffs(9)
373+
374+
d2_c_m4 = d2_coeffs(1); d2_c_m3 = d2_coeffs(2)
375+
d2_c_m2 = d2_coeffs(3); d2_c_m1 = d2_coeffs(4)
376+
d2_c_j = d2_coeffs(5)
377+
d2_c_p1 = d2_coeffs(6); d2_c_p2 = d2_coeffs(7)
378+
d2_c_p3 = d2_coeffs(8); d2_c_p4 = d2_coeffs(9)
379+
380+
! It is better to access d?(i, j - 1, b) via old_d?
381+
old_du = du(i, 4, b)
382+
old_dud = dud(i, 4, b)
383+
old_d2u = d2u(i, 4, b)
384+
385+
! Populate registers with the u and v stencils
386+
u_m4 = u(i, 1, b); u_m3 = u(i, 2, b)
387+
u_m2 = u(i, 3, b); u_m1 = u(i, 4, b)
388+
u_j = u(i, 5, b); u_p1 = u(i, 6, b)
389+
u_p2 = u(i, 7, b); u_p3 = u(i, 8, b)
390+
v_m4 = v(i, 1, b); v_m3 = v(i, 2, b)
391+
v_m2 = v(i, 3, b); v_m1 = v(i, 4, b)
392+
v_j = v(i, 5, b); v_p1 = v(i, 6, b)
393+
v_p2 = v(i, 7, b); v_p3 = v(i, 8, b)
394+
376395
do j = 5, n - 4
396+
u_p4 = u(i, j+4, b); v_p4 = v(i, j+4, b)
397+
377398
! du
378-
temp_du = d1_c_m4*u(i, j - 4, b) &
379-
+ d1_c_m3*u(i, j - 3, b) &
380-
+ d1_c_m2*u(i, j - 2, b) &
381-
+ d1_c_m1*u(i, j - 1, b) &
382-
+ d1_c_j*u(i, j, b) &
383-
+ d1_c_p1*u(i, j + 1, b) &
384-
+ d1_c_p2*u(i, j + 2, b) &
385-
+ d1_c_p3*u(i, j + 3, b) &
386-
+ d1_c_p4*u(i, j + 4, b)
387-
du(i, j, b) = d1_fw(j)*(temp_du - d1_alpha*du(i, j - 1, b))
399+
temp_du = d1_c_m4*u_m4 + d1_c_m3*u_m3 + d1_c_m2*u_m2 + d1_c_m1*u_m1 &
400+
+ d1_c_j*u_j &
401+
+ d1_c_p1*u_p1 + d1_c_p2*u_p2 + d1_c_p3*u_p3 + d1_c_p4*u_p4
402+
du(i, j, b) = d1_fw(j)*(temp_du - d1_alpha*old_du)
403+
old_du = du(i, j, b)
404+
388405
! dud
389-
temp_dud = d1_c_m4*u(i, j - 4, b)*v(i, j - 4, b) &
390-
+ d1_c_m3*u(i, j - 3, b)*v(i, j - 3, b) &
391-
+ d1_c_m2*u(i, j - 2, b)*v(i, j - 2, b) &
392-
+ d1_c_m1*u(i, j - 1, b)*v(i, j - 1, b) &
393-
+ d1_c_j*u(i, j, b)*v(i, j, b) &
394-
+ d1_c_p1*u(i, j + 1, b)*v(i, j + 1, b) &
395-
+ d1_c_p2*u(i, j + 2, b)*v(i, j + 2, b) &
396-
+ d1_c_p3*u(i, j + 3, b)*v(i, j + 3, b) &
397-
+ d1_c_p4*u(i, j + 4, b)*v(i, j + 4, b)
398-
dud(i, j, b) = d1_fw(j)*(temp_dud - d1_alpha*dud(i, j - 1, b))
406+
temp_dud = d1_c_m4*u_m4*v_m4 + d1_c_m3*u_m3*v_m3 &
407+
+ d1_c_m2*u_m2*v_m2 + d1_c_m1*u_m1*v_m1 &
408+
+ d1_c_j*u_j*v_j &
409+
+ d1_c_p1*u_p1*v_p1 + d1_c_p2*u_p2*v_p2 &
410+
+ d1_c_p3*u_p3*v_p3 + d1_c_p4*u_p4*v_p4
411+
dud(i, j, b) = d1_fw(j)*(temp_dud - d1_alpha*old_dud)
412+
old_dud = dud(i, j, b)
413+
399414
! d2u
400-
temp_d2u = d2_c_m4*u(i, j - 4, b) &
401-
+ d2_c_m3*u(i, j - 3, b) &
402-
+ d2_c_m2*u(i, j - 2, b) &
403-
+ d2_c_m1*u(i, j - 1, b) &
404-
+ d2_c_j*u(i, j, b) &
405-
+ d2_c_p1*u(i, j + 1, b) &
406-
+ d2_c_p2*u(i, j + 2, b) &
407-
+ d2_c_p3*u(i, j + 3, b) &
408-
+ d2_c_p4*u(i, j + 4, b)
409-
d2u(i, j, b) = d2_fw(j)*(temp_d2u - d2_alpha*d2u(i, j - 1, b))
415+
temp_d2u = d2_c_m4*u_m4 + d2_c_m3*u_m3 + d2_c_m2*u_m2 + d2_c_m1*u_m1 &
416+
+ d2_c_j*u_j &
417+
+ d2_c_p1*u_p1 + d2_c_p2*u_p2 + d2_c_p3*u_p3 + d2_c_p4*u_p4
418+
d2u(i, j, b) = d2_fw(j)*(temp_d2u - d2_alpha*old_d2u)
419+
old_d2u = d2u(i, j, b)
420+
421+
! Prepare registers for the next step
422+
u_m4 = u_m3; u_m3 = u_m2; u_m2 = u_m1; u_m1 = u_j
423+
u_j = u_p1; u_p1 = u_p2; u_p2 = u_p3; u_p3 = u_p4
424+
v_m4 = v_m3; v_m3 = v_m2; v_m2 = v_m1; v_m1 = v_j
425+
v_j = v_p1; v_p1 = v_p2; v_p2 = v_p3; v_p3 = v_p4
410426
end do
411427

412428
j = n - 3

0 commit comments

Comments
 (0)