Skip to content

Commit 3a1caf4

Browse files
committed
fix: Allow dist. algorithm to solve TDSs with non-uniform off diagonals.
1 parent c8937de commit 3a1caf4

File tree

5 files changed

+120
-126
lines changed

5 files changed

+120
-126
lines changed

src/cuda/kernels_dist.f90

Lines changed: 19 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -9,70 +9,69 @@ module m_cuda_kernels_dist
99

1010
attributes(global) subroutine der_univ_dist( &
1111
du, send_u_b, send_u_e, u, u_b, u_e, coeffs_b, coeffs_e, coeffs, n, &
12-
alfa, ffr, fbc &
12+
ffr, fbc, faf &
1313
)
1414
implicit none
1515

1616
! Arguments
1717
real(dp), device, intent(out), dimension(:, :, :) :: du, send_u_b, &
1818
send_u_e
1919
real(dp), device, intent(in), dimension(:, :, :) :: u, u_b, u_e
20-
real(dp), device, intent(in), dimension(:) :: ffr, fbc
2120
real(dp), device, intent(in), dimension(:, :) :: coeffs_b, coeffs_e
2221
real(dp), device, intent(in), dimension(:) :: coeffs
23-
real(dp), value, intent(in) :: alfa
2422
integer, value, intent(in) :: n
23+
real(dp), device, intent(in), dimension(:) :: ffr, fbc, faf
2524

2625
! Local variables
2726
integer :: i, j, b, k, lj
2827
integer :: jm2, jm1, jp1, jp2
29-
integer :: n_s, n_m, n_b, n_e !stencil, middle, begin, end
3028

31-
real(dp) :: temp_du, c_m4, c_m3, c_m2, c_m1, c_j, c_p1, c_p2, c_p3, c_p4
29+
real(dp) :: c_m4, c_m3, c_m2, c_m1, c_j, c_p1, c_p2, c_p3, c_p4, &
30+
temp_du, alpha, last_r
3231

3332
i = threadIdx%x
3433
b = blockIdx%x
3534

36-
n_s = (size(coeffs)-1)/2
37-
n_m = size(coeffs)
38-
n_b = size(coeffs_b, dim=2)
39-
n_e = size(coeffs_e, dim=2)
40-
4135
! store bulk coeffs in the registers
4236
c_m4 = coeffs(1); c_m3 = coeffs(2); c_m2 = coeffs(3); c_m1 = coeffs(4)
4337
c_j = coeffs(5)
4438
c_p1 = coeffs(6); c_p2 = coeffs(7); c_p3 = coeffs(8); c_p4 = coeffs(9)
39+
last_r = ffr(1)
4540

4641
du(i, 1, b) = coeffs(1)*u_b(i, 1, b) + coeffs(2)*u_b(i, 2, b) &
4742
+ coeffs(3)*u_b(i, 3, b) + coeffs(4)*u_b(i, 4, b) &
4843
+ coeffs(5)*u(i, 1, b) &
4944
+ coeffs(6)*u(i, 2, b) + coeffs(7)*u(i, 3, b) &
5045
+ coeffs(8)*u(i, 4, b) + coeffs(9)*u(i, 5, b)
46+
du(i, 1, b) = du(i, 1, b)*faf(1)
5147
du(i, 2, b) = coeffs(1)*u_b(i, 2, b) + coeffs(2)*u_b(i, 3, b) &
5248
+ coeffs(3)*u_b(i, 4, b) + coeffs(4)*u(i, 1, b) &
5349
+ coeffs(5)*u(i, 2, b) &
5450
+ coeffs(6)*u(i, 3, b) + coeffs(7)*u(i, 4, b) &
5551
+ coeffs(8)*u(i, 5, b) + coeffs(9)*u(i, 6, b)
52+
du(i, 2, b) = du(i, 2, b)*faf(2)
5653
du(i, 3, b) = coeffs(1)*u_b(i, 3, b) + coeffs(2)*u_b(i, 4, b) &
5754
+ coeffs(3)*u(i, 1, b) + coeffs(4)*u(i, 2, b) &
5855
+ coeffs(5)*u(i, 3, b) &
5956
+ coeffs(6)*u(i, 4, b) + coeffs(7)*u(i, 5, b) &
6057
+ coeffs(8)*u(i, 6, b) + coeffs(9)*u(i, 7, b)
61-
du(i, 3, b) = ffr(3)*(du(i, 3, b) - alfa*du(i, 2, b))
58+
du(i, 3, b) = ffr(3)*(du(i, 3, b) - faf(3)*du(i, 2, b))
6259
du(i, 4, b) = coeffs(1)*u_b(i, 4, b) + coeffs(2)*u(i, 1, b) &
6360
+ coeffs(3)*u(i, 2, b) + coeffs(4)*u(i, 3, b) &
6461
+ coeffs(5)*u(i, 4, b) &
6562
+ coeffs(6)*u(i, 5, b) + coeffs(7)*u(i, 6, b) &
6663
+ coeffs(8)*u(i, 7, b) + coeffs(9)*u(i, 8, b)
67-
du(i, 4, b) = ffr(4)*(du(i, 4, b) - alfa*du(i, 3, b))
64+
du(i, 4, b) = ffr(4)*(du(i, 4, b) - faf(3)*du(i, 3, b))
65+
66+
alpha = faf(5)
6867

