Skip to content

Commit b0c28f1

Browse files
authored
Merge pull request #308 from avik-pal/ap/minor_patches
2 parents cbf0861 + 41f4fec commit b0c28f1

File tree

7 files changed

+65
-44
lines changed

7 files changed

+65
-44
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "NonlinearSolve"
22
uuid = "8913a72c-1f9b-4ce2-8d82-65094dcecaec"
33
authors = ["SciML"]
4-
version = "3.0.0"
4+
version = "3.0.1"
55

66
[deps]
77
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"

docs/src/basics/NonlinearSolution.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,3 +14,5 @@ SciMLBase.NonlinearSolution
1414
`NonlinearSafeTerminationReturnCode.ProtectiveTermination` and is caused if the step-size
1515
of the solver was too large or the objective value became non-finite.
1616
- `ReturnCode.MaxIters` - The maximum number of iterations was reached.
17+
- `ReturnCode.Failure` - The nonlinear solve failed for some reason. This is used
18+
sparingly and mostly for wrapped solvers for which we don't have a better error code.

ext/NonlinearSolveMINPACKExt.jl

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,9 @@ using MINPACK
66
function SciMLBase.__solve(prob::Union{NonlinearProblem{uType, iip},
77
NonlinearLeastSquaresProblem{uType, iip}}, alg::CMINPACK, args...;
88
abstol = 1e-6, maxiters = 100000, alias_u0::Bool = false,
9-
kwargs...) where {uType, iip}
9+
termination_condition = nothing, kwargs...) where {uType, iip}
10+
@assert termination_condition===nothing "CMINPACK does not support termination conditions!"
11+
1012
if prob.u0 isa Number
1113
u0 = [prob.u0]
1214
else
@@ -64,7 +66,11 @@ function SciMLBase.__solve(prob::Union{NonlinearProblem{uType, iip},
6466

6567
u = reshape(original.x, size(u))
6668
resid = original.f
67-
retcode = original.converged ? ReturnCode.Success : ReturnCode.Failure
69+
# retcode = original.converged ? ReturnCode.Success : ReturnCode.Failure
70+
# MINPACK lies about convergence? or maybe uses some other criteria?
71+
# We just check for absolute tolerance on the residual
72+
objective = NonlinearSolve.DEFAULT_NORM(resid)
73+
retcode = ifelse(objective abstol, ReturnCode.Success, ReturnCode.Failure)
6874

6975
return SciMLBase.build_solution(prob, alg, u, resid; retcode, original)
7076
end

ext/NonlinearSolveNLsolveExt.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,9 @@ using NonlinearSolve, NLsolve, DiffEqBase, SciMLBase
44
import UnPack: @unpack
55

66
function SciMLBase.__solve(prob::NonlinearProblem, alg::NLsolveJL, args...; abstol = 1e-6,
7-
maxiters = 1000, alias_u0::Bool = false, kwargs...)
7+
maxiters = 1000, alias_u0::Bool = false, termination_condition = nothing, kwargs...)
8+
@assert termination_condition===nothing "NLsolveJL does not support termination conditions!"
9+
810
if typeof(prob.u0) <: Number
911
u0 = [prob.u0]
1012
else

src/levenberg.jl

Lines changed: 49 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -10,10 +10,24 @@ An advanced Levenberg-Marquardt implementation with the improvements suggested i
1010
algorithm for nonlinear least-squares minimization". Designed for large-scale and
1111
numerically-difficult nonlinear systems.
1212
13-
If no `linsolve` is provided or a variant of `QR` is provided, then we will use an efficient
14-
routine for the factorization without constructing `JᵀJ` and `Jᵀf`. For more details see
15-
"Chapter 10: Implementation of the Levenberg-Marquardt Method" of
16-
["Numerical Optimization" by Jorge Nocedal & Stephen J. Wright](https://link.springer.com/book/10.1007/978-0-387-40065-5).
13+
### How to Choose the Linear Solver?
14+
15+
There are 2 ways to perform the LM Step
16+
17+
1. Solve `(JᵀJ + λDᵀD) δx = Jᵀf` directly using a linear solver
18+
2. Solve for `Jδx = f` and `√λ⋅D δx = 0` simultaneously (to derive this simply compute the
19+
normal form for this)
20+
21+
The second form tends to be more robust and can be solved using any Least Squares Solver.
22+
If no `linsolve` or a least squares solver is provided, then we will solve the 2nd form.
23+
However, in most cases, this means losing structure in `J` which is not ideal. Note that
24+
whatever you do, do not specify solvers like `linsolve = NormalCholeskyFactorization()` or
25+
any such solver which converts the equation to normal form before solving. These don't use
26+
cache efficiently and we already support the normal form natively.
27+
28+
Additionally, note that the first form leads to a positive definite system, so we can use
29+
more efficient solvers like `linsolve = CholeskyFactorization()`. If you know that the
30+
problem is very well conditioned, then you might want to solve the normal form directly.
1731
1832
### Keyword Arguments
1933
@@ -168,7 +182,7 @@ function SciMLBase.__init(prob::Union{NonlinearProblem{uType, iip},
168182
T = eltype(u)
169183
fu = evaluate_f(prob, u)
170184

171-
fastls = !__needs_square_A(alg, u0)
185+
fastls = prob isa NonlinearProblem && !__needs_square_A(alg, u0)
172186

173187
if !fastls
174188
uf, linsolve, J, fu_cache, jac_cache, du, JᵀJ, v = jacobian_caches(alg, f, u, p,
@@ -253,9 +267,9 @@ function perform_step!(cache::LevenbergMarquardtCache{iip, fastls}) where {iip,
253267
if fastls
254268
if setindex_trait(cache.mat_tmp) === CanSetindex()
255269
copyto!(@view(cache.mat_tmp[1:length(cache.fu), :]), cache.J)
256-
cache.mat_tmp[(length(cache.fu) + 1):end, :] .= cache.λ .* cache.DᵀD
270+
cache.mat_tmp[(length(cache.fu) + 1):end, :] .= sqrt.(cache.λ .* cache.DᵀD)
257271
else
258-
cache.mat_tmp = _vcat(cache.J, cache.λ .* cache.DᵀD)
272+
cache.mat_tmp = _vcat(cache.J, sqrt.(cache.λ .* cache.DᵀD))
259273
end
260274
if setindex_trait(cache.rhs_tmp) === CanSetindex()
261275
cache.rhs_tmp[1:length(cache.fu)] .= _vec(cache.fu)
@@ -283,7 +297,7 @@ function perform_step!(cache::LevenbergMarquardtCache{iip, fastls}) where {iip,
283297
evaluate_f(cache, cache.u_cache_2, cache.p, Val(:fu_cache_2))
284298

285299
# The following lines do: cache.a = -cache.mat_tmp \ cache.fu_tmp
286-
# NOTE: Don't pass `A`` in again, since we want to reuse the previous solve
300+
# NOTE: Don't pass `A` in again, since we want to reuse the previous solve
287301
@bb cache.Jv = cache.J × vec(cache.v)
288302
Jv = _restructure(cache.fu_cache_2, cache.Jv)
289303
@bb @. cache.fu_cache_2 = (2 / cache.h) * ((cache.fu_cache_2 - cache.fu) / cache.h - Jv)
@@ -337,6 +351,33 @@ function perform_step!(cache::LevenbergMarquardtCache{iip, fastls}) where {iip,
337351
return nothing
338352
end
339353

354+
@inline __update_LM_diagonal!!(y::Number, x::Number) = max(y, x)
355+
@inline function __update_LM_diagonal!!(y::Diagonal, x::AbstractVector)
356+
if setindex_trait(y.diag) === CanSetindex()
357+
@. y.diag = max(y.diag, x)
358+
return y
359+
else
360+
return Diagonal(max.(y.diag, x))
361+
end
362+
end
363+
@inline function __update_LM_diagonal!!(y::Diagonal, x::AbstractMatrix)
364+
if setindex_trait(y.diag) === CanSetindex()
365+
if fast_scalar_indexing(y.diag)
366+
@inbounds for i in axes(x, 1)
367+
y.diag[i] = max(y.diag[i], x[i, i])
368+
end
369+
return y
370+
else
371+
idxs = diagind(x)
372+
@.. broadcast=false y.diag=max(y.diag, @view(x[idxs]))
373+
return y
374+
end
375+
else
376+
idxs = diagind(x)
377+
return Diagonal(@.. broadcast=false max(y.diag, @view(x[idxs])))
378+
end
379+
end
380+
340381
function __reinit_internal!(cache::LevenbergMarquardtCache;
341382
termination_condition = get_termination_mode(cache.tc_cache_1), kwargs...)
342383
abstol, reltol, tc_cache_1 = init_termination_cache(cache.abstol, cache.reltol,

src/utils.jl

Lines changed: 0 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -442,33 +442,6 @@ function __sum_JᵀJ!!(y, J)
442442
end
443443
end
444444

445-
@inline __update_LM_diagonal!!(y::Number, x::Number) = max(y, x)
446-
@inline function __update_LM_diagonal!!(y::Diagonal, x::AbstractVector)
447-
if setindex_trait(y.diag) === CanSetindex()
448-
@. y.diag = max(y.diag, x)
449-
return y
450-
else
451-
return Diagonal(max.(y.diag, x))
452-
end
453-
end
454-
@inline function __update_LM_diagonal!!(y::Diagonal, x::AbstractMatrix)
455-
if setindex_trait(y.diag) === CanSetindex()
456-
if fast_scalar_indexing(y.diag)
457-
@inbounds for i in axes(x, 1)
458-
y.diag[i] = max(y.diag[i], x[i, i])
459-
end
460-
return y
461-
else
462-
idxs = diagind(x)
463-
@.. broadcast=false y.diag=max(y.diag, @view(x[idxs]))
464-
return y
465-
end
466-
else
467-
idxs = diagind(x)
468-
return Diagonal(@.. broadcast=false max(y.diag, @view(x[idxs])))
469-
end
470-
end
471-
472445
# Alpha for Initial Jacobian Guess
473446
# The values are somewhat different from SciPy, these were tuned to the 23 test problems
474447
@inline function __initial_inv_alpha::Number, u, fu, norm::F) where {F}

test/23_test_problems.jl

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,6 @@ end
3939
@testset "NewtonRaphson 23 Test Problems" begin
4040
alg_ops = (NewtonRaphson(),)
4141

42-
# dictionary with indices of test problems where method does not converge to small residual
4342
broken_tests = Dict(alg => Int[] for alg in alg_ops)
4443
broken_tests[alg_ops[1]] = [1, 6]
4544

@@ -54,7 +53,6 @@ end
5453
TrustRegion(; radius_update_scheme = RadiusUpdateSchemes.Bastin),
5554
TrustRegion(; radius_update_scheme = RadiusUpdateSchemes.NLsolve))
5655

57-
# dictionary with indices of test problems where method does not converge to small residual
5856
broken_tests = Dict(alg => Int[] for alg in alg_ops)
5957
broken_tests[alg_ops[1]] = [6, 11, 21]
6058
broken_tests[alg_ops[2]] = [6, 11, 21]
@@ -70,10 +68,9 @@ end
7068
alg_ops = (LevenbergMarquardt(), LevenbergMarquardt(; α_geodesic = 0.1),
7169
LevenbergMarquardt(; linsolve = CholeskyFactorization()))
7270

73-
# dictionary with indices of test problems where method does not converge to small residual
7471
broken_tests = Dict(alg => Int[] for alg in alg_ops)
75-
broken_tests[alg_ops[1]] = [3, 6, 11, 17, 21]
76-
broken_tests[alg_ops[2]] = [3, 6, 11, 17, 21]
72+
broken_tests[alg_ops[1]] = [6, 11, 21]
73+
broken_tests[alg_ops[2]] = [6, 11, 21]
7774
broken_tests[alg_ops[3]] = [6, 11, 21]
7875

7976
test_on_library(problems, dicts, alg_ops, broken_tests)

0 commit comments

Comments
 (0)