|
| 1 | +""" |
| 2 | + GaussNewton(; concrete_jac = nothing, linsolve = nothing, precs = DEFAULT_PRECS, |
| 3 | + adkwargs...) |
| 4 | +
|
| 5 | +An advanced GaussNewton implementation with support for efficient handling of sparse |
| 6 | +matrices via colored automatic differentiation and preconditioned linear solvers. Designed |
| 7 | +for large-scale and numerically-difficult nonlinear least squares problems. |
| 8 | +
|
| 9 | +!!! note |
| 10 | + In most practical situations, users should prefer using `LevenbergMarquardt` instead! It |
| 11 | + is a more general extension of `Gauss-Newton` Method. |
| 12 | +
|
| 13 | +### Keyword Arguments |
| 14 | +
|
| 15 | + - `autodiff`: determines the backend used for the Jacobian. Note that this argument is |
| 16 | + ignored if an analytical Jacobian is passed, as that will be used instead. Defaults to |
| 17 | + `AutoForwardDiff()`. Valid choices are types from ADTypes.jl. |
| 18 | + - `concrete_jac`: whether to build a concrete Jacobian. If a Krylov-subspace method is used, |
| 19 | + then the Jacobian will not be constructed and instead direct Jacobian-vector products |
| 20 | + `J*v` are computed using forward-mode automatic differentiation or finite differencing |
| 21 | + tricks (without ever constructing the Jacobian). However, if the Jacobian is still needed, |
| 22 | + for example for a preconditioner, `concrete_jac = true` can be passed in order to force |
| 23 | + the construction of the Jacobian. |
| 24 | + - `linsolve`: the [LinearSolve.jl](https://github.com/SciML/LinearSolve.jl) used for the |
| 25 | + linear solves within the Newton method. Defaults to `nothing`, which means it uses the |
| 26 | + LinearSolve.jl default algorithm choice. For more information on available algorithm |
| 27 | + choices, see the [LinearSolve.jl documentation](https://docs.sciml.ai/LinearSolve/stable/). |
| 28 | + - `precs`: the choice of preconditioners for the linear solver. Defaults to using no |
| 29 | + preconditioners. For more information on specifying preconditioners for LinearSolve |
| 30 | + algorithms, consult the |
| 31 | + [LinearSolve.jl documentation](https://docs.sciml.ai/LinearSolve/stable/). |
| 32 | +
|
| 33 | +!!! warning |
| 34 | +
|
| 35 | + Jacobian-Free version of `GaussNewton` doesn't work yet, and it forces jacobian |
| 36 | + construction. This will be fixed in the near future. |
| 37 | +""" |
| 38 | +@concrete struct GaussNewton{CJ, AD} <: AbstractNewtonAlgorithm{CJ, AD} |
| 39 | + ad::AD |
| 40 | + linsolve |
| 41 | + precs |
| 42 | +end |
| 43 | + |
| 44 | +function GaussNewton(; concrete_jac = nothing, linsolve = NormalCholeskyFactorization(), |
| 45 | + precs = DEFAULT_PRECS, adkwargs...) |
| 46 | + ad = default_adargs_to_adtype(; adkwargs...) |
| 47 | + return GaussNewton{_unwrap_val(concrete_jac)}(ad, linsolve, precs) |
| 48 | +end |
| 49 | + |
| 50 | +@concrete mutable struct GaussNewtonCache{iip} <: AbstractNonlinearSolveCache{iip} |
| 51 | + f |
| 52 | + alg |
| 53 | + u |
| 54 | + fu1 |
| 55 | + fu2 |
| 56 | + fu_new |
| 57 | + du |
| 58 | + p |
| 59 | + uf |
| 60 | + linsolve |
| 61 | + J |
| 62 | + JᵀJ |
| 63 | + Jᵀf |
| 64 | + jac_cache |
| 65 | + force_stop |
| 66 | + maxiters::Int |
| 67 | + internalnorm |
| 68 | + retcode::ReturnCode.T |
| 69 | + abstol |
| 70 | + prob |
| 71 | + stats::NLStats |
| 72 | +end |
| 73 | + |
| 74 | +function SciMLBase.__init(prob::NonlinearLeastSquaresProblem{uType, iip}, alg::GaussNewton, |
| 75 | + args...; alias_u0 = false, maxiters = 1000, abstol = 1e-6, internalnorm = DEFAULT_NORM, |
| 76 | + kwargs...) where {uType, iip} |
| 77 | + @unpack f, u0, p = prob |
| 78 | + u = alias_u0 ? u0 : deepcopy(u0) |
| 79 | + if iip |
| 80 | + fu1 = f.resid_prototype === nothing ? zero(u) : f.resid_prototype |
| 81 | + f(fu1, u, p) |
| 82 | + else |
| 83 | + fu1 = f(u, p) |
| 84 | + end |
| 85 | + uf, linsolve, J, fu2, jac_cache, du = jacobian_caches(alg, f, u, p, Val(iip)) |
| 86 | + |
| 87 | + JᵀJ = J isa Number ? zero(J) : similar(J, size(J, 2), size(J, 2)) |
| 88 | + Jᵀf = zero(u) |
| 89 | + |
| 90 | + return GaussNewtonCache{iip}(f, alg, u, fu1, fu2, zero(fu1), du, p, uf, linsolve, J, |
| 91 | + JᵀJ, Jᵀf, jac_cache, false, maxiters, internalnorm, ReturnCode.Default, abstol, |
| 92 | + prob, NLStats(1, 0, 0, 0, 0)) |
| 93 | +end |
| 94 | + |
| 95 | +function perform_step!(cache::GaussNewtonCache{true}) |
| 96 | + @unpack u, fu1, f, p, alg, J, JᵀJ, Jᵀf, linsolve, du = cache |
| 97 | + jacobian!!(J, cache) |
| 98 | + mul!(JᵀJ, J', J) |
| 99 | + mul!(Jᵀf, J', fu1) |
| 100 | + |
| 101 | + # u = u - J \ fu |
| 102 | + linres = dolinsolve(alg.precs, linsolve; A = JᵀJ, b = _vec(Jᵀf), linu = _vec(du), |
| 103 | + p, reltol = cache.abstol) |
| 104 | + cache.linsolve = linres.cache |
| 105 | + @. u = u - du |
| 106 | + f(cache.fu_new, u, p) |
| 107 | + |
| 108 | + (cache.internalnorm(cache.fu_new .- cache.fu1) < cache.abstol || |
| 109 | + cache.internalnorm(cache.fu_new) < cache.abstol) && |
| 110 | + (cache.force_stop = true) |
| 111 | + cache.fu1 .= cache.fu_new |
| 112 | + cache.stats.nf += 1 |
| 113 | + cache.stats.njacs += 1 |
| 114 | + cache.stats.nsolve += 1 |
| 115 | + cache.stats.nfactors += 1 |
| 116 | + return nothing |
| 117 | +end |
| 118 | + |
| 119 | +function perform_step!(cache::GaussNewtonCache{false}) |
| 120 | + @unpack u, fu1, f, p, alg, linsolve = cache |
| 121 | + |
| 122 | + cache.J = jacobian!!(cache.J, cache) |
| 123 | + cache.JᵀJ = cache.J' * cache.J |
| 124 | + cache.Jᵀf = cache.J' * fu1 |
| 125 | + # u = u - J \ fu |
| 126 | + if linsolve === nothing |
| 127 | + cache.du = fu1 / cache.J |
| 128 | + else |
| 129 | + linres = dolinsolve(alg.precs, linsolve; A = cache.JᵀJ, b = _vec(cache.Jᵀf), |
| 130 | + linu = _vec(cache.du), p, reltol = cache.abstol) |
| 131 | + cache.linsolve = linres.cache |
| 132 | + end |
| 133 | + cache.u = @. u - cache.du # `u` might not support mutation |
| 134 | + cache.fu_new = f(cache.u, p) |
| 135 | + |
| 136 | + (cache.internalnorm(cache.fu_new .- cache.fu1) < cache.abstol || |
| 137 | + cache.internalnorm(cache.fu_new) < cache.abstol) && |
| 138 | + (cache.force_stop = true) |
| 139 | + cache.fu1 = cache.fu_new |
| 140 | + cache.stats.nf += 1 |
| 141 | + cache.stats.njacs += 1 |
| 142 | + cache.stats.nsolve += 1 |
| 143 | + cache.stats.nfactors += 1 |
| 144 | + return nothing |
| 145 | +end |
| 146 | + |
| 147 | +function SciMLBase.reinit!(cache::GaussNewtonCache{iip}, u0 = cache.u; p = cache.p, |
| 148 | + abstol = cache.abstol, maxiters = cache.maxiters) where {iip} |
| 149 | + cache.p = p |
| 150 | + if iip |
| 151 | + recursivecopy!(cache.u, u0) |
| 152 | + cache.f(cache.fu1, cache.u, p) |
| 153 | + else |
| 154 | + # don't have alias_u0 but cache.u is never mutated for OOP problems so it doesn't matter |
| 155 | + cache.u = u0 |
| 156 | + cache.fu1 = cache.f(cache.u, p) |
| 157 | + end |
| 158 | + cache.abstol = abstol |
| 159 | + cache.maxiters = maxiters |
| 160 | + cache.stats.nf = 1 |
| 161 | + cache.stats.nsteps = 1 |
| 162 | + cache.force_stop = false |
| 163 | + cache.retcode = ReturnCode.Default |
| 164 | + return cache |
| 165 | +end |
0 commit comments