@@ -217,7 +217,7 @@ attributes(global) subroutine transeq_3fused_dist( &
217217 real (dp), device, intent (in ) :: d2_fw(:), d2_bw(:), d2_af(:)
218218
219219 ! Local variables
220- integer :: i, j, b, k
220+ integer :: i, j, b
221221
222222 real (dp) :: d1_c_m4, d1_c_m3, d1_c_m2, d1_c_m1, d1_c_j, &
223223 d1_c_p1, d1_c_p2, d1_c_p3, d1_c_p4, &
@@ -226,23 +226,14 @@ attributes(global) subroutine transeq_3fused_dist( &
226226 d2_c_p1, d2_c_p2, d2_c_p3, d2_c_p4, &
227227 d2_alpha, d2_last_r
228228 real (dp) :: temp_du, temp_dud, temp_d2u
229+ real (dp) :: u_m4, u_m3, u_m2, u_m1, u_j, u_p1, u_p2, u_p3, u_p4
230+ real (dp) :: v_m4, v_m3, v_m2, v_m1, v_j, v_p1, v_p2, v_p3, v_p4
231+ real (dp) :: old_du, old_dud, old_d2u
229232
230233 i = threadIdx% x
231234 b = blockIdx% x
232235
233- ! store bulk coeffs in the registers
234- d1_c_m4 = d1_coeffs(1 ); d1_c_m3 = d1_coeffs(2 )
235- d1_c_m2 = d1_coeffs(3 ); d1_c_m1 = d1_coeffs(4 )
236- d1_c_j = d1_coeffs(5 )
237- d1_c_p1 = d1_coeffs(6 ); d1_c_p2 = d1_coeffs(7 )
238- d1_c_p3 = d1_coeffs(8 ); d1_c_p4 = d1_coeffs(9 )
239236 d1_last_r = d1_fw(1 )
240-
241- d2_c_m4 = d2_coeffs(1 ); d2_c_m3 = d2_coeffs(2 )
242- d2_c_m2 = d2_coeffs(3 ); d2_c_m1 = d2_coeffs(4 )
243- d2_c_j = d2_coeffs(5 )
244- d2_c_p1 = d2_coeffs(6 ); d2_c_p2 = d2_coeffs(7 )
245- d2_c_p3 = d2_coeffs(8 ); d2_c_p4 = d2_coeffs(9 )
246237 d2_last_r = d2_fw(1 )
247238
248239 ! j = 1
@@ -373,40 +364,65 @@ attributes(global) subroutine transeq_3fused_dist( &
373364 d1_alpha = d1_af(5 )
374365 d2_alpha = d2_af(5 )
375366
367+ ! store bulk coeffs in the registers
368+ d1_c_m4 = d1_coeffs(1 ); d1_c_m3 = d1_coeffs(2 )
369+ d1_c_m2 = d1_coeffs(3 ); d1_c_m1 = d1_coeffs(4 )
370+ d1_c_j = d1_coeffs(5 )
371+ d1_c_p1 = d1_coeffs(6 ); d1_c_p2 = d1_coeffs(7 )
372+ d1_c_p3 = d1_coeffs(8 ); d1_c_p4 = d1_coeffs(9 )
373+
374+ d2_c_m4 = d2_coeffs(1 ); d2_c_m3 = d2_coeffs(2 )
375+ d2_c_m2 = d2_coeffs(3 ); d2_c_m1 = d2_coeffs(4 )
376+ d2_c_j = d2_coeffs(5 )
377+ d2_c_p1 = d2_coeffs(6 ); d2_c_p2 = d2_coeffs(7 )
378+ d2_c_p3 = d2_coeffs(8 ); d2_c_p4 = d2_coeffs(9 )
379+
380+ ! It is better to access d?(i, j - 1, b) via old_d?
381+ old_du = du(i, 4 , b)
382+ old_dud = dud(i, 4 , b)
383+ old_d2u = d2u(i, 4 , b)
384+
385+ ! Populate registers with the u and v stencils
386+ u_m4 = u(i, 1 , b); u_m3 = u(i, 2 , b)
387+ u_m2 = u(i, 3 , b); u_m1 = u(i, 4 , b)
388+ u_j = u(i, 5 , b); u_p1 = u(i, 6 , b)
389+ u_p2 = u(i, 7 , b); u_p3 = u(i, 8 , b)
390+ v_m4 = v(i, 1 , b); v_m3 = v(i, 2 , b)
391+ v_m2 = v(i, 3 , b); v_m1 = v(i, 4 , b)
392+ v_j = v(i, 5 , b); v_p1 = v(i, 6 , b)
393+ v_p2 = v(i, 7 , b); v_p3 = v(i, 8 , b)
394+
376395 do j = 5 , n - 4
396+ u_p4 = u(i, j+4 , b); v_p4 = v(i, j+4 , b)
397+
377398 ! du
378- temp_du = d1_c_m4* u(i, j - 4 , b) &
379- + d1_c_m3* u(i, j - 3 , b) &
380- + d1_c_m2* u(i, j - 2 , b) &
381- + d1_c_m1* u(i, j - 1 , b) &
382- + d1_c_j* u(i, j, b) &
383- + d1_c_p1* u(i, j + 1 , b) &
384- + d1_c_p2* u(i, j + 2 , b) &
385- + d1_c_p3* u(i, j + 3 , b) &
386- + d1_c_p4* u(i, j + 4 , b)
387- du(i, j, b) = d1_fw(j)* (temp_du - d1_alpha* du(i, j - 1 , b))
399+ temp_du = d1_c_m4* u_m4 + d1_c_m3* u_m3 + d1_c_m2* u_m2 + d1_c_m1* u_m1 &
400+ + d1_c_j* u_j &
401+ + d1_c_p1* u_p1 + d1_c_p2* u_p2 + d1_c_p3* u_p3 + d1_c_p4* u_p4
402+ du(i, j, b) = d1_fw(j)* (temp_du - d1_alpha* old_du)
403+ old_du = du(i, j, b)
404+
388405 ! dud
389- temp_dud = d1_c_m4* u(i, j - 4 , b)* v(i, j - 4 , b) &
390- + d1_c_m3* u(i, j - 3 , b)* v(i, j - 3 , b) &
391- + d1_c_m2* u(i, j - 2 , b)* v(i, j - 2 , b) &
392- + d1_c_m1* u(i, j - 1 , b)* v(i, j - 1 , b) &
393- + d1_c_j* u(i, j, b)* v(i, j, b) &
394- + d1_c_p1* u(i, j + 1 , b)* v(i, j + 1 , b) &
395- + d1_c_p2* u(i, j + 2 , b)* v(i, j + 2 , b) &
396- + d1_c_p3* u(i, j + 3 , b)* v(i, j + 3 , b) &
397- + d1_c_p4* u(i, j + 4 , b)* v(i, j + 4 , b)
398- dud(i, j, b) = d1_fw(j)* (temp_dud - d1_alpha* dud(i, j - 1 , b))
406+ temp_dud = d1_c_m4* u_m4* v_m4 + d1_c_m3* u_m3* v_m3 &
407+ + d1_c_m2* u_m2* v_m2 + d1_c_m1* u_m1* v_m1 &
408+ + d1_c_j* u_j* v_j &
409+ + d1_c_p1* u_p1* v_p1 + d1_c_p2* u_p2* v_p2 &
410+ + d1_c_p3* u_p3* v_p3 + d1_c_p4* u_p4* v_p4
411+ dud(i, j, b) = d1_fw(j)* (temp_dud - d1_alpha* old_dud)
412+ old_dud = dud(i, j, b)
413+
399414 ! d2u
400- temp_d2u = d2_c_m4* u(i, j - 4 , b) &
401- + d2_c_m3* u(i, j - 3 , b) &
402- + d2_c_m2* u(i, j - 2 , b) &
403- + d2_c_m1* u(i, j - 1 , b) &
404- + d2_c_j* u(i, j, b) &
405- + d2_c_p1* u(i, j + 1 , b) &
406- + d2_c_p2* u(i, j + 2 , b) &
407- + d2_c_p3* u(i, j + 3 , b) &
408- + d2_c_p4* u(i, j + 4 , b)
409- d2u(i, j, b) = d2_fw(j)* (temp_d2u - d2_alpha* d2u(i, j - 1 , b))
415+ temp_d2u = d2_c_m4* u_m4 + d2_c_m3* u_m3 + d2_c_m2* u_m2 + d2_c_m1* u_m1 &
416+ + d2_c_j* u_j &
417+ + d2_c_p1* u_p1 + d2_c_p2* u_p2 + d2_c_p3* u_p3 + d2_c_p4* u_p4
418+ d2u(i, j, b) = d2_fw(j)* (temp_d2u - d2_alpha* old_d2u)
419+ old_d2u = d2u(i, j, b)
420+
421+ ! Prepare registers for the next step
422+ u_m4 = u_m3; u_m3 = u_m2; u_m2 = u_m1; u_m1 = u_j
423+ u_j = u_p1; u_p1 = u_p2; u_p2 = u_p3; u_p3 = u_p4
424+ v_m4 = v_m3; v_m3 = v_m2; v_m2 = v_m1; v_m1 = v_j
425+ v_j = v_p1; v_p1 = v_p2; v_p2 = v_p3; v_p3 = v_p4
410426 end do
411427
412428 j = n - 3
0 commit comments