69-
do j = n_s+1, n-n_s
68+
do j = 5, n-4
7069
temp_du = c_m4*u(i, j-4, b) + c_m3*u(i, j-3, b) &
7170
+ c_m2*u(i, j-2, b) + c_m1*u(i, j-1, b) &
7271
+ c_j*u(i, j, b) &
7372
+ c_p1*u(i, j+1, b) + c_p2*u(i, j+2, b) &
7473
+ c_p3*u(i, j+3, b) + c_p4*u(i, j+4, b)
75-
du(i, j, b) = ffr(j)*(temp_du - alfa*du(i, j-1, b))
74+
du(i, j, b) = ffr(j)*(temp_du - alpha*du(i, j-1, b))
7675
end do
7776

7877
j = n-3
@@ -81,49 +80,48 @@ attributes(global) subroutine der_univ_dist( &
8180
+ coeffs(5)*u(i, j, b) &
8281
+ coeffs(6)*u(i, j+1, b) + coeffs(7)*u(i, j+2, b) &
8382
+ coeffs(8)*u(i, j+3, b) + coeffs(9)*u_e(i, 1, b)
84-
du(i, j, b) = ffr(j)*(du(i, j, b) - alfa*du(i, j-1, b))
83+
du(i, j, b) = ffr(j)*(du(i, j, b) - faf(j)*du(i, j-1, b))
8584
j = n-2
8685
du(i, j, b) = coeffs(1)*u(i, j-4, b) + coeffs(2)*u(i, j-3, b) &
8786
+ coeffs(3)*u(i, j-2, b) + coeffs(4)*u(i, j-1, b) &
8887
+ coeffs(5)*u(i, j, b) &
8988
+ coeffs(6)*u(i, j+1, b) + coeffs(7)*u(i, j+2, b) &
9089
+ coeffs(8)*u_e(i, 1, b) + coeffs(9)*u_e(i, 2, b)
91-
du(i, j, b) = ffr(j)*(du(i, j, b) - alfa*du(i, j-1, b))
90+
du(i, j, b) = ffr(j)*(du(i, j, b) - faf(j)*du(i, j-1, b))
9291
j = n-1
9392
du(i, j, b) = coeffs(1)*u(i, j-4, b) + coeffs(2)*u(i, j-3, b) &
9493
+ coeffs(3)*u(i, j-2, b) + coeffs(4)*u(i, j-1, b) &
9594
+ coeffs(5)*u(i, j, b) &
9695
+ coeffs(6)*u(i, j+1, b) + coeffs(7)*u_e(i, 1, b) &
9796
+ coeffs(8)*u_e(i, 2, b) + coeffs(9)*u_e(i, 3, b)
98-
du(i, j, b) = ffr(j)*(du(i, j, b) - alfa*du(i, j-1, b))
97+
du(i, j, b) = ffr(j)*(du(i, j, b) - faf(j)*du(i, j-1, b))
9998
j = n
10099
du(i, j, b) = coeffs(1)*u(i, j-4, b) + coeffs(2)*u(i, j-3, b) &
101100
+ coeffs(3)*u(i, j-2, b) + coeffs(4)*u(i, j-1, b) &
102101
+ coeffs(5)*u(i, j, b) &
103102
+ coeffs(6)*u_e(i, 1, b) + coeffs(7)*u_e(i, 2, b) &
104103
+ coeffs(8)*u_e(i, 3, b) + coeffs(9)*u_e(i, 4, b)
105-
du(i, j, b) = ffr(j)*(du(i, j, b) - alfa*du(i, j-1, b))
104+
du(i, j, b) = ffr(j)*(du(i, j, b) - faf(j)*du(i, j-1, b))
106105

107106
send_u_e(i, 1, b) = du(i, n, b)
108107

109108
! Backward pass of the hybrid algorithm
110109
do j = n - 2, 2, -1
111110
du(i, j, b) = du(i, j, b) - fbc(j)*du(i, j + 1, b)
112111
end do
113-
du(i, 1, b) = ffr(1)*(du(i, 1, b) - fbc(1)*du(i, 2, b))
112+
du(i, 1, b) = last_r*(du(i, 1, b) - fbc(1)*du(i, 2, b))
114113
send_u_b(i, 1, b) = du(i, 1, b)
115114

116115
end subroutine der_univ_dist
117116

118117
attributes(global) subroutine der_univ_subs(du, recv_u_b, recv_u_e, &
119-
n, alfa, dist_sa, dist_sc)
118+
n, dist_sa, dist_sc)
120119
implicit none
121120

