diff --git a/src/cg.jl b/src/cg.jl index 5c38c5ea..4fd6574a 100644 --- a/src/cg.jl +++ b/src/cg.jl @@ -48,15 +48,15 @@ function iterate(it::CGIterable, iteration::Int=start(it)) # u := r + βu (almost an axpy) β = it.residual^2 / it.prev_residual^2 - it.u .= it.r .+ β .* it.u + axpby!(true, it.r, β, it.u) # c = A * u mul!(it.c, it.A, it.u) α = it.residual^2 / dot(it.u, it.c) # Improve solution and residual - it.x .+= α .* it.u - it.r .-= α .* it.c + axpy!(α, it.u, it.x) + axpy!(-α, it.c, it.r) it.prev_residual = it.residual it.residual = norm(it.r) @@ -83,15 +83,15 @@ function iterate(it::PCGIterable, iteration::Int=start(it)) # u := c + βu (almost an axpy) β = it.ρ / ρ_prev - it.u .= it.c .+ β .* it.u + axpby!(true, it.c, β, it.u) # c = A * u mul!(it.c, it.A, it.u) α = it.ρ / dot(it.u, it.c) # Improve solution and residual - it.x .+= α .* it.u - it.r .-= α .* it.c + axpy!(α, it.u, it.x) + axpy!(-α, it.c, it.r) it.residual = norm(it.r) @@ -135,7 +135,7 @@ function cg_iterator!(x, A, b, Pl = Identity(); else mv_products = 1 mul!(c, A, x) - r .-= c + axpy!(-one(eltype(c)), c, r) end residual = norm(r) tolerance = max(reltol * residual, abstol)