Skip to content

Commit 98c6c66

Browse files
committed
Support TRICG, TRIMR, GPMR
Also add tests
1 parent 1379a43 commit 98c6c66

File tree

12 files changed

+183
-87
lines changed

12 files changed

+183
-87
lines changed

src/gpmr.jl

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -157,7 +157,7 @@ optargs_gpmr = (:x0, :y0)
157157
kwargs_gpmr = (:C, :D, :E, :F, :ldiv, :gsp, , , :reorthogonalization, :atol, :rtol, :itmax, :timemax, :verbose, :history, :callback, :iostream)
158158

159159
@eval begin
160-
function gpmr!(workspace :: GpmrWorkspace{T,FC,S}, $(def_args_gpmr...); $(def_kwargs_gpmr...)) where {T <: AbstractFloat, FC <: FloatOrComplex{T}, S <: AbstractVector{FC}}
160+
function gpmr!(workspace :: GpmrWorkspace{T,FC,Sm,Sn}, $(def_args_gpmr...); $(def_kwargs_gpmr...)) where {T <: AbstractFloat, FC <: FloatOrComplex{T}, Sm <: AbstractVector{FC}, Sn <: AbstractVector{FC}}
161161

162162
# Timer
163163
start_time = time_ns()
@@ -181,8 +181,8 @@ kwargs_gpmr = (:C, :D, :E, :F, :ldiv, :gsp, :λ, :μ, :reorthogonalization, :ato
181181
# Check type consistency
182182
eltype(A) == FC || @warn "eltype(A) ≠ $FC. This could lead to errors or additional allocations in operator-vector products."
183183
eltype(B) == FC || @warn "eltype(B) ≠ $FC. This could lead to errors or additional allocations in operator-vector products."
184-
ktypeof(b) == S || error("ktypeof(b) must be equal to $S")
185-
ktypeof(c) == S || error("ktypeof(c) must be equal to $S")
184+
ktypeof(b) == Sm || error("ktypeof(b) must be equal to $Sm")
185+
ktypeof(c) == Sn || error("ktypeof(c) must be equal to $Sn")
186186

187187
# Determine λ and μ associated to generalized saddle point systems.
188188
gsp &&= one(FC) ; μ = zero(FC))
@@ -192,10 +192,10 @@ kwargs_gpmr = (:C, :D, :E, :F, :ldiv, :gsp, :λ, :μ, :reorthogonalization, :ato
192192
warm_start && 0) && !FisI && error("Warm-start with right preconditioners is not supported.")
193193

194194
# Set up workspace.
195-
allocate_if(!CisI, workspace, :q , S, workspace.x) # The length of q is m
196-
allocate_if(!DisI, workspace, :p , S, workspace.y) # The length of p is n
197-
allocate_if(!EisI, workspace, :wB, S, workspace.x) # The length of wB is m
198-
allocate_if(!FisI, workspace, :wA, S, workspace.y) # The length of wA is n
195+
allocate_if(!CisI, workspace, :q , Sm, workspace.x) # The length of q is m
196+
allocate_if(!DisI, workspace, :p , Sn, workspace.y) # The length of p is n
197+
allocate_if(!EisI, workspace, :wB, Sm, workspace.x) # The length of wB is m
198+
allocate_if(!FisI, workspace, :wA, Sn, workspace.y) # The length of wA is n
199199
wA, wB, dA, dB, Δx, Δy = workspace.wA, workspace.wB, workspace.dA, workspace.dB, workspace.Δx, workspace.Δy
200200
x, y, V, U, gs, gc = workspace.x, workspace.y, workspace.V, workspace.U, workspace.gs, workspace.gc
201201
zt, R, stats = workspace.zt, workspace.R, workspace.stats

src/interface.jl

