@@ -13,38 +13,41 @@ struct MultiLevel{S, Pre, Post, TA, TP, TR, TW}
13
13
workspace:: TW
14
14
end
15
15
16
- struct MultiLevelWorkspace{T , bs}
17
- coarse_xs:: Vector{Vector{Vector{T}} }
18
- coarse_bs:: Vector{Vector{Vector{T}} }
19
- res_vecs:: Vector{Vector{Vector{T}} }
16
+ struct MultiLevelWorkspace{TX , bs}
17
+ coarse_xs:: Vector{TX }
18
+ coarse_bs:: Vector{TX }
19
+ res_vecs:: Vector{TX }
20
20
end
21
21
function MultiLevelWorkspace (:: Type{Val{bs}} , :: Type{T} ) where {bs, T<: Number }
22
- MultiLevelWorkspace {T, bs} ( Vector{Vector{Vector{T}}}[],
23
- Vector{Vector{Vector{T}}}[],
24
- Vector{Vector{Vector{T}}}[])
22
+ if bs === 1
23
+ TX = Vector{T}
24
+ else
25
+ TX = Matrix{T}
26
+ end
27
+ MultiLevelWorkspace {TX, bs} (TX[], TX[], TX[])
25
28
end
26
- Base. eltype (w:: MultiLevelWorkspace{T } ) where T = T
27
- blocksize (w:: MultiLevelWorkspace{T , bs} ) where {T , bs} = bs
29
+ Base. eltype (w:: MultiLevelWorkspace{TX } ) where TX = eltype (TX)
30
+ blocksize (w:: MultiLevelWorkspace{TX , bs} ) where {TX , bs} = bs
28
31
29
- function residual! (m:: MultiLevelWorkspace{T , bs} , n) where {T , bs}
32
+ function residual! (m:: MultiLevelWorkspace{TX , bs} , n) where {TX , bs}
30
33
if bs === 1
31
- push! (m. res_vecs, [ Vector {T} (undef, n) for _ in 1 : nthreads ()] )
34
+ push! (m. res_vecs, TX (undef, n))
32
35
else
33
- push! (m. res_vecs, [ Vector {T} (undef, n, bs) for _ in 1 : nthreads ()] )
36
+ push! (m. res_vecs, TX (undef, n, bs))
34
37
end
35
38
end
36
- function coarse_x! (m:: MultiLevelWorkspace{T , bs} , n) where {T , bs}
39
+ function coarse_x! (m:: MultiLevelWorkspace{TX , bs} , n) where {TX , bs}
37
40
if bs === 1
38
- push! (m. coarse_xs, [ Vector {T} (undef, n) for _ in 1 : nthreads ()] )
41
+ push! (m. coarse_xs, TX (undef, n))
39
42
else
40
- push! (m. coarse_xs, [ Vector {T} (undef, n, bs) for _ in 1 : nthreads ()] )
43
+ push! (m. coarse_xs, TX (undef, n, bs))
41
44
end
42
45
end
43
- function coarse_b! (m:: MultiLevelWorkspace{T , bs} , n) where {T , bs}
46
+ function coarse_b! (m:: MultiLevelWorkspace{TX , bs} , n) where {TX , bs}
44
47
if bs === 1
45
- push! (m. coarse_bs, [ Vector {T} (undef, n) for _ in 1 : nthreads ()] )
48
+ push! (m. coarse_bs, TX (undef, n))
46
49
else
47
- push! (m. coarse_bs, [ Vector {T} (undef, n, bs) for _ in 1 : nthreads ()] )
50
+ push! (m. coarse_bs, TX (undef, n, bs))
48
51
end
49
52
end
50
53
@@ -147,7 +150,7 @@ function solve!(x, ml::MultiLevel, b::AbstractArray{T},
147
150
tol:: Float64 = 1e-5 ,
148
151
verbose:: Bool = false ,
149
152
log:: Bool = false ,
150
- calculate_residual = false ) where {T}
153
+ calculate_residual = true ) where {T}
151
154
152
155
A = length (ml) == 1 ? ml. final_A : ml. levels[1 ]. A
153
156
V = promote_type (eltype (A), eltype (b))
@@ -184,14 +187,14 @@ function __solve!(x, ml, v::V, b, lvl)
184
187
A = ml. levels[lvl]. A
185
188
ml. presmoother (A, x, b)
186
189
187
- res = ml. workspace. res_vecs[lvl][ threadid ()]
190
+ res = ml. workspace. res_vecs[lvl]
188
191
mul! (res, A, x)
189
192
reshape (res, size (b)) .= b .- reshape (res, size (b))
190
193
191
- coarse_b = ml. workspace. coarse_bs[lvl][ threadid ()]
194
+ coarse_b = ml. workspace. coarse_bs[lvl]
192
195
mul! (coarse_b, ml. levels[lvl]. R, res)
193
196
194
- coarse_x = ml. workspace. coarse_xs[lvl][ threadid ()]
197
+ coarse_x = ml. workspace. coarse_xs[lvl]
195
198
coarse_x .= 0
196
199
if lvl == length (ml. levels)
197
200
ml. coarse_solver (coarse_x, coarse_b)
0 commit comments