Skip to content

Commit 843465c

Browse files
authored
Update the API for workspaces (#975)
* Update the API for workspaces * Fix the tests on GPUs * More update for the solvers * Use S::Type
1 parent a3bbee9 commit 843465c

File tree

9 files changed

+147
-155
lines changed

9 files changed

+147
-155
lines changed

docs/src/inplace.md

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,10 @@ XyzSolver(kc::KrylovConstructor)
1515
If the name of the Krylov method contains an underscore (e.g., `minres_qlp` or `cgls_lanczos_shift`), the workspace constructor transforms it by capitalizing each word and removing underscores, resulting in names like `MinresQlpSolver` or `CglsLanczosShiftSolver`.
1616

1717
Given an operator `A` and a right-hand side `b`, you can create a `KrylovSolver` based on the size of `A` and the type of `b`, or explicitly provide the dimensions `(m, n)` and the storage type `S`.
18+
19+
!!! note
20+
The constructors of `CgLanczosShiftSolver` and `CglsLanczosShiftSolver` require an additional argument `nshifts`.
21+
1822
We assume that `S(undef, 0)`, `S(undef, n)`, and `S(undef, m)` are well-defined for the storage type `S`.
1923
For more advanced vector types, workspaces can also be created with the help of a `KrylovConstructor`.
2024
```@docs
@@ -24,20 +28,20 @@ See the section [custom workspaces](@ref custom_workspaces) for an example where
2428

2529
For example, use `S = Vector{Float64}` if you want to solve linear systems in double precision on the CPU and `S = CuVector{Float32}` if you want to solve linear systems in single precision on an Nvidia GPU.
2630

27-
!!! note
28-
`DiomSolver`, `FomSolver`, `DqgmresSolver`, `GmresSolver`, `BlockGmresSolver`, `FgmresSolver`, `GpmrSolver`, `CgLanczosShiftSolver` and `CglsLanczosShiftSolver` require an additional argument (`memory` or `nshifts`).
29-
3031
The workspace is always the first argument of the in-place methods:
3132

3233
```@solvers
33-
minres_solver = MinresSolver(n, n, Vector{Float64})
34+
minres_solver = MinresSolver(m, n, Vector{Float64})
3435
minres!(minres_solver, A1, b1)
3536
36-
dqgmres_solver = DqgmresSolver(n, n, memory, Vector{BigFloat})
37-
dqgmres!(dqgmres_solver, A2, b2)
37+
bicgstab_solver = BicgstabSolver(m, n, Vector{ComplexF64})
38+
bicgstab!(bicgstab_solver, A2, b2)
39+
40+
gmres_solver = GmresSolver(m, n, Vector{BigFloat})
41+
gmres!(gmres_solver, A3, b3)
3842
3943
lsqr_solver = LsqrSolver(m, n, CuVector{Float32})
40-
lsqr!(lsqr_solver, A3, b3)
44+
lsqr!(lsqr_solver, A4, b4)
4145
```
4246

4347
A generic function `solve!` is also available and dispatches to the appropriate Krylov method.
@@ -46,6 +50,9 @@ A generic function `solve!` is also available and dispatches to the appropriate
4650
Krylov.solve!
4751
```
4852

53+
!!! note
54+
The function `solve!` is not exported to prevent potential conflicts with other Julia packages.
55+
4956
In-place methods return an updated `solver` workspace.
5057
Solutions and statistics can be recovered via `solver.x`, `solver.y` and `solver.stats`.
5158
Functions `solution`, `statistics` and `results` can be also used.

src/block_krylov_processes.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,7 @@ function nonhermitian_lanczos(A, B::AbstractMatrix{FC}, C::AbstractMatrix{FC}, k
119119
m, n = size(A)
120120
t, p = size(B)
121121
Aᴴ = A'
122-
pivoting = VERSION < v"1.9" ? Val{false}() : NoPivot()
122+
pivoting = NoPivot()
123123

124124
nnzT = p*p + (k-1)*p*(2*p+1) + div(p*(p+1), 2)
125125
colptr = zeros(Int, p*k+1)

src/block_krylov_solvers.jl

Lines changed: 7 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@ The outer constructors
1717
solver = BlockMinresSolver(A, B)
1818
1919
may be used in order to create these vectors.
20-
`memory` is set to `div(n,p)` if the value given is larger than `div(n,p)`.
2120
"""
2221
mutable struct BlockMinresSolver{T,FC,SV,SM} <: BlockKrylovSolver{T,FC,SV,SM}
2322
m :: Int
@@ -42,7 +41,7 @@ mutable struct BlockMinresSolver{T,FC,SV,SM} <: BlockKrylovSolver{T,FC,SV,SM}
4241
stats :: SimpleStats{T}
4342
end
4443

45-
function BlockMinresSolver(m, n, p, SV, SM)
44+
function BlockMinresSolver(m::Integer, n::Integer, p::Integer, SV::Type, SM::Type)
4645
FC = eltype(SV)
4746
T = real(FC)
4847
ΔX = SM(undef, 0, 0)
@@ -78,12 +77,11 @@ Type for storing the vectors required by the in-place version of BLOCK-GMRES.
7877
7978
The outer constructors
8079
81-
solver = BlockGmresSolver(m, n, p, memory, SV, SM)
82-
solver = BlockGmresSolver(A, B, memory = 5)
80+
solver = BlockGmresSolver(m, n, p, SV, SM; memory = 5)
81+
solver = BlockGmresSolver(A, B; memory = 5)
8382
8483
may be used in order to create these vectors.
8584
`memory` is set to `div(n,p)` if the value given is larger than `div(n,p)`.
86-
`memory` is an optional argument in the second constructor.
8785
"""
8886
mutable struct BlockGmresSolver{T,FC,SV,SM} <: BlockKrylovSolver{T,FC,SV,SM}
8987
m :: Int
@@ -105,7 +103,7 @@ mutable struct BlockGmresSolver{T,FC,SV,SM} <: BlockKrylovSolver{T,FC,SV,SM}
105103
stats :: SimpleStats{T}
106104
end
107105

108-
function BlockGmresSolver(m, n, p, memory, SV, SM)
106+
function BlockGmresSolver(m::Integer, n::Integer, p::Integer, SV::Type, SM::Type; memory::Integer = 5)
109107
memory = min(div(n,p), memory)
110108
FC = eltype(SV)
111109
T = real(FC)
@@ -126,12 +124,12 @@ function BlockGmresSolver(m, n, p, memory, SV, SM)
126124
return solver
127125
end
128126

129-
function BlockGmresSolver(A, B, memory = 5)
127+
function BlockGmresSolver(A, B; memory::Integer = 5)
130128
m, n = size(A)
131129
s, p = size(B)
132130
SM = typeof(B)
133131
SV = matrix_to_vector(SM)
134-
BlockGmresSolver(m, n, p, memory, SV, SM)
132+
BlockGmresSolver(m, n, p, SV, SM; memory)
135133
end
136134

137135
for (KS, fun, nsol, nA, nAt, warm_start) in [
@@ -220,11 +218,7 @@ function show(io :: IO, solver :: Union{KrylovSolver{T,FC,S}, BlockKrylovSolver{
220218
type_i = fieldtype(workspace, i)
221219
field_i = getfield(solver, name_i)
222220
size_i = ksizeof(field_i)
223-
if (name_i::Symbol in [:w̅, :w̄, :d̅]) && (VERSION < v"1.8.0-DEV")
224-
(size_i 0) && Printf.format(io, format2, string(name_i), type_i, format_bytes(size_i))
225-
else
226-
(size_i 0) && Printf.format(io, format, string(name_i), type_i, format_bytes(size_i))
227-
end
221+
(size_i 0) && Printf.format(io, format, string(name_i), type_i, format_bytes(size_i))
228222
end
229223
@printf(io, "└%s┴%s┴%s┘\n",""^l1,""^l2,""^l3)
230224
if show_stats

src/krylov_solve.jl

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@
1818
solve!(solver, args...; kwargs...)
1919
2020
Generic function that dispatches to the appropriate in-place Krylov method based on the type of `solver`.
21-
This function is not exported to prevent potential conflicts with other Julia packages.
2221
"""
2322
function solve! end
2423

@@ -82,7 +81,7 @@ for (workspace, krylov, args, def_args, optargs, def_optargs, kwargs, def_kwargs
8281
@eval begin
8382
function $(krylov)($(def_args...); memory::Int=20, $(def_kwargs...)) where {T <: AbstractFloat, FC <: FloatOrComplex{T}}
8483
start_time = time_ns()
85-
solver = $workspace(A, b, memory)
84+
solver = $workspace(A, b; memory)
8685
elapsed_time = start_time |> ktimer
8786
timemax -= elapsed_time
8887
$(krylov!)(solver, $(args...); $(kwargs...))
@@ -93,7 +92,7 @@ for (workspace, krylov, args, def_args, optargs, def_optargs, kwargs, def_kwargs
9392
if !isempty($optargs)
9493
function $(krylov)($(def_args...), $(def_optargs...); memory::Int=20, $(def_kwargs...)) where {T <: AbstractFloat, FC <: FloatOrComplex{T}}
9594
start_time = time_ns()
96-
solver = $workspace(A, b, memory)
95+
solver = $workspace(A, b; memory)
9796
warm_start!(solver, $(optargs...))
9897
elapsed_time = start_time |> ktimer
9998
timemax -= elapsed_time
@@ -188,7 +187,7 @@ for (workspace, krylov, args, def_args, optargs, def_optargs, kwargs, def_kwargs
188187
@eval begin
189188
function $(krylov)($(def_args...); memory :: Int=20, $(def_kwargs...)) where {T <: AbstractFloat, FC <: FloatOrComplex{T}}
190189
start_time = time_ns()
191-
solver = $workspace(A, B, memory)
190+
solver = $workspace(A, B; memory)
192191
elapsed_time = ktimer(start_time)
193192
timemax -= elapsed_time
194193
$(krylov!)(solver, $(args...); $(kwargs...))
@@ -199,7 +198,7 @@ for (workspace, krylov, args, def_args, optargs, def_optargs, kwargs, def_kwargs
199198
if !isempty($optargs)
200199
function $(krylov)($(def_args...), $(def_optargs...); memory :: Int=20, $(def_kwargs...)) where {T <: AbstractFloat, FC <: FloatOrComplex{T}}
201200
start_time = time_ns()
202-
solver = $workspace(A, B, memory)
201+
solver = $workspace(A, B; memory)
203202
warm_start!(solver, $(optargs...))
204203
elapsed_time = ktimer(start_time)
205204
timemax -= elapsed_time

0 commit comments

Comments
 (0)