From d4d1780c592d6a894cfd66359f2213d2433b1cb4 Mon Sep 17 00:00:00 2001 From: Tim Holy Date: Sat, 22 Nov 2025 04:30:35 -0600 Subject: [PATCH] Allow vm and vn to be of different types For problems `Ax=b`, this allows the vector space of `x` (`vn`) to be different from the vector space of `b` (`vm`); also for some bipartite problems [A B; C D] [x; y] = [b; c] it allows `x` and `y` to be of different types. The key change is to introduce `KrylovWorkspaceNext{T,FC,Sm,Sn}`, for which `Sm` and `Sn` are the types of the length-`m` and length-`n` vectors, respectively. `KrylovWorkspace{T,FC,S}` is now just a type alias for `KrylovWorkspaceNext{T,FC,S,S}`. This is done for reasons of backwards compatibility; in the next breaking release, we should rename `KrylovWorkspaceNext` to `KrylovWorkspace`. This PR exploits the new flexibility for CGLS, LSQR, TRICG, TRIMR, and GPMR. Extending to other solvers is left as an exercise for the reader. Fixes #1037 --- src/cgls.jl | 6 +- src/gpmr.jl | 14 +-- src/interface.jl | 60 ++++++++++- src/krylov_show.jl | 4 +- src/krylov_workspaces.jl | 208 ++++++++++++++++++++------------------- src/lsqr.jl | 8 +- src/tricg.jl | 10 +- src/trimr.jl | 10 +- test/test_cgls.jl | 9 ++ test/test_gpmr.jl | 14 ++- test/test_interface.jl | 8 ++ test/test_lsqr.jl | 9 ++ test/test_tricg.jl | 12 +++ test/test_trimr.jl | 12 +++ test/test_utils.jl | 9 ++ 15 files changed, 265 insertions(+), 128 deletions(-) diff --git a/src/cgls.jl b/src/cgls.jl index 5963d6dc9..e6021d071 100644 --- a/src/cgls.jl +++ b/src/cgls.jl @@ -126,7 +126,7 @@ args_cgls = (:A, :b) kwargs_cgls = (:M, :ldiv, :radius, :λ, :atol, :rtol, :itmax, :timemax, :verbose, :history, :callback, :iostream) @eval begin - function cgls!(workspace :: CglsWorkspace{T,FC,S}, $(def_args_cgls...); $(def_kwargs_cgls...)) where {T <: AbstractFloat, FC <: FloatOrComplex{T}, S <: AbstractVector{FC}} + function cgls!(workspace :: CglsWorkspace{T,FC,Sm,Sn}, $(def_args_cgls...); $(def_kwargs_cgls...)) where {T <: AbstractFloat, FC <: FloatOrComplex{T}, Sm <: AbstractVector{FC}, Sn <: AbstractVector{FC}} # Timer start_time = time_ns() @@ -142,13 +142,13 @@ kwargs_cgls = (:M, :ldiv, :radius, :λ, :atol, :rtol, :itmax, :timemax, :verbose # Check type consistency eltype(A) == FC || @warn "eltype(A) ≠ $FC. This could lead to errors or additional allocations in operator-vector products." - ktypeof(b) == S || error("ktypeof(b) must be equal to $S") + ktypeof(b) == Sm || error("ktypeof(b) must be equal to $Sm") # Compute the adjoint of A Aᴴ = A' # Set up workspace. - allocate_if(!MisI, workspace, :Mr, S, workspace.r) # The length of Mr is m + allocate_if(!MisI, workspace, :Mr, Sm, workspace.r) # The length of Mr is m x, p, s, r, q, stats = workspace.x, workspace.p, workspace.s, workspace.r, workspace.q, workspace.stats rNorms, ArNorms = stats.residuals, stats.Aresiduals reset!(stats) diff --git a/src/gpmr.jl b/src/gpmr.jl index 275f741a3..f95bc2517 100644 --- a/src/gpmr.jl +++ b/src/gpmr.jl @@ -157,7 +157,7 @@ optargs_gpmr = (:x0, :y0) kwargs_gpmr = (:C, :D, :E, :F, :ldiv, :gsp, :λ, :μ, :reorthogonalization, :atol, :rtol, :itmax, :timemax, :verbose, :history, :callback, :iostream) @eval begin - function gpmr!(workspace :: GpmrWorkspace{T,FC,S}, $(def_args_gpmr...); $(def_kwargs_gpmr...)) where {T <: AbstractFloat, FC <: FloatOrComplex{T}, S <: AbstractVector{FC}} + function gpmr!(workspace :: GpmrWorkspace{T,FC,Sm,Sn}, $(def_args_gpmr...); $(def_kwargs_gpmr...)) where {T <: AbstractFloat, FC <: FloatOrComplex{T}, Sm <: AbstractVector{FC}, Sn <: AbstractVector{FC}} # Timer start_time = time_ns() @@ -181,8 +181,8 @@ kwargs_gpmr = (:C, :D, :E, :F, :ldiv, :gsp, :λ, :μ, :reorthogonalization, :ato # Check type consistency eltype(A) == FC || @warn "eltype(A) ≠ $FC. This could lead to errors or additional allocations in operator-vector products." eltype(B) == FC || @warn "eltype(B) ≠ $FC. This could lead to errors or additional allocations in operator-vector products." - ktypeof(b) == S || error("ktypeof(b) must be equal to $S") - ktypeof(c) == S || error("ktypeof(c) must be equal to $S") + ktypeof(b) == Sm || error("ktypeof(b) must be equal to $Sm") + ktypeof(c) == Sn || error("ktypeof(c) must be equal to $Sn") # Determine λ and μ associated to generalized saddle point systems. gsp && (λ = one(FC) ; μ = zero(FC)) @@ -192,10 +192,10 @@ kwargs_gpmr = (:C, :D, :E, :F, :ldiv, :gsp, :λ, :μ, :reorthogonalization, :ato warm_start && (μ ≠ 0) && !FisI && error("Warm-start with right preconditioners is not supported.") # Set up workspace. - allocate_if(!CisI, workspace, :q , S, workspace.x) # The length of q is m - allocate_if(!DisI, workspace, :p , S, workspace.y) # The length of p is n - allocate_if(!EisI, workspace, :wB, S, workspace.x) # The length of wB is m - allocate_if(!FisI, workspace, :wA, S, workspace.y) # The length of wA is n + allocate_if(!CisI, workspace, :q , Sm, workspace.x) # The length of q is m + allocate_if(!DisI, workspace, :p , Sn, workspace.y) # The length of p is n + allocate_if(!EisI, workspace, :wB, Sm, workspace.x) # The length of wB is m + allocate_if(!FisI, workspace, :wA, Sn, workspace.y) # The length of wA is n wA, wB, dA, dB, Δx, Δy = workspace.wA, workspace.wB, workspace.dA, workspace.dB, workspace.Δx, workspace.Δy x, y, V, U, gs, gc = workspace.x, workspace.y, workspace.V, workspace.U, workspace.gs, workspace.gc zt, R, stats = workspace.zt, workspace.R, workspace.stats diff --git a/src/interface.jl b/src/interface.jl index f34a4cfc6..9330f4414 100644 --- a/src/interface.jl +++ b/src/interface.jl @@ -138,7 +138,7 @@ for (workspace, krylov, args, def_args, optargs, def_optargs, kwargs, def_kwargs krylov_solve(::Val{Symbol($krylov)}, $(def_args...); $(def_kwargs...)) where {T <: AbstractFloat, FC <: FloatOrComplex{T}} = $(krylov)($(args...); $(kwargs...)) end - elseif krylov in (:diom, :dqgmres, :fom, :gmres, :fgmres, :gpmr) + elseif krylov in (:diom, :dqgmres, :fom, :gmres, :fgmres) @eval begin function $(krylov)($(def_args...); memory::Int = 20, $(def_kwargs...)) where {T <: AbstractFloat, FC <: FloatOrComplex{T}} start_time = time_ns() @@ -164,6 +164,35 @@ for (workspace, krylov, args, def_args, optargs, def_optargs, kwargs, def_kwargs return results(workspace) end + krylov_solve(::Val{Symbol($krylov)}, $(def_args...), $(def_optargs...); memory::Int = 20, $(def_kwargs...)) where {T <: AbstractFloat, FC <: FloatOrComplex{T}} = $(krylov)($(args...), $(optargs...); memory, $(kwargs...)) + end + end + elseif krylov == :gpmr + @eval begin + function $(krylov)($(def_args...); memory::Int = 20, $(def_kwargs...)) where {T <: AbstractFloat, FC <: FloatOrComplex{T}} + start_time = time_ns() + workspace = $workspace(KrylovConstructor(b, c, similar(b, 0), similar(c, 0)); memory) + elapsed_time = start_time |> ktimer + timemax -= elapsed_time + $(krylov!)(workspace, $(args...); $(kwargs...)) + workspace.stats.timer += elapsed_time + return results(workspace) + end + + krylov_solve(::Val{Symbol($krylov)}, $(def_args...); memory::Int = 20, $(def_kwargs...)) where {T <: AbstractFloat, FC <: FloatOrComplex{T}} = $(krylov)($(args...); memory, $(kwargs...)) + + if !isempty($optargs) + function $(krylov)($(def_args...), $(def_optargs...); memory::Int = 20, $(def_kwargs...)) where {T <: AbstractFloat, FC <: FloatOrComplex{T}} + start_time = time_ns() + workspace = $workspace(KrylovConstructor(b, c, similar(b, 0), similar(c, 0)); memory) + warm_start!(workspace, $(optargs...)) + elapsed_time = start_time |> ktimer + timemax -= elapsed_time + $(krylov!)(workspace, $(args...); $(kwargs...)) + workspace.stats.timer += elapsed_time + return results(workspace) + end + krylov_solve(::Val{Symbol($krylov)}, $(def_args...), $(def_optargs...); memory::Int = 20, $(def_kwargs...)) where {T <: AbstractFloat, FC <: FloatOrComplex{T}} = $(krylov)($(args...), $(optargs...); memory, $(kwargs...)) end end @@ -196,6 +225,35 @@ for (workspace, krylov, args, def_args, optargs, def_optargs, kwargs, def_kwargs krylov_solve(::Val{Symbol($krylov)}, $(def_args...), $(def_optargs...); window::Int = 5, $(def_kwargs...)) where {T <: AbstractFloat, FC <: FloatOrComplex{T}} = $(krylov)($(args...), $(optargs...); window, $(kwargs...)) end end + elseif krylov in (:tricg, :trimr) + @eval begin + function $(krylov)($(def_args...); $(def_kwargs...)) where {T <: AbstractFloat, FC <: FloatOrComplex{T}} + start_time = time_ns() + workspace = $workspace(KrylovConstructor(b, c, similar(b, 0), similar(c, 0))) + elapsed_time = start_time |> ktimer + timemax -= elapsed_time + $(krylov!)(workspace, $(args...); $(kwargs...)) + workspace.stats.timer += elapsed_time + return results(workspace) + end + + krylov_solve(::Val{Symbol($krylov)}, $(def_args...); $(def_kwargs...)) where {T <: AbstractFloat, FC <: FloatOrComplex{T}} = $(krylov)($(args...); $(kwargs...)) + + if !isempty($optargs) + function $(krylov)($(def_args...), $(def_optargs...); $(def_kwargs...)) where {T <: AbstractFloat, FC <: FloatOrComplex{T}} + start_time = time_ns() + workspace = $workspace(KrylovConstructor(b, c, similar(b, 0), similar(c, 0))) + warm_start!(workspace, $(optargs...)) + elapsed_time = start_time |> ktimer + timemax -= elapsed_time + $(krylov!)(workspace, $(args...); $(kwargs...)) + workspace.stats.timer += elapsed_time + return results(workspace) + end + + krylov_solve(::Val{Symbol($krylov)}, $(def_args...), $(def_optargs...); $(def_kwargs...)) where {T <: AbstractFloat, FC <: FloatOrComplex{T}} = $(krylov)($(args...), $(optargs...); $(kwargs...)) + end + end else @eval begin function $(krylov)($(def_args...); $(def_kwargs...)) where {T <: AbstractFloat, FC <: FloatOrComplex{T}} diff --git a/src/krylov_show.jl b/src/krylov_show.jl index 3cec3c77b..b362be294 100644 --- a/src/krylov_show.jl +++ b/src/krylov_show.jl @@ -63,13 +63,13 @@ end Statistics of `workspace` are displayed if `show_stats` is set to true. """ -function show(io :: IO, workspace :: Union{KrylovWorkspace{T,FC,S}, BlockKrylovWorkspace{T,FC,S}}; show_stats :: Bool=true) where {T <: AbstractFloat, FC <: FloatOrComplex{T}, S <: AbstractVector{FC}} +function show(io :: IO, workspace :: Union{KrylovWorkspaceNext{T,FC,S}, BlockKrylovWorkspace{T,FC,S}}; show_stats :: Bool=true) where {T <: AbstractFloat, FC <: FloatOrComplex{T}, S <: AbstractVector{FC}} type_workspace = typeof(workspace) name_workspace = string(type_workspace.name.name) name_stats = string(typeof(workspace.stats).name.name) nbytes = sizeof(workspace) storage = format_bytes(nbytes) - architecture = S <: Vector ? "CPU" : "GPU" + architecture = S <: Vector ? "CPU" : "GPU" # FIXME cannot assume that all non-Vector types are GPU types l1 = max(length(name_workspace), length(string(FC)) + 11) # length("Precision: ") = 11 nchar = type_workspace <: Union{CgLanczosShiftWorkspace, FomWorkspace, DiomWorkspace, DqgmresWorkspace, GmresWorkspace, FgmresWorkspace, GpmrWorkspace, BlockGmresWorkspace} ? 8 : 0 # length("Vector{}") = 8 l2 = max(ndigits(workspace.m) + 7, length(architecture) + 14, length(string(S)) + nchar) # length("nrows: ") = 7 and length("Architecture: ") = 14 diff --git a/src/krylov_workspaces.jl b/src/krylov_workspaces.jl index 199854c2f..34b4be6c2 100644 --- a/src/krylov_workspaces.jl +++ b/src/krylov_workspaces.jl @@ -32,11 +32,20 @@ For rectangular problems (`m ≠ n`), use the second constructor with `vm` and ` Empty vectors `vm_empty` and `vn_empty` reduce storage requirements when features such as warm-start or preconditioners are unused. These empty vectors will be replaced within a [`KrylovWorkspace`](@ref) only if required, such as when preconditioners are provided. """ -struct KrylovConstructor{S} - vm::S - vn::S - vm_empty::S - vn_empty::S +struct KrylovConstructor{Sm, Sn} + vm::Sm + vn::Sn + vm_empty::Sm + vn_empty::Sn + + function KrylovConstructor{Sm, Sn}(vm, vn, vm_empty, vn_empty) where {Sm, Sn} + eltype(Sm) === eltype(Sn) || throw(ArgumentError("KrylovConstructor requires that eltype(Sm) == eltype(Sn), got $(eltype(Sm)) and $(eltype(Sn))")) + return new{Sm, Sn}(vm, vn, vm_empty, vn_empty) + end +end + +function KrylovConstructor(vm::Sm, vn::Sn, vm_empty, vn_empty) where {Sm, Sn} + return KrylovConstructor{Sm, Sn}(vm, vn, vm_empty, vn_empty) end function KrylovConstructor(vm; vm_empty=vm) @@ -47,8 +56,12 @@ function KrylovConstructor(vm, vn; vm_empty=vm, vn_empty=vn) return KrylovConstructor(vm, vn, vm_empty, vn_empty) end +# TODO: in the next breaking release, change to KrylovWorkspace{T,FC,Sm,Sn} +# and delete the alias below +abstract type KrylovWorkspaceNext{T,FC,Sm,Sn} end + "Abstract type for using Krylov solvers in-place." -abstract type KrylovWorkspace{T,FC,S} end +const KrylovWorkspace{T,FC,S} = KrylovWorkspaceNext{T,FC,S,S} """ Workspace for the in-place method [`minres!`](@ref). @@ -975,32 +988,31 @@ The following outer constructors can be used to initialize this workspace: workspace = TricgWorkspace(A, b) workspace = TricgWorkspace(kc::KrylovConstructor) """ -mutable struct TricgWorkspace{T,FC,S} <: KrylovWorkspace{T,FC,S} +mutable struct TricgWorkspace{T,FC,Sm,Sn} <: KrylovWorkspaceNext{T,FC,Sm,Sn} m :: Int n :: Int - y :: S - N⁻¹uₖ₋₁ :: S - N⁻¹uₖ :: S - p :: S - gy₂ₖ₋₁ :: S - gy₂ₖ :: S - x :: S - M⁻¹vₖ₋₁ :: S - M⁻¹vₖ :: S - q :: S - gx₂ₖ₋₁ :: S - gx₂ₖ :: S - Δx :: S - Δy :: S - uₖ :: S - vₖ :: S + y :: Sn + N⁻¹uₖ₋₁ :: Sn + N⁻¹uₖ :: Sn + p :: Sn + gy₂ₖ₋₁ :: Sn + gy₂ₖ :: Sn + x :: Sm + M⁻¹vₖ₋₁ :: Sm + M⁻¹vₖ :: Sm + q :: Sm + gx₂ₖ₋₁ :: Sm + gx₂ₖ :: Sm + Δx :: Sm + Δy :: Sn + uₖ :: Sn + vₖ :: Sm warm_start :: Bool stats :: SimpleStats{T} end -function TricgWorkspace(kc::KrylovConstructor) - S = typeof(kc.vm) - FC = eltype(S) +function TricgWorkspace(kc::KrylovConstructor{Sm,Sn}) where {Sm,Sn} + FC = eltype(Sm) T = real(FC) m = length(kc.vm) n = length(kc.vn) @@ -1021,7 +1033,7 @@ function TricgWorkspace(kc::KrylovConstructor) uₖ = similar(kc.vn_empty) vₖ = similar(kc.vm_empty) stats = SimpleStats(0, false, false, false, 0, T[], T[], T[], 0.0, "unknown") - workspace = TricgWorkspace{T,FC,S}(m, n, y, N⁻¹uₖ₋₁, N⁻¹uₖ, p, gy₂ₖ₋₁, gy₂ₖ, x, M⁻¹vₖ₋₁, M⁻¹vₖ, q, gx₂ₖ₋₁, gx₂ₖ, Δx, Δy, uₖ, vₖ, false, stats) + workspace = TricgWorkspace{T,FC,Sm,Sn}(m, n, y, N⁻¹uₖ₋₁, N⁻¹uₖ, p, gy₂ₖ₋₁, gy₂ₖ, x, M⁻¹vₖ₋₁, M⁻¹vₖ, q, gx₂ₖ₋₁, gx₂ₖ, Δx, Δy, uₖ, vₖ, false, stats) return workspace end @@ -1046,7 +1058,7 @@ function TricgWorkspace(m::Integer, n::Integer, S::Type) vₖ = S(undef, 0) S = isconcretetype(S) ? S : typeof(x) stats = SimpleStats(0, false, false, false, 0, T[], T[], T[], 0.0, "unknown") - workspace = TricgWorkspace{T,FC,S}(m, n, y, N⁻¹uₖ₋₁, N⁻¹uₖ, p, gy₂ₖ₋₁, gy₂ₖ, x, M⁻¹vₖ₋₁, M⁻¹vₖ, q, gx₂ₖ₋₁, gx₂ₖ, Δx, Δy, uₖ, vₖ, false, stats) + workspace = TricgWorkspace{T,FC,S,S}(m, n, y, N⁻¹uₖ₋₁, N⁻¹uₖ, p, gy₂ₖ₋₁, gy₂ₖ, x, M⁻¹vₖ₋₁, M⁻¹vₖ, q, gx₂ₖ₋₁, gx₂ₖ, Δx, Δy, uₖ, vₖ, false, stats) return workspace end @@ -1065,36 +1077,35 @@ The following outer constructors can be used to initialize this workspace: workspace = TrimrWorkspace(A, b) workspace = TrimrWorkspace(kc::KrylovConstructor) """ -mutable struct TrimrWorkspace{T,FC,S} <: KrylovWorkspace{T,FC,S} +mutable struct TrimrWorkspace{T,FC,Sm,Sn} <: KrylovWorkspaceNext{T,FC,Sm,Sn} m :: Int n :: Int - y :: S - N⁻¹uₖ₋₁ :: S - N⁻¹uₖ :: S - p :: S - gy₂ₖ₋₃ :: S - gy₂ₖ₋₂ :: S - gy₂ₖ₋₁ :: S - gy₂ₖ :: S - x :: S - M⁻¹vₖ₋₁ :: S - M⁻¹vₖ :: S - q :: S - gx₂ₖ₋₃ :: S - gx₂ₖ₋₂ :: S - gx₂ₖ₋₁ :: S - gx₂ₖ :: S - Δx :: S - Δy :: S - uₖ :: S - vₖ :: S + y :: Sn + N⁻¹uₖ₋₁ :: Sn + N⁻¹uₖ :: Sn + p :: Sn + gy₂ₖ₋₃ :: Sn + gy₂ₖ₋₂ :: Sn + gy₂ₖ₋₁ :: Sn + gy₂ₖ :: Sn + x :: Sm + M⁻¹vₖ₋₁ :: Sm + M⁻¹vₖ :: Sm + q :: Sm + gx₂ₖ₋₃ :: Sm + gx₂ₖ₋₂ :: Sm + gx₂ₖ₋₁ :: Sm + gx₂ₖ :: Sm + Δx :: Sm + Δy :: Sn + uₖ :: Sn + vₖ :: Sm warm_start :: Bool stats :: SimpleStats{T} end -function TrimrWorkspace(kc::KrylovConstructor) - S = typeof(kc.vm) - FC = eltype(S) +function TrimrWorkspace(kc::KrylovConstructor{Sm,Sn}) where {Sm,Sn} + FC = eltype(Sm) T = real(FC) m = length(kc.vm) n = length(kc.vn) @@ -1119,7 +1130,7 @@ function TrimrWorkspace(kc::KrylovConstructor) uₖ = similar(kc.vn_empty) vₖ = similar(kc.vm_empty) stats = SimpleStats(0, false, false, false, 0, T[], T[], T[], 0.0, "unknown") - workspace = TrimrWorkspace{T,FC,S}(m, n, y, N⁻¹uₖ₋₁, N⁻¹uₖ, p, gy₂ₖ₋₃, gy₂ₖ₋₂, gy₂ₖ₋₁, gy₂ₖ, x, M⁻¹vₖ₋₁, M⁻¹vₖ, q, gx₂ₖ₋₃, gx₂ₖ₋₂, gx₂ₖ₋₁, gx₂ₖ, Δx, Δy, uₖ, vₖ, false, stats) + workspace = TrimrWorkspace{T,FC,Sm,Sn}(m, n, y, N⁻¹uₖ₋₁, N⁻¹uₖ, p, gy₂ₖ₋₃, gy₂ₖ₋₂, gy₂ₖ₋₁, gy₂ₖ, x, M⁻¹vₖ₋₁, M⁻¹vₖ, q, gx₂ₖ₋₃, gx₂ₖ₋₂, gx₂ₖ₋₁, gx₂ₖ, Δx, Δy, uₖ, vₖ, false, stats) return workspace end @@ -1148,7 +1159,7 @@ function TrimrWorkspace(m::Integer, n::Integer, S::Type) vₖ = S(undef, 0) S = isconcretetype(S) ? S : typeof(x) stats = SimpleStats(0, false, false, false, 0, T[], T[], T[], 0.0, "unknown") - workspace = TrimrWorkspace{T,FC,S}(m, n, y, N⁻¹uₖ₋₁, N⁻¹uₖ, p, gy₂ₖ₋₃, gy₂ₖ₋₂, gy₂ₖ₋₁, gy₂ₖ, x, M⁻¹vₖ₋₁, M⁻¹vₖ, q, gx₂ₖ₋₃, gx₂ₖ₋₂, gx₂ₖ₋₁, gx₂ₖ, Δx, Δy, uₖ, vₖ, false, stats) + workspace = TrimrWorkspace{T,FC,S,S}(m, n, y, N⁻¹uₖ₋₁, N⁻¹uₖ, p, gy₂ₖ₋₃, gy₂ₖ₋₂, gy₂ₖ₋₁, gy₂ₖ, x, M⁻¹vₖ₋₁, M⁻¹vₖ, q, gx₂ₖ₋₃, gx₂ₖ₋₂, gx₂ₖ₋₁, gx₂ₖ, Δx, Δy, uₖ, vₖ, false, stats) return workspace end @@ -1620,21 +1631,20 @@ The following outer constructors can be used to initialize this workspace: workspace = CglsWorkspace(A, b) workspace = CglsWorkspace(kc::KrylovConstructor) """ -mutable struct CglsWorkspace{T,FC,S} <: KrylovWorkspace{T,FC,S} +mutable struct CglsWorkspace{T,FC,Sm,Sn} <: KrylovWorkspaceNext{T,FC,Sm,Sn} m :: Int n :: Int - x :: S - p :: S - s :: S - r :: S - q :: S - Mr :: S + x :: Sn + p :: Sn + s :: Sn + r :: Sm + q :: Sm + Mr :: Sm stats :: SimpleStats{T} end -function CglsWorkspace(kc::KrylovConstructor) - S = typeof(kc.vm) - FC = eltype(S) +function CglsWorkspace(kc::KrylovConstructor{Sm,Sn}) where {Sm,Sn} + FC = eltype(Sm) # Sn has the same eltype T = real(FC) m = length(kc.vm) n = length(kc.vn) @@ -1645,7 +1655,7 @@ function CglsWorkspace(kc::KrylovConstructor) q = similar(kc.vm) Mr = similar(kc.vm_empty) stats = SimpleStats(0, false, false, false, 0, T[], T[], T[], 0.0, "unknown") - workspace = CglsWorkspace{T,FC,S}(m, n, x, p, s, r, q, Mr, stats) + workspace = CglsWorkspace{T,FC,Sm,Sn}(m, n, x, p, s, r, q, Mr, stats) return workspace end @@ -1660,7 +1670,7 @@ function CglsWorkspace(m::Integer, n::Integer, S::Type) Mr = S(undef, 0) S = isconcretetype(S) ? S : typeof(x) stats = SimpleStats(0, false, false, false, 0, T[], T[], T[], 0.0, "unknown") - workspace = CglsWorkspace{T,FC,S}(m, n, x, p, s, r, q, Mr, stats) + workspace = CglsWorkspace{T,FC,S,S}(m, n, x, p, s, r, q, Mr, stats) return workspace end @@ -2022,24 +2032,23 @@ The following outer constructors can be used to initialize this workspace: workspace = LsqrWorkspace(A, b) workspace = LsqrWorkspace(kc::KrylovConstructor) """ -mutable struct LsqrWorkspace{T,FC,S} <: KrylovWorkspace{T,FC,S} +mutable struct LsqrWorkspace{T,FC,Sm,Sn} <: KrylovWorkspaceNext{T,FC,Sm,Sn} m :: Int n :: Int - x :: S - Nv :: S - Aᴴu :: S - w :: S - Mu :: S - Av :: S - u :: S - v :: S + x :: Sn + Nv :: Sn + Aᴴu :: Sn + w :: Sn + Mu :: Sm + Av :: Sm + u :: Sm + v :: Sn err_vec :: Vector{T} stats :: SimpleStats{T} end -function LsqrWorkspace(kc::KrylovConstructor; window::Integer = 5) - S = typeof(kc.vm) - FC = eltype(S) +function LsqrWorkspace(kc::KrylovConstructor{Sm,Sn}; window::Integer = 5) where {Sm,Sn} + FC = eltype(Sm) T = real(FC) m = length(kc.vm) n = length(kc.vn) @@ -2053,7 +2062,7 @@ function LsqrWorkspace(kc::KrylovConstructor; window::Integer = 5) v = similar(kc.vn_empty) err_vec = zeros(T, window) stats = SimpleStats(0, false, false, false, 0, T[], T[], T[], 0.0, "unknown") - workspace = LsqrWorkspace{T,FC,S}(m, n, x, Nv, Aᴴu, w, Mu, Av, u, v, err_vec, stats) + workspace = LsqrWorkspace{T,FC,Sm,Sn}(m, n, x, Nv, Aᴴu, w, Mu, Av, u, v, err_vec, stats) return workspace end @@ -2071,7 +2080,7 @@ function LsqrWorkspace(m::Integer, n::Integer, S::Type; window::Integer = 5) err_vec = zeros(T, window) S = isconcretetype(S) ? S : typeof(x) stats = SimpleStats(0, false, false, false, 0, T[], T[], T[], 0.0, "unknown") - workspace = LsqrWorkspace{T,FC,S}(m, n, x, Nv, Aᴴu, w, Mu, Av, u, v, err_vec, stats) + workspace = LsqrWorkspace{T,FC,S,S}(m, n, x, Nv, Aᴴu, w, Mu, Av, u, v, err_vec, stats) return workspace end @@ -2609,21 +2618,21 @@ The following outer constructors can be used to initialize this workspace: `memory` is set to `n + m` if the value given is larger than `n + m`. """ -mutable struct GpmrWorkspace{T,FC,S} <: KrylovWorkspace{T,FC,S} +mutable struct GpmrWorkspace{T,FC,Sm,Sn} <: KrylovWorkspaceNext{T,FC,Sm,Sn} m :: Int n :: Int - wA :: S - wB :: S - dA :: S - dB :: S - Δx :: S - Δy :: S - x :: S - y :: S - q :: S - p :: S - V :: Vector{S} - U :: Vector{S} + wA :: Sn + wB :: Sm + dA :: Sm + dB :: Sn + Δx :: Sm + Δy :: Sn + x :: Sm + y :: Sn + q :: Sm + p :: Sn + V :: Vector{Sm} + U :: Vector{Sn} gs :: Vector{FC} gc :: Vector{T} zt :: Vector{FC} @@ -2632,9 +2641,8 @@ mutable struct GpmrWorkspace{T,FC,S} <: KrylovWorkspace{T,FC,S} stats :: SimpleStats{T} end -function GpmrWorkspace(kc::KrylovConstructor; memory::Integer = 20) - S = typeof(kc.vm) - FC = eltype(S) +function GpmrWorkspace(kc::KrylovConstructor{Sm,Sn}; memory::Integer = 20) where {Sm,Sn} + FC = eltype(Sm) T = real(FC) m = length(kc.vm) n = length(kc.vn) @@ -2649,14 +2657,14 @@ function GpmrWorkspace(kc::KrylovConstructor; memory::Integer = 20) y = similar(kc.vn) q = similar(kc.vm_empty) p = similar(kc.vn_empty) - V = S[similar(kc.vm) for i = 1 : memory] - U = S[similar(kc.vn) for i = 1 : memory] + V = Sm[similar(kc.vm) for i = 1 : memory] + U = Sn[similar(kc.vn) for i = 1 : memory] gs = Vector{FC}(undef, 4 * memory) gc = Vector{T}(undef, 4 * memory) zt = Vector{FC}(undef, 2 * memory) R = Vector{FC}(undef, memory * (2 * memory + 1)) stats = SimpleStats(0, false, false, false, 0, T[], T[], T[], 0.0, "unknown") - workspace = GpmrWorkspace{T,FC,S}(m, n, wA, wB, dA, dB, Δx, Δy, x, y, q, p, V, U, gs, gc, zt, R, false, stats) + workspace = GpmrWorkspace{T,FC,Sm,Sn}(m, n, wA, wB, dA, dB, Δx, Δy, x, y, q, p, V, U, gs, gc, zt, R, false, stats) return workspace end @@ -2682,7 +2690,7 @@ function GpmrWorkspace(m::Integer, n::Integer, S::Type; memory::Integer = 20) R = Vector{FC}(undef, memory * (2 * memory + 1)) S = isconcretetype(S) ? S : typeof(x) stats = SimpleStats(0, false, false, false, 0, T[], T[], T[], 0.0, "unknown") - workspace = GpmrWorkspace{T,FC,S}(m, n, wA, wB, dA, dB, Δx, Δy, x, y, q, p, V, U, gs, gc, zt, R, false, stats) + workspace = GpmrWorkspace{T,FC,S,S}(m, n, wA, wB, dA, dB, Δx, Δy, x, y, q, p, V, U, gs, gc, zt, R, false, stats) return workspace end diff --git a/src/lsqr.jl b/src/lsqr.jl index ade3b6330..5c8bd0894 100644 --- a/src/lsqr.jl +++ b/src/lsqr.jl @@ -164,7 +164,7 @@ args_lsqr = (:A, :b) kwargs_lsqr = (:M, :N, :ldiv, :sqd, :λ, :radius, :etol, :axtol, :btol, :conlim, :atol, :rtol, :itmax, :timemax, :verbose, :history, :callback, :iostream) @eval begin - function lsqr!(workspace :: LsqrWorkspace{T,FC,S}, $(def_args_lsqr...); $(def_kwargs_lsqr...)) where {T <: AbstractFloat, FC <: FloatOrComplex{T}, S <: AbstractVector{FC}} + function lsqr!(workspace :: LsqrWorkspace{T,FC,Sm,Sn}, $(def_args_lsqr...); $(def_kwargs_lsqr...)) where {T <: AbstractFloat, FC <: FloatOrComplex{T}, Sm <: AbstractVector{FC}, Sn <: AbstractVector{FC}} # Timer start_time = time_ns() @@ -185,14 +185,14 @@ kwargs_lsqr = (:M, :N, :ldiv, :sqd, :λ, :radius, :etol, :axtol, :btol, :conlim, # Check type consistency eltype(A) == FC || @warn "eltype(A) ≠ $FC. This could lead to errors or additional allocations in operator-vector products." - ktypeof(b) == S || error("ktypeof(b) must be equal to $S") + ktypeof(b) == Sm || error("ktypeof(b) must be equal to $Sm") # Compute the adjoint of A Aᴴ = A' # Set up workspace. - allocate_if(!MisI, workspace, :u, S, workspace.Av) # The length of u is m - allocate_if(!NisI, workspace, :v, S, workspace.x) # The length of v is n + allocate_if(!MisI, workspace, :u, Sm, workspace.Av) # The length of u is m + allocate_if(!NisI, workspace, :v, Sn, workspace.x) # The length of v is n x, Nv, Aᴴu, w = workspace.x, workspace.Nv, workspace.Aᴴu, workspace.w Mu, Av, err_vec, stats = workspace.Mu, workspace.Av, workspace.err_vec, workspace.stats rNorms, ArNorms = stats.residuals, stats.Aresiduals diff --git a/src/tricg.jl b/src/tricg.jl index 74bd12562..e70f3f951 100644 --- a/src/tricg.jl +++ b/src/tricg.jl @@ -146,7 +146,7 @@ optargs_tricg = (:x0, :y0) kwargs_tricg = (:M, :N, :ldiv, :spd, :snd, :flip, :τ, :ν, :atol, :rtol, :itmax, :timemax, :verbose, :history, :callback, :iostream) @eval begin - function tricg!(workspace :: TricgWorkspace{T,FC,S}, $(def_args_tricg...); $(def_kwargs_tricg...)) where {T <: AbstractFloat, FC <: FloatOrComplex{T}, S <: AbstractVector{FC}} + function tricg!(workspace :: TricgWorkspace{T,FC,Sm,Sn}, $(def_args_tricg...); $(def_kwargs_tricg...)) where {T <: AbstractFloat, FC <: FloatOrComplex{T}, Sm <: AbstractVector{FC}, Sn <: AbstractVector{FC}} # Timer start_time = time_ns() @@ -169,8 +169,8 @@ kwargs_tricg = (:M, :N, :ldiv, :spd, :snd, :flip, :τ, :ν, :atol, :rtol, :itmax # Check type consistency eltype(A) == FC || @warn "eltype(A) ≠ $FC. This could lead to errors or additional allocations in operator-vector products." - ktypeof(b) == S || error("ktypeof(b) must be equal to $S") - ktypeof(c) == S || error("ktypeof(c) must be equal to $S") + ktypeof(b) == Sm || error("ktypeof(b) must be equal to $Sm") + ktypeof(c) == Sn || error("ktypeof(c) must be equal to $Sn") # Determine τ and ν associated to SQD, SPD or SND systems. flip && (τ = -one(T) ; ν = one(T)) @@ -185,8 +185,8 @@ kwargs_tricg = (:M, :N, :ldiv, :spd, :snd, :flip, :τ, :ν, :atol, :rtol, :itmax Aᴴ = A' # Set up workspace. - allocate_if(!MisI, workspace, :vₖ, S, workspace.x) # The length of vₖ is m - allocate_if(!NisI, workspace, :uₖ, S, workspace.y) # The length of uₖ is n + allocate_if(!MisI, workspace, :vₖ, Sm, workspace.x) # The length of vₖ is m + allocate_if(!NisI, workspace, :uₖ, Sn, workspace.y) # The length of uₖ is n Δy, yₖ, N⁻¹uₖ₋₁, N⁻¹uₖ, p = workspace.Δy, workspace.y, workspace.N⁻¹uₖ₋₁, workspace.N⁻¹uₖ, workspace.p Δx, xₖ, M⁻¹vₖ₋₁, M⁻¹vₖ, q = workspace.Δx, workspace.x, workspace.M⁻¹vₖ₋₁, workspace.M⁻¹vₖ, workspace.q gy₂ₖ₋₁, gy₂ₖ, gx₂ₖ₋₁, gx₂ₖ = workspace.gy₂ₖ₋₁, workspace.gy₂ₖ, workspace.gx₂ₖ₋₁, workspace.gx₂ₖ diff --git a/src/trimr.jl b/src/trimr.jl index 5482e6624..e8107f172 100644 --- a/src/trimr.jl +++ b/src/trimr.jl @@ -147,7 +147,7 @@ optargs_trimr = (:x0, :y0) kwargs_trimr = (:M, :N, :ldiv, :spd, :snd, :flip, :sp, :τ, :ν, :atol, :rtol, :itmax, :timemax, :verbose, :history, :callback, :iostream) @eval begin - function trimr!(workspace :: TrimrWorkspace{T,FC,S}, $(def_args_trimr...); $(def_kwargs_trimr...)) where {T <: AbstractFloat, FC <: FloatOrComplex{T}, S <: AbstractVector{FC}} + function trimr!(workspace :: TrimrWorkspace{T,FC,Sm,Sn}, $(def_args_trimr...); $(def_kwargs_trimr...)) where {T <: AbstractFloat, FC <: FloatOrComplex{T}, Sm <: AbstractVector{FC}, Sn <: AbstractVector{FC}} # Timer start_time = time_ns() @@ -173,8 +173,8 @@ kwargs_trimr = (:M, :N, :ldiv, :spd, :snd, :flip, :sp, :τ, :ν, :atol, :rtol, : # Check type consistency eltype(A) == FC || @warn "eltype(A) ≠ $FC. This could lead to errors or additional allocations in operator-vector products." - ktypeof(b) == S || error("ktypeof(b) must be equal to $S") - ktypeof(c) == S || error("ktypeof(c) must be equal to $S") + ktypeof(b) == Sm || error("ktypeof(b) must be equal to $Sm") + ktypeof(c) == Sn || error("ktypeof(c) must be equal to $Sn") # Determine τ and ν associated to SQD, SPD or SND systems. flip && (τ = -one(T) ; ν = one(T)) @@ -190,8 +190,8 @@ kwargs_trimr = (:M, :N, :ldiv, :spd, :snd, :flip, :sp, :τ, :ν, :atol, :rtol, : Aᴴ = A' # Set up workspace. - allocate_if(!MisI, workspace, :vₖ, S, workspace.x) # The length of vₖ is m - allocate_if(!NisI, workspace, :uₖ, S, workspace.y) # The length of uₖ is n + allocate_if(!MisI, workspace, :vₖ, Sm, workspace.x) # The length of vₖ is m + allocate_if(!NisI, workspace, :uₖ, Sn, workspace.y) # The length of uₖ is n Δy, yₖ, N⁻¹uₖ₋₁, N⁻¹uₖ, p = workspace.Δy, workspace.y, workspace.N⁻¹uₖ₋₁, workspace.N⁻¹uₖ, workspace.p Δx, xₖ, M⁻¹vₖ₋₁, M⁻¹vₖ, q = workspace.Δx, workspace.x, workspace.M⁻¹vₖ₋₁, workspace.M⁻¹vₖ, workspace.q gy₂ₖ₋₃, gy₂ₖ₋₂, gy₂ₖ₋₁, gy₂ₖ = workspace.gy₂ₖ₋₃, workspace.gy₂ₖ₋₂, workspace.gy₂ₖ₋₁, workspace.gy₂ₖ diff --git a/test/test_cgls.jl b/test/test_cgls.jl index 6e314c04e..27e72dcfd 100644 --- a/test/test_cgls.jl +++ b/test/test_cgls.jl @@ -47,6 +47,15 @@ (x, stats) = cgls(A, b, M=D⁻¹, λ=1.0) end + # Test different types for input and output + A, b, c, D = small_sp(false, FC=FC) + workspace = CglsWorkspace(KrylovConstructor(TestVector(b), c)) + cgls!(workspace, A, TestVector(b), M=inv(D), λ=1.0) + @test typeof(workspace.x) === typeof(c) + workspace = CglsWorkspace(KrylovConstructor(b, TestVector(c))) + cgls!(workspace, A, b, M=inv(D), λ=1.0) + @test typeof(workspace.x) === typeof(TestVector(c)) + # test callback function A, b, M = saddle_point(FC=FC) M⁻¹ = inv(M) diff --git a/test/test_gpmr.jl b/test/test_gpmr.jl index 0363fc7ff..2082df5b5 100644 --- a/test/test_gpmr.jl +++ b/test/test_gpmr.jl @@ -301,8 +301,20 @@ (x, y, stats) = gpmr(A, A', b, c, C=M⁻¹, D=N⁻¹) end + # Test different types for b and c + A, b, _ = saddle_point(FC=FC) + (x, y, stats) = gpmr(A, A', b, -b) + c = TestVector(-b) + workspace = GpmrWorkspace(KrylovConstructor(b, c)) + @test typeof(workspace.x) === typeof(b) + @test typeof(workspace.y) === typeof(c) + gpmr!(workspace, A, A', b, c) + xt, yt = workspace.x, workspace.y + @test typeof(xt) === typeof(b) + @test typeof(yt) === typeof(c) + # test callback function - # Not testing with an interesting callback because workspace.x and workspace.y are not updated + # Not testing with an interesting callback because workspace.x and workspace.y are not updated # until the end of the algorithm (TODO: be able to evaluate workspace.x and workspace.y ?) A, b, c = square_adjoint(FC=FC) workspace = GpmrWorkspace(A, b; memory = 20) diff --git a/test/test_interface.jl b/test/test_interface.jl index 91f16deea..11ac901fc 100644 --- a/test/test_interface.jl +++ b/test/test_interface.jl @@ -296,6 +296,14 @@ function test_krylov_workspaces(FC; krylov_constructor::Bool=false, use_val::Boo @test solution(workspace, 1) === workspace.x @test solution(workspace, 2) === workspace.y @test solution_count(workspace) == 2 + + if method == :gpmr + x, y, stats = use_val ? @inferred(krylov_solve(Val(method), Ao, Au, TestVector(b), c)) : krylov_solve(method, Ao, Au, TestVector(b), c) + @test typeof(x) === typeof(TestVector(b)) + else + x, y, stats = use_val ? @inferred(krylov_solve(Val(method), Au, TestVector(c), b)) : krylov_solve(method, Au, TestVector(c), b) + @test typeof(x) === typeof(TestVector(c)) + end end if method ∈ (:usymlq, :usymqr) diff --git a/test/test_lsqr.jl b/test/test_lsqr.jl index 8b70cec57..13a9937bf 100644 --- a/test/test_lsqr.jl +++ b/test/test_lsqr.jl @@ -88,6 +88,15 @@ (x, stats) = lsqr(A, b, M=M⁻¹, N=N⁻¹, sqd=true) end + # Test different types for input and output + A, b, c, D = small_sp(false, FC=FC) + workspace = LsqrWorkspace(KrylovConstructor(TestVector(b), c)) + lsqr!(workspace, A, TestVector(b), M=inv(D), λ=1.0) + @test typeof(workspace.x) === typeof(c) + workspace = LsqrWorkspace(KrylovConstructor(b, TestVector(c))) + lsqr!(workspace, A, b, M=inv(D), λ=1.0) + @test typeof(workspace.x) === typeof(TestVector(c)) + # test callback function A, b, M = saddle_point(FC=FC) M⁻¹ = inv(M) diff --git a/test/test_tricg.jl b/test/test_tricg.jl index 7b17b5e06..73178390a 100644 --- a/test/test_tricg.jl +++ b/test/test_tricg.jl @@ -166,6 +166,18 @@ @test sqrt(dot(r, inv(H) * r)) / sqrt(dot([b; c], inv(H) * [b; c])) ≤ tricg_tol end + # Test different types for b and c + A, b, _ = saddle_point(FC=FC) + (x, y, stats) = tricg(A, b, -b) + c = TestVector(-b) + workspace = TricgWorkspace(KrylovConstructor(b, c)) + @test typeof(workspace.x) === typeof(b) + @test typeof(workspace.y) === typeof(c) + tricg!(workspace, A, b, c) + xt, yt = workspace.x, workspace.y + @test typeof(xt) === typeof(b) + @test typeof(yt) === typeof(c) + for transpose ∈ (false, true) A, b, c = ssy_mo_breakdown(transpose) diff --git a/test/test_trimr.jl b/test/test_trimr.jl index eddcc69fc..d8b160143 100644 --- a/test/test_trimr.jl +++ b/test/test_trimr.jl @@ -206,6 +206,18 @@ @test sqrt(dot(r, inv(H) * r)) / sqrt(dot([b; c], inv(H) * [b; c])) ≤ trimr_tol end + # Test different types for b and c + A, b, _ = saddle_point(FC=FC) + (x, y, stats) = trimr(A, b, -b) + c = TestVector(-b) + workspace = TrimrWorkspace(KrylovConstructor(b, c)) + @test typeof(workspace.x) === typeof(b) + @test typeof(workspace.y) === typeof(c) + trimr!(workspace, A, b, c) + xt, yt = workspace.x, workspace.y + @test typeof(xt) === typeof(b) + @test typeof(yt) === typeof(c) + for transpose ∈ (false, true) A, b, c = ssy_mo_breakdown(transpose) diff --git a/test/test_utils.jl b/test/test_utils.jl index 4d71545f2..2ce77d71d 100644 --- a/test/test_utils.jl +++ b/test/test_utils.jl @@ -5,6 +5,15 @@ include("gen_lsq.jl") include("check_min_norm.jl") include("callback_utils.jl") +struct TestVector{T} <: AbstractVector{T} + data::Vector{T} +end +TestVector{T}(::UndefInitializer, n::Integer) where {T} = TestVector{T}(Vector{T}(undef, n)) +Base.size(v::TestVector) = size(v.data) +Base.getindex(v::TestVector, i::Int) = v.data[i] +Base.setindex!(v::TestVector, val, i::Int) = (v.data[i] = val) +Base.similar(v::TestVector, ::Type{S}, dims::Dims) where {S} = TestVector{S}(similar(v.data, S, dims)) + # Symmetric and positive definite systems. function symmetric_definite(n :: Int=10; FC=Float64) α = FC <: Complex ? FC(im) : one(FC)