Skip to content

Commit d89759d

Browse files
committed
Allow vm and vn to be of different types
For problems `Ax=b`, this allows the vector space of `x` (`vn`) to be different from the vector space of `b` (`vm`); also for some bipartite problems [A B; C D] [x; y] = [b; c] it allows `x` and `y` to be of different types. The key change is to introduce `KrylovWorkspaceNext{T,FC,Sm,Sn}`, for which `Sm` and `Sn` are the types of the length-`m` and length-`n` vectors, respectively. `KrylovWorkspace{T,FC,S}` is now just a type alias for `KrylovWorkspaceNext{T,FC,S,S}`. This is done for reasons of backwards compatibility; in the next breaking release, we should rename `KrylovWorkspaceNext` to `KrylovWorkspace`. This PR exploits the new flexibility for CGLS, LSQR, TRICG, TRIMR, and GPMR. Extending to other solvers is left as an exercise for the reader. Fixes #1037
1 parent 4e3d7a4 commit d89759d

File tree

15 files changed

+265
-128
lines changed

15 files changed

+265
-128
lines changed

src/cgls.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -126,7 +126,7 @@ args_cgls = (:A, :b)
126126
kwargs_cgls = (:M, :ldiv, :radius, , :atol, :rtol, :itmax, :timemax, :verbose, :history, :callback, :iostream)
127127

128128
@eval begin
129-
function cgls!(workspace :: CglsWorkspace{T,FC,S}, $(def_args_cgls...); $(def_kwargs_cgls...)) where {T <: AbstractFloat, FC <: FloatOrComplex{T}, S <: AbstractVector{FC}}
129+
function cgls!(workspace :: CglsWorkspace{T,FC,Sm,Sn}, $(def_args_cgls...); $(def_kwargs_cgls...)) where {T <: AbstractFloat, FC <: FloatOrComplex{T}, Sm <: AbstractVector{FC}, Sn <: AbstractVector{FC}}
130130

