Skip to content

Commit 5137b7b

Browse files
authored
Merge pull request #68 from JuliaLinearAlgebra/RA/mt
Make AMG thread safe
2 parents cfee0f0 + 6a269fd commit 5137b7b

File tree

1 file changed

+22
-25
lines changed

1 file changed

+22
-25
lines changed

src/multilevel.jl

Lines changed: 22 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -13,41 +13,38 @@ struct MultiLevel{S, Pre, Post, TA, TP, TR, TW}
1313
workspace::TW
1414
end
1515

16-
struct MultiLevelWorkspace{TX, bs}
17-
coarse_xs::Vector{TX}
18-
coarse_bs::Vector{TX}
19-
res_vecs::Vector{TX}
16+
struct MultiLevelWorkspace{T, bs}
17+
coarse_xs::Vector{Vector{Vector{T}}}
18+
coarse_bs::Vector{Vector{Vector{T}}}
19+
res_vecs::Vector{Vector{Vector{T}}}
2020
end
2121
function MultiLevelWorkspace(::Type{Val{bs}}, ::Type{T}) where {bs, T<:Number}
22-
if bs === 1
23-
TX = Vector{T}
24-
else
25-
TX = Matrix{T}
26-
end
27-
MultiLevelWorkspace{TX, bs}(TX[], TX[], TX[])
22+
MultiLevelWorkspace{T, bs}( Vector{Vector{Vector{T}}}[],
23+
Vector{Vector{Vector{T}}}[],
24+
Vector{Vector{Vector{T}}}[])
2825
end
29-
Base.eltype(w::MultiLevelWorkspace{TX}) where TX = eltype(TX)
30-
blocksize(w::MultiLevelWorkspace{TX, bs}) where {TX, bs} = bs
26+
Base.eltype(w::MultiLevelWorkspace{T}) where T = T
27+
blocksize(w::MultiLevelWorkspace{T, bs}) where {T, bs} = bs
3128

32-
function residual!(m::MultiLevelWorkspace{TX, bs}, n) where {TX, bs}
29+
function residual!(m::MultiLevelWorkspace{T, bs}, n) where {T, bs}
3330
if bs === 1
34-
push!(m.res_vecs, TX(undef, n))
31+
push!(m.res_vecs, [Vector{T}(undef, n) for _ in 1:nthreads()])
3532
else
36-
push!(m.res_vecs, TX(undef, n, bs))
33+
push!(m.res_vecs, [Vector{T}(undef, n, bs) for _ in 1:nthreads()])
3734
end
3835
end
39-
function coarse_x!(m::MultiLevelWorkspace{TX, bs}, n) where {TX, bs}
36+
function coarse_x!(m::MultiLevelWorkspace{T, bs}, n) where {T, bs}
4037
if bs === 1
41-
push!(m.coarse_xs, TX(undef, n))
38+
push!(m.coarse_xs, [Vector{T}(undef, n) for _ in 1:nthreads()])
4239
else
43-
push!(m.coarse_xs, TX(undef, n, bs))
40+
push!(m.coarse_xs, [Vector{T}(undef, n, bs) for _ in 1:nthreads()])
4441
end
4542
end
46-
function coarse_b!(m::MultiLevelWorkspace{TX, bs}, n) where {TX, bs}
43+
function coarse_b!(m::MultiLevelWorkspace{T, bs}, n) where {T, bs}
4744
if bs === 1
48-
push!(m.coarse_bs, TX(undef, n))
45+
push!(m.coarse_bs, [Vector{T}(undef, n) for _ in 1:nthreads()])
4946
else
50-
push!(m.coarse_bs, TX(undef, n, bs))
47+
push!(m.coarse_bs, [Vector{T}(undef, n, bs) for _ in 1:nthreads()])
5148
end
5249
end
5350

@@ -150,7 +147,7 @@ function solve!(x, ml::MultiLevel, b::AbstractArray{T},
150147
tol::Float64 = 1e-5,
151148
verbose::Bool = false,
152149
log::Bool = false,
153-
calculate_residual = true) where {T}
150+
calculate_residual = false) where {T}
154151

155152
A = length(ml) == 1 ? ml.final_A : ml.levels[1].A
156153
V = promote_type(eltype(A), eltype(b))
@@ -187,14 +184,14 @@ function __solve!(x, ml, v::V, b, lvl)
187184
A = ml.levels[lvl].A
188185
ml.presmoother(A, x, b)
189186

190-
res = ml.workspace.res_vecs[lvl]
187+
res = ml.workspace.res_vecs[lvl][threadid()]
191188
mul!(res, A, x)
192189
reshape(res, size(b)) .= b .- reshape(res, size(b))
193190

194-
coarse_b = ml.workspace.coarse_bs[lvl]
191+
coarse_b = ml.workspace.coarse_bs[lvl][threadid()]
195192
mul!(coarse_b, ml.levels[lvl].R, res)
196193

197-
coarse_x = ml.workspace.coarse_xs[lvl]
194+
coarse_x = ml.workspace.coarse_xs[lvl][threadid()]
198195
coarse_x .= 0
199196
if lvl == length(ml.levels)
200197
ml.coarse_solver(coarse_x, coarse_b)

0 commit comments

Comments
 (0)