122121
! Arguments
123122
real(dp), device, intent(out), dimension(:, :, :) :: du
124123
real(dp), device, intent(in), dimension(:, :, :) :: recv_u_b, recv_u_e
125124
real(dp), device, intent(in), dimension(:) :: dist_sa, dist_sc
126-
real(dp), value, intent(in) :: alfa
127125
integer, value, intent(in) :: n
128126

129127
! Local variables

src/derparams.f90

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -10,21 +10,20 @@ subroutine der_1_vv()
1010
end subroutine der_1_vv
1111

1212
subroutine der_2_vv(coeffs, coeffs_b, coeffs_e, &
13-
dist_fr, dist_bc, dist_sa, dist_sc, &
14-
n_halo, alfa, dx2, n, bcond)
13+
dist_fr, dist_bc, dist_af, dist_sa, dist_sc, &
14+
n_halo, dx2, n, bcond)
1515
implicit none
1616

1717
real(dp), allocatable, dimension(:), intent(out) :: coeffs, &
18-
dist_fr, dist_bc, dist_sa, dist_sc
18+
dist_fr, dist_bc, dist_af, dist_sa, dist_sc
1919
real(dp), allocatable, dimension(:,:), intent(out) :: coeffs_b, coeffs_e
2020
integer, intent(out) :: n_halo
21-
real(dp), intent(out) :: alfa
2221
real(dp), intent(in) :: dx2
2322
integer, intent(in) :: n
2423
character(len=*), intent(in) :: bcond
2524

2625
real(dp), allocatable :: dist_b(:)
27-
real(dp) :: asi, bsi, csi, dsi
26+
real(dp) :: alfa, asi, bsi, csi, dsi
2827
integer :: i, n_stencil
2928

3029
allocate(dist_sa(n), dist_sc(n), dist_b(n))
@@ -55,14 +54,16 @@ subroutine der_2_vv(coeffs, coeffs_b, coeffs_e, &
5554
print*, 'Boundary condition is not recognized :', bcond
5655
end select
5756

58-
call process_dist(dist_fr, dist_bc, dist_sa, dist_sc, dist_b, n)
57+
call process_dist(dist_fr, dist_bc, dist_af, dist_sa, dist_sc, dist_b, n)
5958

6059
end subroutine der_2_vv
6160

62-
subroutine process_dist(dist_fr, dist_bc, dist_sa, dist_sc, dist_b, n)
61+
subroutine process_dist(dist_fr, dist_bc, dist_af, &
62+
dist_sa, dist_sc, dist_b, n)
6363
implicit none
6464

65-
real(dp), allocatable, dimension(:), intent(out) :: dist_fr, dist_bc
65+
real(dp), allocatable, dimension(:), intent(out) :: dist_fr, dist_bc, &
66+
dist_af
6667
real(dp), dimension(:), intent(inout) :: dist_sa, dist_sc, dist_b
6768
integer, intent(in) :: n
6869

@@ -71,7 +72,8 @@ subroutine process_dist(dist_fr, dist_bc, dist_sa, dist_sc, dist_b, n)
7172
m = n
7273
nrank = 0; nproc = 1
7374

74-
allocate(dist_fr(n), dist_bc(n))
75+
! forward factors, backward factors, and auxiliary factor
76+
allocate(dist_fr(n), dist_bc(n), dist_af(n))
7577

7678
do nrank = 0, nproc-1
7779

@@ -80,9 +82,11 @@ subroutine process_dist(dist_fr, dist_bc, dist_sa, dist_sc, dist_b, n)
8082
dist_sa(i) = dist_sa(i)/dist_b(i)
8183
dist_sc(i) = dist_sc(i)/dist_b(i)
8284
dist_bc(i) = dist_sc(i)
85+
dist_af(i) = 1._dp/dist_b(i)
8386
end do
8487
do i = 3+m*nrank, m+m*nrank
8588
dist_fr(i) = 1.d0/(dist_b(i)-dist_sa(i)*dist_sc(i-1))
89+
dist_af(i) = dist_sa(i)
8690
dist_sa(i) = -dist_fr(i)*dist_sa(i)*dist_sa(i-1)
8791
dist_sc(i) = dist_fr(i)*dist_sc(i)
8892
!dist_bc(i) = dist_sc(i)
@@ -92,7 +96,7 @@ subroutine process_dist(dist_fr, dist_bc, dist_sa, dist_sc, dist_b, n)
9296
dist_bc(i) = dist_sc(i)
9397
dist_sc(i) = -dist_sc(i)*dist_sc(i+1)
9498
end do
95-
! this is not good
99+
! dist_fr(1) is never used, so store this extra factor instead.
96100
dist_fr(1+m*nrank) = 1.d0/(1.d0-dist_sc(1+m*nrank)*dist_sa(2+m*nrank))
97101

98102
dist_sa(1+m*nrank) = dist_fr(1+m*nrank)*dist_sa(1+m*nrank)

0 commit comments

Comments
 (0)