diff --git a/docs/src/getting_started.md b/docs/src/getting_started.md index ece51177..effba4d2 100644 --- a/docs/src/getting_started.md +++ b/docs/src/getting_started.md @@ -49,7 +49,8 @@ Most solvers contain the `log` keyword. This is to be used when obtaining more information is required, to use it place the set `log` to `true`. ```julia -x, ch = cg(Master, rand(10, 10), rand(10) log=true) +r = cg(Master, rand(10, 10), rand(10) log=true) +x, ch = r.x, r.history svd, L, ch = svdl(Master, rand(100, 100), log=true) ``` diff --git a/docs/src/linear_systems/cg.md b/docs/src/linear_systems/cg.md index 7fde0a13..f32a716b 100644 --- a/docs/src/linear_systems/cg.md +++ b/docs/src/linear_systems/cg.md @@ -20,7 +20,8 @@ n = 100 A = cu(rand(n, n)) A = A + A' + 2*n*I b = cu(rand(n)) -x = cg(A, b) +r = cg(A, b) +x = r.x ``` !!! note diff --git a/src/cg.jl b/src/cg.jl index edfa86e8..d0977ca4 100644 --- a/src/cg.jl +++ b/src/cg.jl @@ -8,7 +8,7 @@ mutable struct CGIterable{matT, solT, vecT, numT <: Real} r::vecT c::vecT u::vecT - reltol::numT + tol::numT residual::numT prev_residual::numT maxiter::Int @@ -22,18 +22,49 @@ mutable struct PCGIterable{precT, matT, solT, vecT, numT <: Real, paramT <: Numb r::vecT c::vecT u::vecT - reltol::numT + tol::numT residual::numT ρ::paramT maxiter::Int mv_products::Int end -@inline converged(it::Union{CGIterable, PCGIterable}) = it.residual ≤ it.reltol +struct CGResult{Tx, T, Thistory} + x::Tx + residual::T + tol::T + iterations::Int + maxiter::Int + converged::Bool + history::Thistory +end +function Base.show(io::IO, r::CGResult) + first_two(fr) = [x for (i, x) in enumerate(fr)][1:2] + + @printf io "Result of CG Algorithm\n" + @printf io " * Algorithm: CG \n" + + if length(join(r.x, ",")) < 40 || length(r.x) <= 2 + @printf io " * x: [%s]\n" join(r.x, ",") + else + @printf io " * x: [%s, ...]\n" join(first_two(r.x), ",") + end + + @printf io " * Convergence\n" + @printf io " * Residual: %s\n" r.residual + @printf io " * Tolerance: %s\n" r.tol + @printf io " * Converged: %s\n" r.converged + @printf io " * Iterations: %s\n" r.iterations + @printf io " * Iterations limit: %s\n" r.maxiter + + return +end + +@inline isconverged(it::Union{CGIterable, PCGIterable}) = it.residual ≤ it.tol @inline start(it::Union{CGIterable, PCGIterable}) = 0 -@inline done(it::Union{CGIterable, PCGIterable}, iteration::Int) = iteration ≥ it.maxiter || converged(it) +@inline done(it::Union{CGIterable, PCGIterable}, iteration::Int) = iteration ≥ it.maxiter || isconverged(it) ############### @@ -114,7 +145,8 @@ struct CGStateVariables{T,Tx<:AbstractArray{T}} end function cg_iterator!(x, A, b, Pl = Identity(); - tol = sqrt(eps(real(eltype(b)))), + reltol = sqrt(eps(real(eltype(b)))), + tol = zero(real(eltype(b))), maxiter::Int = size(A, 2), statevars::CGStateVariables = CGStateVariables(zero(x), similar(x), similar(x)), initially_zero::Bool = false @@ -130,24 +162,24 @@ function cg_iterator!(x, A, b, Pl = Identity(); mv_products = 0 c = similar(x) residual = norm(b) - reltol = residual * tol # Save one dot product + tol = max(residual * reltol, tol) # Save one dot product else mv_products = 1 mul!(c, A, x) r .-= c residual = norm(r) - reltol = norm(b) * tol + tol = max(norm(b) * reltol, tol) end # Return the iterable if isa(Pl, Identity) return CGIterable(A, x, r, c, u, - reltol, residual, one(residual), + tol, residual, one(residual), maxiter, mv_products ) else return PCGIterable(Pl, A, x, r, c, u, - reltol, residual, one(eltype(x)), + tol, residual, one(eltype(x)), maxiter, mv_products ) end @@ -177,7 +209,8 @@ cg(A, b; kwargs...) = cg!(zerox(A, b), A, b; initially_zero = true, kwargs...) residual vector; - `Pl = Identity()`: left preconditioner of the method. Should be symmetric, positive-definite like `A`; -- `tol::Real = sqrt(eps(real(eltype(b))))`: tolerance for stopping condition `|r_k| / |r_0| ≤ tol`; +- `reltol::Real = sqrt(eps(real(eltype(b))))`: relative tolerance for stopping condition `|r_k| / |r_0| ≤ reltol`; +- `tol` = zero(real(eltype(b))): tolerance for stopping condition `|r_k| ≤ tol`, - `maxiter::Int = size(A,2)`: maximum number of iterations; - `verbose::Bool = false`: print method information; - `log::Bool = false`: keep track of the residual norm in each iteration. @@ -199,7 +232,8 @@ cg(A, b; kwargs...) = cg!(zerox(A, b), A, b; initially_zero = true, kwargs...) - `:resnom` => `::Vector`: residual norm at each iteration. """ function cg!(x, A, b; - tol = sqrt(eps(real(eltype(b)))), + reltol = sqrt(eps(real(eltype(b)))), + tol = zero(real(eltype(b))), maxiter::Int = size(A, 2), log::Bool = false, statevars::CGStateVariables = CGStateVariables(zero(x), similar(x), similar(x)), @@ -208,15 +242,17 @@ function cg!(x, A, b; kwargs... ) history = ConvergenceHistory(partial = !log) - history[:tol] = tol log && reserve!(history, :resnorm, maxiter + 1) # Actually perform CG - iterable = cg_iterator!(x, A, b, Pl; tol = tol, maxiter = maxiter, statevars = statevars, kwargs...) + iterable = cg_iterator!(x, A, b, Pl; tol = tol, reltol = reltol, maxiter = maxiter, statevars = statevars, kwargs...) + history[:tol] = iterable.tol if log history.mvps = iterable.mv_products end - for (iteration, item) = enumerate(iterable) + iteration = 0 + for item in iterable + iteration += 1 if log nextiter!(history, mvps = 1) push!(history, :resnorm, iterable.residual) @@ -225,8 +261,9 @@ function cg!(x, A, b; end verbose && println() - log && setconv(history, converged(iterable)) + converged = isconverged(iterable) + log && setconv(history, converged) log && shrink!(history) - log ? (iterable.x, history) : iterable.x + return CGResult(iterable.x, iterable.residual, iterable.tol, iteration, maxiter, converged, history) end diff --git a/test/cg.jl b/test/cg.jl index 4ef1b148..67e2ccad 100644 --- a/test/cg.jl +++ b/test/cg.jl @@ -26,26 +26,31 @@ Random.seed!(1234321) A = rand(T, n, n) A = A' * A + I b = rand(T, n) - tol = √eps(real(T)) + reltol = √eps(real(T)) - x,ch = cg(A, b; tol=tol, maxiter=2n, log=true) + r = cg(A, b; reltol=reltol, maxiter=2n, log=true) + x, ch = r.x, r.history @test isa(ch, ConvergenceHistory) - @test norm(A*x - b) / norm(b) ≤ tol + @test norm(A*x - b) / norm(b) ≤ reltol + @test norm(A*x - b) ≤ r.tol @test ch.isconverged # If you start from the exact solution, you should converge immediately - x,ch = cg!(A \ b, A, b; tol=10tol, log=true) + r = cg!(A \ b, A, b; reltol=10reltol, log=true) + x, ch = r.x, r.history @test niters(ch) ≤ 1 @test nprods(ch) ≤ 2 # Test with cholfact should converge immediately F = cholesky(A, Val(false)) - x,ch = cg(A, b; Pl=F, log=true) + r = cg(A, b; Pl=F, log=true) + x, ch = r.x, r.history @test niters(ch) ≤ 2 @test nprods(ch) ≤ 2 # All-zeros rhs should give all-zeros lhs - x0 = cg(A, zeros(T, n)) + r = cg(A, zeros(T, n)) + x0 = r.x @test x0 == zeros(T, n) end end @@ -59,24 +64,30 @@ end tol = 1e-5 @testset "SparseMatrixCSC{$T, $Ti}" for T in (Float64, Float32), Ti in (Int64, Int32) - xCG = cg(A, rhs; tol=tol, maxiter=100) - xJAC = cg(A, rhs; Pl=P, tol=tol, maxiter=100) + r = cg(A, rhs; tol=tol, maxiter=100) + xCG = r.x + r = cg(A, rhs; Pl=P, tol=tol, maxiter=100) + xJAC = r.x @test norm(A * xCG - rhs) ≤ tol @test norm(A * xJAC - rhs) ≤ tol end Af = LinearMap(A) @testset "Function" begin - xCG = cg(Af, rhs; tol=tol, maxiter=100) - xJAC = cg(Af, rhs; Pl=P, tol=tol, maxiter=100) + r = cg(Af, rhs; tol=tol, maxiter=100) + xCG = r.x + r = cg(Af, rhs; Pl=P, tol=tol, maxiter=100) + xJAC = r.x @test norm(A * xCG - rhs) ≤ tol @test norm(A * xJAC - rhs) ≤ tol end @testset "Function with specified starting guess" begin x0 = randn(size(rhs)) - xCG, hCG = cg!(copy(x0), Af, rhs; tol=tol, maxiter=100, log=true) - xJAC, hJAC = cg!(copy(x0), Af, rhs; Pl=P, tol=tol, maxiter=100, log=true) + r = cg!(copy(x0), Af, rhs; tol=tol, maxiter=100, log=true) + xCG, hCG = r.x, r.history + r = cg!(copy(x0), Af, rhs; Pl=P, tol=tol, maxiter=100, log=true) + xJAC, hJAC = r.x, r.history @test norm(A * xCG - rhs) ≤ tol @test norm(A * xJAC - rhs) ≤ tol @test niters(hJAC) == niters(hCG) @@ -88,7 +99,8 @@ end A = A + A' + 100I x = view(rand(10, 2), :, 1) b = rand(10) - x, hist = cg!(x, A, b, log = true) + r = cg!(x, A, b, log = true) + x, hist = r.x, r.history @test hist.isconverged end