Lines changed: 30 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -138,7 +138,7 @@ for (workspace, krylov, args, def_args, optargs, def_optargs, kwargs, def_kwargs
138138

139139
krylov_solve(::Val{Symbol($krylov)}, $(def_args...); $(def_kwargs...)) where {T <: AbstractFloat, FC <: FloatOrComplex{T}} = $(krylov)($(args...); $(kwargs...))
140140
end
141-
elseif krylov in (:diom, :dqgmres, :fom, :gmres, :fgmres, :gpmr)
141+
elseif krylov in (:diom, :dqgmres, :fom, :gmres, :fgmres)
142142
@eval begin
143143
function $(krylov)($(def_args...); memory::Int = 20, $(def_kwargs...)) where {T <: AbstractFloat, FC <: FloatOrComplex{T}}
144144
start_time = time_ns()
@@ -196,6 +196,35 @@ for (workspace, krylov, args, def_args, optargs, def_optargs, kwargs, def_kwargs
196196
krylov_solve(::Val{Symbol($krylov)}, $(def_args...), $(def_optargs...); window::Int = 5, $(def_kwargs...)) where {T <: AbstractFloat, FC <: FloatOrComplex{T}} = $(krylov)($(args...), $(optargs...); window, $(kwargs...))
197197
end
198198
end
199+
elseif krylov in (:tricg, :trimr, :gpmr)
200+
@eval begin
201+
function $(krylov)($(def_args...); $(def_kwargs...)) where {T <: AbstractFloat, FC <: FloatOrComplex{T}}
202+
start_time = time_ns()
203+
workspace = $workspace(KrylovConstructor(b, c))
204+
elapsed_time = start_time |> ktimer
205+
timemax -= elapsed_time
206+
$(krylov!)(workspace, $(args...); $(kwargs...))
207+
workspace.stats.timer += elapsed_time
208+
return results(workspace)
209+
end
210+
211+
krylov_solve(::Val{Symbol($krylov)}, $(def_args...); $(def_kwargs...)) where {T <: AbstractFloat, FC <: FloatOrComplex{T}} = $(krylov)($(args...); $(kwargs...))
212+
213+
if !isempty($optargs)
214+
function $(krylov)($(def_args...), $(def_optargs...); $(def_kwargs...)) where {T <: AbstractFloat, FC <: FloatOrComplex{T}}
215+
start_time = time_ns()
216+
workspace = $workspace(KrylovConstructor(b, c))
217+
warm_start!(workspace, $(optargs...))
218+
elapsed_time = start_time |> ktimer
219+
timemax -= elapsed_time
220+
$(krylov!)(workspace, $(args...); $(kwargs...))
221+
workspace.stats.timer += elapsed_time
222+
return results(workspace)
223+
end
224+
225+
krylov_solve(::Val{Symbol($krylov)}, $(def_args...), $(def_optargs...); $(def_kwargs...)) where {T <: AbstractFloat, FC <: FloatOrComplex{T}} = $(krylov)($(args...), $(optargs...); $(kwargs...))
226+
end
227+
end
199228
else
200229
@eval begin
201230
function $(krylov)($(def_args...); $(def_kwargs...)) where {T <: AbstractFloat, FC <: FloatOrComplex{T}}

src/krylov_workspaces.jl

Lines changed: 65 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -988,32 +988,31 @@ The following outer constructors can be used to initialize this workspace:
988988
workspace = TricgWorkspace(A, b)
989989
workspace = TricgWorkspace(kc::KrylovConstructor)
990990
"""
991-
mutable struct TricgWorkspace{T,FC,S} <: KrylovWorkspace{T,FC,S}
991+
mutable struct TricgWorkspace{T,FC,Sm,Sn} <: KrylovWorkspaceNext{T,FC,Sm,Sn}
992992
m :: Int
993993
n :: Int
994-
y :: S
995-
N⁻¹uₖ₋₁ :: S
996-
N⁻¹uₖ :: S
997-
p :: S
998-
gy₂ₖ₋₁ :: S
999-
gy₂ₖ :: S
1000-
x :: S
1001-
M⁻¹vₖ₋₁ :: S
1002-
M⁻¹vₖ :: S
1003-
q :: S
1004-
gx₂ₖ₋₁ :: S
1005-
gx₂ₖ :: S
1006-
Δx :: S
1007-
Δy :: S
1008-
uₖ :: S
1009-
vₖ :: S
994+
y :: Sn
995+
N⁻¹uₖ₋₁ :: Sn
996+
N⁻¹uₖ :: Sn
997+
p :: Sn
998+
gy₂ₖ₋₁ :: Sn
999+
gy₂ₖ :: Sn
1000+
x :: Sm
1001+
M⁻¹vₖ₋₁ :: Sm
1002+
M⁻¹vₖ :: Sm
1003+
q :: Sm
1004+
gx₂ₖ₋₁ :: Sm
1005+
gx₂ₖ :: Sm
1006+
Δx :: Sm
1007+
Δy :: Sn
1008+
uₖ :: Sn
1009+
vₖ :: Sm
10101010
warm_start :: Bool
10111011
stats :: SimpleStats{T}
10121012
end
10131013

1014-
function TricgWorkspace(kc::KrylovConstructor)
1015-
S = typeof(kc.vm)
1016-
FC = eltype(S)
1014+
function TricgWorkspace(kc::KrylovConstructor{Sm,Sn}) where {Sm,Sn}
1015+
FC = eltype(Sm)
10171016
T = real(FC)
10181017
m = length(kc.vm)
10191018
n = length(kc.vn)
@@ -1034,7 +1033,7 @@ function TricgWorkspace(kc::KrylovConstructor)
10341033
uₖ = similar(kc.vn_empty)
10351034
vₖ = similar(kc.vm_empty)
10361035
stats = SimpleStats(0, false, false, false, 0, T[], T[], T[], 0.0, "unknown")
1037-
workspace = TricgWorkspace{T,FC,S}(m, n, y, N⁻¹uₖ₋₁, N⁻¹uₖ, p, gy₂ₖ₋₁, gy₂ₖ, x, M⁻¹vₖ₋₁, M⁻¹vₖ, q, gx₂ₖ₋₁, gx₂ₖ, Δx, Δy, uₖ, vₖ, false, stats)
1036+
workspace = TricgWorkspace{T,FC,Sm,Sn}(m, n, y, N⁻¹uₖ₋₁, N⁻¹uₖ, p, gy₂ₖ₋₁, gy₂ₖ, x, M⁻¹vₖ₋₁, M⁻¹vₖ, q, gx₂ₖ₋₁, gx₂ₖ, Δx, Δy, uₖ, vₖ, false, stats)
10381037
return workspace
10391038
end
10401039

@@ -1059,7 +1058,7 @@ function TricgWorkspace(m::Integer, n::Integer, S::Type)
10591058
vₖ = S(undef, 0)
10601059
S = isconcretetype(S) ? S : typeof(x)
10611060
stats = SimpleStats(0, false, false, false, 0, T[], T[], T[], 0.0, "unknown")
1062-
workspace = TricgWorkspace{T,FC,S}(m, n, y, N⁻¹uₖ₋₁, N⁻¹uₖ, p, gy₂ₖ₋₁, gy₂ₖ, x, M⁻¹vₖ₋₁, M⁻¹vₖ, q, gx₂ₖ₋₁, gx₂ₖ, Δx, Δy, uₖ, vₖ, false, stats)
1061+
workspace = TricgWorkspace{T,FC,S,S}(m, n, y, N⁻¹uₖ₋₁, N⁻¹uₖ, p, gy₂ₖ₋₁, gy₂ₖ, x, M⁻¹vₖ₋₁, M⁻¹vₖ, q, gx₂ₖ₋₁, gx₂ₖ, Δx, Δy, uₖ, vₖ, false, stats)
10631062
return workspace
10641063
end
10651064

@@ -1078,36 +1077,35 @@ The following outer constructors can be used to initialize this workspace:
10781077
workspace = TrimrWorkspace(A, b)
10791078
workspace = TrimrWorkspace(kc::KrylovConstructor)
10801079
"""
1081-
mutable struct TrimrWorkspace{T,FC,S} <: KrylovWorkspace{T,FC,S}
1080+
mutable struct TrimrWorkspace{T,FC,Sm,Sn} <: KrylovWorkspaceNext{T,FC,Sm,Sn}
10821081
m :: Int
10831082
n :: Int
1084-
y :: S
1085-
N⁻¹uₖ₋₁ :: S
1086-
N⁻¹uₖ :: S
1087-
p :: S
1088-
gy₂ₖ₋₃ :: S
1089-
gy₂ₖ₋₂ :: S
1090-
gy₂ₖ₋₁ :: S
1091-
gy₂ₖ :: S
1092-
x :: S
1093-
M⁻¹vₖ₋₁ :: S
1094-
M⁻¹vₖ :: S
1095-
q :: S
1096-
gx₂ₖ₋₃ :: S
1097-
gx₂ₖ₋₂ :: S
1098-
gx₂ₖ₋₁ :: S
1099-
gx₂ₖ :: S
1100-
Δx :: S
1101-
Δy :: S
1102-
uₖ :: S
1103-
vₖ :: S
1083+
y :: Sn
1084+
N⁻¹uₖ₋₁ :: Sn
1085+
N⁻¹uₖ :: Sn
1086+
p :: Sn
1087+
gy₂ₖ₋₃ :: Sn
1088+
gy₂ₖ₋₂ :: Sn
1089+
gy₂ₖ₋₁ :: Sn
1090+
gy₂ₖ :: Sn
1091+
x :: Sm
1092+
M⁻¹vₖ₋₁ :: Sm
1093+
M⁻¹vₖ :: Sm
1094+
q :: Sm
1095+
gx₂ₖ₋₃ :: Sm
1096+
gx₂ₖ₋₂ :: Sm
1097+
gx₂ₖ₋₁ :: Sm
1098+
gx₂ₖ :: Sm
1099+
Δx :: Sm
1100+
Δy :: Sn
1101+
uₖ :: Sn
1102+
vₖ :: Sm
11041103
warm_start :: Bool
11051104
stats :: SimpleStats{T}
11061105
end
11071106

1108-
function TrimrWorkspace(kc::KrylovConstructor)
1109-
S = typeof(kc.vm)
1110-
FC = eltype(S)
1107+
function TrimrWorkspace(kc::KrylovConstructor{Sm,Sn}) where {Sm,Sn}
1108+
FC = eltype(Sm)
11111109
T = real(FC)
11121110
m = length(kc.vm)
11131111
n = length(kc.vn)
@@ -1132,7 +1130,7 @@ function TrimrWorkspace(kc::KrylovConstructor)
11321130
uₖ = similar(kc.vn_empty)
11331131
vₖ = similar(kc.vm_empty)
11341132
stats = SimpleStats(0, false, false, false, 0, T[], T[], T[], 0.0, "unknown")
1135-
workspace = TrimrWorkspace{T,FC,S}(m, n, y, N⁻¹uₖ₋₁, N⁻¹uₖ, p, gy₂ₖ₋₃, gy₂ₖ₋₂, gy₂ₖ₋₁, gy₂ₖ, x, M⁻¹vₖ₋₁, M⁻¹vₖ, q, gx₂ₖ₋₃, gx₂ₖ₋₂, gx₂ₖ₋₁, gx₂ₖ, Δx, Δy, uₖ, vₖ, false, stats)
1133+
workspace = TrimrWorkspace{T,FC,Sm,Sn}(m, n, y, N⁻¹uₖ₋₁, N⁻¹uₖ, p, gy₂ₖ₋₃, gy₂ₖ₋₂, gy₂ₖ₋₁, gy₂ₖ, x, M⁻¹vₖ₋₁, M⁻¹vₖ, q, gx₂ₖ₋₃, gx₂ₖ₋₂, gx₂ₖ₋₁, gx₂ₖ, Δx, Δy, uₖ, vₖ, false, stats)
11361134
return workspace
11371135
end
11381136

@@ -1161,7 +1159,7 @@ function TrimrWorkspace(m::Integer, n::Integer, S::Type)
11611159
vₖ = S(undef, 0)
11621160
S = isconcretetype(S) ? S : typeof(x)
11631161
stats = SimpleStats(0, false, false, false, 0, T[], T[], T[], 0.0, "unknown")
1164-
workspace = TrimrWorkspace{T,FC,S}(m, n, y, N⁻¹uₖ₋₁, N⁻¹uₖ, p, gy₂ₖ₋₃, gy₂ₖ₋₂, gy₂ₖ₋₁, gy₂ₖ, x, M⁻¹vₖ₋₁, M⁻¹vₖ, q, gx₂ₖ₋₃, gx₂ₖ₋₂, gx₂ₖ₋₁, gx₂ₖ, Δx, Δy, uₖ, vₖ, false, stats)
1162+
workspace = TrimrWorkspace{T,FC,S,S}(m, n, y, N⁻¹uₖ₋₁, N⁻¹uₖ, p, gy₂ₖ₋₃, gy₂ₖ₋₂, gy₂ₖ₋₁, gy₂ₖ, x, M⁻¹vₖ₋₁, M⁻¹vₖ, q, gx₂ₖ₋₃, gx₂ₖ₋₂, gx₂ₖ₋₁, gx₂ₖ, Δx, Δy, uₖ, vₖ, false, stats)
11651163
return workspace
11661164
end
11671165

@@ -2621,21 +2619,21 @@ The following outer constructors can be used to initialize this workspace:
26212619
26222620
`memory` is set to `n + m` if the value given is larger than `n + m`.
26232621
"""
2624-
mutable struct GpmrWorkspace{T,FC,S} <: KrylovWorkspace{T,FC,S}
2622+
mutable struct GpmrWorkspace{T,FC,Sm,Sn} <: KrylovWorkspaceNext{T,FC,Sm,Sn}
26252623
m :: Int
26262624
n :: Int
2627-
wA :: S
2628-
wB :: S
2629-
dA :: S
2630-
dB :: S
2631-
Δx :: S
2632-
Δy :: S
2633-
x :: S
2634-
y :: S
2635-
q :: S
2636-
p :: S
2637-
V :: Vector{S}
2638-
U :: Vector{S}
2625+
wA :: Sn
2626+
wB :: Sm
2627+
dA :: Sm
2628+
dB :: Sn
2629+
Δx :: Sm
2630+
Δy :: Sn
2631+
x :: Sm
2632+
y :: Sn
2633+
q :: Sm
2634+
p :: Sn
2635+
V :: Vector{Sm}
2636+
U :: Vector{Sn}
26392637
gs :: Vector{FC}
26402638
gc :: Vector{T}
26412639
zt :: Vector{FC}
@@ -2644,9 +2642,8 @@ mutable struct GpmrWorkspace{T,FC,S} <: KrylovWorkspace{T,FC,S}
26442642
stats :: SimpleStats{T}
26452643
end
26462644

2647-
function GpmrWorkspace(kc::KrylovConstructor; memory::Integer = 20)
2648-
S = typeof(kc.vm)
2649-
FC = eltype(S)
2645+
function GpmrWorkspace(kc::KrylovConstructor{Sm,Sn}; memory::Integer = 20) where {Sm,Sn}
2646+
FC = eltype(Sm)
26502647
T = real(FC)
26512648
m = length(kc.vm)
26522649
n = length(kc.vn)
@@ -2661,14 +2658,14 @@ function GpmrWorkspace(kc::KrylovConstructor; memory::Integer = 20)
26612658
y = similar(kc.vn)
26622659
q = similar(kc.vm_empty)
26632660
p = similar(kc.vn_empty)
2664-
V = S[similar(kc.vm) for i = 1 : memory]
2665-
U = S[similar(kc.vn) for i = 1 : memory]
2661+
V = Sm[similar(kc.vm) for i = 1 : memory]
2662+
U = Sn[similar(kc.vn) for i = 1 : memory]
26662663
gs = Vector{FC}(undef, 4 * memory)
26672664
gc = Vector{T}(undef, 4 * memory)
26682665
zt = Vector{FC}(undef, 2 * memory)
26692666
R = Vector{FC}(undef, memory * (2 * memory + 1))
26702667
stats = SimpleStats(0, false, false, false, 0, T[], T[], T[], 0.0, "unknown")
2671-
workspace = GpmrWorkspace{T,FC,S}(m, n, wA, wB, dA, dB, Δx, Δy, x, y, q, p, V, U, gs, gc, zt, R, false, stats)
2668+
workspace = GpmrWorkspace{T,FC,Sm,Sn}(m, n, wA, wB, dA, dB, Δx, Δy, x, y, q, p, V, U, gs, gc, zt, R, false, stats)
26722669
return workspace
26732670
end
26742671

@@ -2694,7 +2691,7 @@ function GpmrWorkspace(m::Integer, n::Integer, S::Type; memory::Integer = 20)
26942691
R = Vector{FC}(undef, memory * (2 * memory + 1))
26952692
S = isconcretetype(S) ? S : typeof(x)
26962693
stats = SimpleStats(0, false, false, false, 0, T[], T[], T[], 0.0, "unknown")
2697-
workspace = GpmrWorkspace{T,FC,S}(m, n, wA, wB, dA, dB, Δx, Δy, x, y, q, p, V, U, gs, gc, zt, R, false, stats)
2694+
workspace = GpmrWorkspace{T,FC,S,S}(m, n, wA, wB, dA, dB, Δx, Δy, x, y, q, p, V, U, gs, gc, zt, R, false, stats)
26982695
return workspace
26992696
end
27002697

src/tricg.jl

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -142,7 +142,7 @@ optargs_tricg = (:x0, :y0)
142142
kwargs_tricg = (:M, :N, :ldiv, :spd, :snd, :flip, , , :atol, :rtol, :itmax, :timemax, :verbose, :history, :callback, :iostream)
143143

144144
@eval begin
145-
function tricg!(workspace :: TricgWorkspace{T,FC,S}, $(def_args_tricg...); $(def_kwargs_tricg...)) where {T <: AbstractFloat, FC <: FloatOrComplex{T}, S <: AbstractVector{FC}}
145+
function tricg!(workspace :: TricgWorkspace{T,FC,Sm,Sn}, $(def_args_tricg...); $(def_kwargs_tricg...)) where {T <: AbstractFloat, FC <: FloatOrComplex{T}, Sm <: AbstractVector{FC}, Sn <: AbstractVector{FC}}
146146

147147
# Timer
148148
start_time = time_ns()
@@ -165,8 +165,8 @@ kwargs_tricg = (:M, :N, :ldiv, :spd, :snd, :flip, :τ, :ν, :atol, :rtol, :itmax
165165

166166
# Check type consistency
167167
eltype(A) == FC || @warn "eltype(A) ≠ $FC. This could lead to errors or additional allocations in operator-vector products."
168-
ktypeof(b) == S || error("ktypeof(b) must be equal to $S")
169-
ktypeof(c) == S || error("ktypeof(c) must be equal to $S")
168+
ktypeof(b) == Sm || error("ktypeof(b) must be equal to $Sm")
169+
ktypeof(c) == Sn || error("ktypeof(c) must be equal to $Sn")
170170

171171
# Determine τ and ν associated to SQD, SPD or SND systems.
172172
flip &&= -one(T) ; ν = one(T))
@@ -181,8 +181,8 @@ kwargs_tricg = (:M, :N, :ldiv, :spd, :snd, :flip, :τ, :ν, :atol, :rtol, :itmax
181181
Aᴴ = A'
182182

183183
# Set up workspace.
184-
allocate_if(!MisI, workspace, :vₖ, S, workspace.x) # The length of vₖ is m
185-
allocate_if(!NisI, workspace, :uₖ, S, workspace.y) # The length of uₖ is n
184+
allocate_if(!MisI, workspace, :vₖ, Sm, workspace.x) # The length of vₖ is m
185+
allocate_if(!NisI, workspace, :uₖ, Sn, workspace.y) # The length of uₖ is n
186186
Δy, yₖ, N⁻¹uₖ₋₁, N⁻¹uₖ, p = workspace.Δy, workspace.y, workspace.N⁻¹uₖ₋₁, workspace.N⁻¹uₖ, workspace.p
187187
Δx, xₖ, M⁻¹vₖ₋₁, M⁻¹vₖ, q = workspace.Δx, workspace.x, workspace.M⁻¹vₖ₋₁, workspace.M⁻¹vₖ, workspace.q
188188
gy₂ₖ₋₁, gy₂ₖ, gx₂ₖ₋₁, gx₂ₖ = workspace.gy₂ₖ₋₁, workspace.gy₂ₖ, workspace.gx₂ₖ₋₁, workspace.gx₂ₖ

src/trimr.jl

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -143,7 +143,7 @@ optargs_trimr = (:x0, :y0)
143143
kwargs_trimr = (:M, :N, :ldiv, :spd, :snd, :flip, :sp, , , :atol, :rtol, :itmax, :timemax, :verbose, :history, :callback, :iostream)
144144

145145
@eval begin
146-
function trimr!(workspace :: TrimrWorkspace{T,FC,S}, $(def_args_trimr...); $(def_kwargs_trimr...)) where {T <: AbstractFloat, FC <: FloatOrComplex{T}, S <: AbstractVector{FC}}
146+
function trimr!(workspace :: TrimrWorkspace{T,FC,Sm,Sn}, $(def_args_trimr...); $(def_kwargs_trimr...)) where {T <: AbstractFloat, FC <: FloatOrComplex{T}, Sm <: AbstractVector{FC}, Sn <: AbstractVector{FC}}
147147

148148
# Timer
149149
start_time = time_ns()
@@ -169,8 +169,8 @@ kwargs_trimr = (:M, :N, :ldiv, :spd, :snd, :flip, :sp, :τ, :ν, :atol, :rtol, :
169169

170170
# Check type consistency
171171
eltype(A) == FC || @warn "eltype(A) ≠ $FC. This could lead to errors or additional allocations in operator-vector products."
172-
ktypeof(b) == S || error("ktypeof(b) must be equal to $S")
173-
ktypeof(c) == S || error("ktypeof(c) must be equal to $S")
172+
ktypeof(b) == Sm || error("ktypeof(b) must be equal to $Sm")
173+
ktypeof(c) == Sn || error("ktypeof(c) must be equal to $Sn")
174174

175175
# Determine τ and ν associated to SQD, SPD or SND systems.
176176
flip &&= -one(T) ; ν = one(T))
@@ -186,8 +186,8 @@ kwargs_trimr = (:M, :N, :ldiv, :spd, :snd, :flip, :sp, :τ, :ν, :atol, :rtol, :
186186
Aᴴ = A'
187187

188188
# Set up workspace.
189-
allocate_if(!MisI, workspace, :vₖ, S, workspace.x) # The length of vₖ is m
190-
allocate_if(!NisI, workspace, :uₖ, S, workspace.y) # The length of uₖ is n
189+
allocate_if(!MisI, workspace, :vₖ, Sm, workspace.x) # The length of vₖ is m
190+
allocate_if(!NisI, workspace, :uₖ, Sn, workspace.y) # The length of uₖ is n
191191
Δy, yₖ, N⁻¹uₖ₋₁, N⁻¹uₖ, p = workspace.Δy, workspace.y, workspace.N⁻¹uₖ₋₁, workspace.N⁻¹uₖ, workspace.p
192192
Δx, xₖ, M⁻¹vₖ₋₁, M⁻¹vₖ, q = workspace.Δx, workspace.x, workspace.M⁻¹vₖ₋₁, workspace.M⁻¹vₖ, workspace.q
193193
gy₂ₖ₋₃, gy₂ₖ₋₂, gy₂ₖ₋₁, gy₂ₖ = workspace.gy₂ₖ₋₃, workspace.gy₂ₖ₋₂, workspace.gy₂ₖ₋₁, workspace.gy₂ₖ

test/test_cgls.jl

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,15 @@
4747
(x, stats) = cgls(A, b, M=D⁻¹, λ=1.0)
4848
end
4949

50+
# Test different types for input and output
51+
A, b, c, D = small_sp(false, FC=FC)
52+
workspace = CglsWorkspace(KrylovConstructor(TestVector(b), c))
53+
cgls!(workspace, A, TestVector(b), M=inv(D), λ=1.0)
54+
@test typeof(workspace.x) === typeof(c)
55+
workspace = CglsWorkspace(KrylovConstructor(b, TestVector(c)))
56+
cgls!(workspace, A, b, M=inv(D), λ=1.0)
57+
@test typeof(workspace.x) === typeof(TestVector(c))
58+
5059
# test callback function
5160
A, b, M = saddle_point(FC=FC)
5261
M⁻¹ = inv(M)

0 commit comments

Comments
 (0)