Skip to content

Commit 94bd0df

Browse files
committed
fix type stability of ifelse
1 parent c467855 commit 94bd0df

File tree

1 file changed

+11
-9
lines changed

1 file changed

+11
-9
lines changed

src/smoother.jl

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -29,15 +29,15 @@ end
2929
function gs!(A, b, x, start, step, stop)
3030
n = size(A, 1)
3131
z = zero(eltype(A))
32-
@inbounds for col = 1:size(x, 2)
33-
for i = start:step:stop
32+
@inbounds for col in 1:size(x, 2)
33+
for i in start:step:stop
3434
rsum = z
3535
d = z
3636
for j in nzrange(A, i)
3737
row = A.rowval[j]
3838
val = A.nzval[j]
3939
d = ifelse(i == row, val, d)
40-
rsum += ifelse(i == row, 0, val * x[row, col])
40+
rsum += ifelse(i == row, z, val * x[row, col])
4141
end
4242
x[i, col] = ifelse(d == 0, x[i, col], (b[i, col] - rsum) / d)
4343
end
@@ -57,6 +57,7 @@ function (jacobi::Jacobi)(A, x, b)
5757
ω = jacobi.ω
5858
one = Base.one(eltype(A))
5959
temp = jacobi.temp
60+
z = zero(eltype(A))
6061

6162
for i in 1:jacobi.iter
6263
@inbounds for col = 1:size(x, 2)
@@ -65,15 +66,15 @@ function (jacobi::Jacobi)(A, x, b)
6566
end
6667

6768
for i = 1:size(A, 1)
68-
rsum = zero(eltype(A))
69-
diag = zero(eltype(A))
69+
rsum = z
70+
diag = z
7071

7172
for j in nzrange(A, i)
7273
row = A.rowval[j]
7374
val = A.nzval[j]
7475

7576
diag = ifelse(row == i, val, diag)
76-
rsum += ifelse(row == i, 0, val * temp[row, col])
77+
rsum += ifelse(row == i, z, val * temp[row, col])
7778
end
7879

7980
xcand = (one - ω) * temp[i, col] + ω * ((b[i, col] - rsum) / diag)
@@ -116,15 +117,16 @@ function (pjacobmapper::ParallelJacobiMapper)(i)
116117
col = pjacobmapper.col
117118
118119
one = Base.one(eltype(A))
119-
rsum = zero(eltype(A))
120-
diag = zero(eltype(A))
120+
z = zero(eltype(A))
121+
rsum = z
122+
diag = z
121123
122124
for j in nzrange(A, i)
123125
row = A.rowval[j]
124126
val = A.nzval[j]
125127
126128
diag = ifelse(row == i, val, diag)
127-
rsum += ifelse(row == i, 0, val * temp[row, col])
129+
rsum += ifelse(row == i, z, val * temp[row, col])
128130
end
129131
xcand = (one - ω) * temp[i, col] + ω * ((b[i, col] - rsum) / diag)
130132

0 commit comments

Comments
 (0)