131131
# Timer
132132
start_time = time_ns()
@@ -142,13 +142,13 @@ kwargs_cgls = (:M, :ldiv, :radius, :λ, :atol, :rtol, :itmax, :timemax, :verbose
142142

143143
# Check type consistency
144144
eltype(A) == FC || @warn "eltype(A) ≠ $FC. This could lead to errors or additional allocations in operator-vector products."
145-
ktypeof(b) == S || error("ktypeof(b) must be equal to $S")
145+
ktypeof(b) == Sm || error("ktypeof(b) must be equal to $Sm")
146146

147147
# Compute the adjoint of A
148148
Aᴴ = A'
149149

150150
# Set up workspace.
151-
allocate_if(!MisI, workspace, :Mr, S, workspace.r) # The length of Mr is m
151+
allocate_if(!MisI, workspace, :Mr, Sm, workspace.r) # The length of Mr is m
152152
x, p, s, r, q, stats = workspace.x, workspace.p, workspace.s, workspace.r, workspace.q, workspace.stats
153153
rNorms, ArNorms = stats.residuals, stats.Aresiduals
154154
reset!(stats)

src/gpmr.jl

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -157,7 +157,7 @@ optargs_gpmr = (:x0, :y0)
157157
kwargs_gpmr = (:C, :D, :E, :F, :ldiv, :gsp, , , :reorthogonalization, :atol, :rtol, :itmax, :timemax, :verbose, :history, :callback, :iostream)
158158

159159
@eval begin
160-
function gpmr!(workspace :: GpmrWorkspace{T,FC,S}, $(def_args_gpmr...); $(def_kwargs_gpmr...)) where {T <: AbstractFloat, FC <: FloatOrComplex{T}, S <: AbstractVector{FC}}
160+
function gpmr!(workspace :: GpmrWorkspace{T,FC,Sm,Sn}, $(def_args_gpmr...); $(def_kwargs_gpmr...)) where {T <: AbstractFloat, FC <: FloatOrComplex{T}, Sm <: AbstractVector{FC}, Sn <: AbstractVector{FC}}
161161

162162
# Timer
163163
start_time = time_ns()
@@ -181,8 +181,8 @@ kwargs_gpmr = (:C, :D, :E, :F, :ldiv, :gsp, :λ, :μ, :reorthogonalization, :ato
181181
# Check type consistency
182182
eltype(A) == FC || @warn "eltype(A) ≠ $FC. This could lead to errors or additional allocations in operator-vector products."
183183
eltype(B) == FC || @warn "eltype(B) ≠ $FC. This could lead to errors or additional allocations in operator-vector products."
184-
ktypeof(b) == S || error("ktypeof(b) must be equal to $S")
185-
ktypeof(c) == S || error("ktypeof(c) must be equal to $S")
184+
ktypeof(b) == Sm || error("ktypeof(b) must be equal to $Sm")
185+
ktypeof(c) == Sn || error("ktypeof(c) must be equal to $Sn")
186186

187187
# Determine λ and μ associated to generalized saddle point systems.
188188
gsp &&= one(FC) ; μ = zero(FC))
@@ -192,10 +192,10 @@ kwargs_gpmr = (:C, :D, :E, :F, :ldiv, :gsp, :λ, :μ, :reorthogonalization, :ato
192192
warm_start && 0) && !FisI && error("Warm-start with right preconditioners is not supported.")
193193

194194
# Set up workspace.
195-
allocate_if(!CisI, workspace, :q , S, workspace.x) # The length of q is m
196-
allocate_if(!DisI, workspace, :p , S, workspace.y) # The length of p is n
197-
allocate_if(!EisI, workspace, :wB, S, workspace.x) # The length of wB is m
198-
allocate_if(!FisI, workspace, :wA, S, workspace.y) # The length of wA is n
195+
allocate_if(!CisI, workspace, :q , Sm, workspace.x) # The length of q is m
196+
allocate_if(!DisI, workspace, :p , Sn, workspace.y) # The length of p is n
197+
allocate_if(!EisI, workspace, :wB, Sm, workspace.x) # The length of wB is m
198+
allocate_if(!FisI, workspace, :wA, Sn, workspace.y) # The length of wA is n
199199
wA, wB, dA, dB, Δx, Δy = workspace.wA, workspace.wB, workspace.dA, workspace.dB, workspace.Δx, workspace.Δy
200200
x, y, V, U, gs, gc = workspace.x, workspace.y, workspace.V, workspace.U, workspace.gs, workspace.gc
201201
zt, R, stats = workspace.zt, workspace.R, workspace.stats

src/interface.jl

Lines changed: 59 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -138,7 +138,7 @@ for (workspace, krylov, args, def_args, optargs, def_optargs, kwargs, def_kwargs
138138

139139
krylov_solve(::Val{Symbol($krylov)}, $(def_args...); $(def_kwargs...)) where {T <: AbstractFloat, FC <: FloatOrComplex{T}} = $(krylov)($(args...); $(kwargs...))
140140
end
141-
elseif krylov in (:diom, :dqgmres, :fom, :gmres, :fgmres, :gpmr)
141+
elseif krylov in (:diom, :dqgmres, :fom, :gmres, :fgmres)
142142
@eval begin
143143
function $(krylov)($(def_args...); memory::Int = 20, $(def_kwargs...)) where {T <: AbstractFloat, FC <: FloatOrComplex{T}}
144144
start_time = time_ns()
@@ -164,6 +164,35 @@ for (workspace, krylov, args, def_args, optargs, def_optargs, kwargs, def_kwargs
164164
return results(workspace)
165165
end
166166

167+
krylov_solve(::Val{Symbol($krylov)}, $(def_args...), $(def_optargs...); memory::Int = 20, $(def_kwargs...)) where {T <: AbstractFloat, FC <: FloatOrComplex{T}} = $(krylov)($(args...), $(optargs...); memory, $(kwargs...))
168+
end
169+
end
170+
elseif krylov == :gpmr
171+
@eval begin
172+
function $(krylov)($(def_args...); memory::Int = 20, $(def_kwargs...)) where {T <: AbstractFloat, FC <: FloatOrComplex{T}}
173+
start_time = time_ns()
174+
workspace = $workspace(KrylovConstructor(b, c, similar(b, 0), similar(c, 0)); memory)
175+
elapsed_time = start_time |> ktimer
176+
timemax -= elapsed_time
177+
$(krylov!)(workspace, $(args...); $(kwargs...))
178+
workspace.stats.timer += elapsed_time
179+
return results(workspace)
180+
end
181+
182+
krylov_solve(::Val{Symbol($krylov)}, $(def_args...); memory::Int = 20, $(def_kwargs...)) where {T <: AbstractFloat, FC <: FloatOrComplex{T}} = $(krylov)($(args...); memory, $(kwargs...))
183+
184+
if !isempty($optargs)
185+
function $(krylov)($(def_args...), $(def_optargs...); memory::Int = 20, $(def_kwargs...)) where {T <: AbstractFloat, FC <: FloatOrComplex{T}}
186+
start_time = time_ns()
187+
workspace = $workspace(KrylovConstructor(b, c, similar(b, 0), similar(c, 0)); memory)
188+
warm_start!(workspace, $(optargs...))
189+
elapsed_time = start_time |> ktimer
190+
timemax -= elapsed_time
191+
$(krylov!)(workspace, $(args...); $(kwargs...))
192+
workspace.stats.timer += elapsed_time
193+
return results(workspace)
194+
end
195+
167196
krylov_solve(::Val{Symbol($krylov)}, $(def_args...), $(def_optargs...); memory::Int = 20, $(def_kwargs...)) where {T <: AbstractFloat, FC <: FloatOrComplex{T}} = $(krylov)($(args...), $(optargs...); memory, $(kwargs...))
168197
end
169198
end
@@ -196,6 +225,35 @@ for (workspace, krylov, args, def_args, optargs, def_optargs, kwargs, def_kwargs
196225
krylov_solve(::Val{Symbol($krylov)}, $(def_args...), $(def_optargs...); window::Int = 5, $(def_kwargs...)) where {T <: AbstractFloat, FC <: FloatOrComplex{T}} = $(krylov)($(args...), $(optargs...); window, $(kwargs...))
197226
end
198227
end
228+
elseif krylov in (:tricg, :trimr)
229+
@eval begin
230+
function $(krylov)($(def_args...); $(def_kwargs...)) where {T <: AbstractFloat, FC <: FloatOrComplex{T}}
231+
start_time = time_ns()
232+
workspace = $workspace(KrylovConstructor(b, c, similar(b, 0), similar(c, 0)))
233+
elapsed_time = start_time |> ktimer
234+
timemax -= elapsed_time
235+
$(krylov!)(workspace, $(args...); $(kwargs...))
236+
workspace.stats.timer += elapsed_time
237+
return results(workspace)
238+
end
239+
240+
krylov_solve(::Val{Symbol($krylov)}, $(def_args...); $(def_kwargs...)) where {T <: AbstractFloat, FC <: FloatOrComplex{T}} = $(krylov)($(args...); $(kwargs...))
241+
242+
if !isempty($optargs)
243+
function $(krylov)($(def_args...), $(def_optargs...); $(def_kwargs...)) where {T <: AbstractFloat, FC <: FloatOrComplex{T}}
244+
start_time = time_ns()
245+
workspace = $workspace(KrylovConstructor(b, c, similar(b, 0), similar(c, 0)))
246+
warm_start!(workspace, $(optargs...))
247+
elapsed_time = start_time |> ktimer
248+
timemax -= elapsed_time
249+
$(krylov!)(workspace, $(args...); $(kwargs...))
250+
workspace.stats.timer += elapsed_time
251+
return results(workspace)
252+
end
253+
254+
krylov_solve(::Val{Symbol($krylov)}, $(def_args...), $(def_optargs...); $(def_kwargs...)) where {T <: AbstractFloat, FC <: FloatOrComplex{T}} = $(krylov)($(args...), $(optargs...); $(kwargs...))
255+
end
256+
end
199257
else
200258
@eval begin
201259
function $(krylov)($(def_args...); $(def_kwargs...)) where {T <: AbstractFloat, FC <: FloatOrComplex{T}}

src/krylov_show.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -63,13 +63,13 @@ end
6363
6464
Statistics of `workspace` are displayed if `show_stats` is set to true.
6565
"""
66-
function show(io :: IO, workspace :: Union{KrylovWorkspace{T,FC,S}, BlockKrylovWorkspace{T,FC,S}}; show_stats :: Bool=true) where {T <: AbstractFloat, FC <: FloatOrComplex{T}, S <: AbstractVector{FC}}
66+
function show(io :: IO, workspace :: Union{KrylovWorkspaceNext{T,FC,Sm,S}, BlockKrylovWorkspace{T,FC,S}}; show_stats :: Bool=true) where {T <: AbstractFloat, FC <: FloatOrComplex{T}, S <: AbstractVector{FC}, Sm <: AbstractVector{FC}}
6767
type_workspace = typeof(workspace)
6868
name_workspace = string(type_workspace.name.name)
6969
name_stats = string(typeof(workspace.stats).name.name)
7070
nbytes = sizeof(workspace)
7171
storage = format_bytes(nbytes)
72-
architecture = S <: Vector ? "CPU" : "GPU"
72+
architecture = S <: Vector ? "CPU" : "GPU" # FIXME cannot assume that all non-Vector types are GPU types
7373
l1 = max(length(name_workspace), length(string(FC)) + 11) # length("Precision: ") = 11
7474
nchar = type_workspace <: Union{CgLanczosShiftWorkspace, FomWorkspace, DiomWorkspace, DqgmresWorkspace, GmresWorkspace, FgmresWorkspace, GpmrWorkspace, BlockGmresWorkspace} ? 8 : 0 # length("Vector{}") = 8
7575
l2 = max(ndigits(workspace.m) + 7, length(architecture) + 14, length(string(S)) + nchar) # length("nrows: ") = 7 and length("Architecture: ") = 14

0 commit comments

Comments
 (0)