Skip to content

Commit 5b51678

Browse files
committed
feat: add SimpleHalley method
1 parent 87fad10 commit 5b51678

File tree

3 files changed

+109
-6
lines changed

3 files changed

+109
-6
lines changed

lib/SimpleNonlinearSolve/src/SimpleNonlinearSolve.jl

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,8 @@ using CommonSolve: CommonSolve, solve
44
using ConcreteStructs: @concrete
55
using FastClosures: @closure
66
using LineSearch: LiFukushimaLineSearch
7-
using LinearAlgebra: dot
8-
using MaybeInplace: @bb
7+
using LinearAlgebra: LinearAlgebra, dot
8+
using MaybeInplace: @bb, setindex_trait, CannotSetindex, CanSetindex
99
using PrecompileTools: @compile_workload, @setup_workload
1010
using Reexport: @reexport
1111
@reexport using SciMLBase # I don't like this but needed to avoid a breaking change
@@ -82,18 +82,15 @@ function solve_adjoint_internal end
8282
algs = [
8383
SimpleBroyden(),
8484
SimpleKlement(),
85+
SimpleHalley(),
8586
SimpleNewtonRaphson(),
8687
SimpleTrustRegion()
8788
]
88-
algs_no_iip = []
8989

9090
@compile_workload begin
9191
for alg in algs, prob in (prob_scalar, prob_iip, prob_oop)
9292
CommonSolve.solve(prob, alg)
9393
end
94-
for alg in algs_no_iip
95-
CommonSolve.solve(prob_scalar, alg)
96-
end
9794
end
9895
end
9996
end
@@ -104,5 +101,6 @@ export Alefeld, Bisection, Brent, Falsi, ITP, Ridder
104101

105102
export SimpleBroyden, SimpleKlement
106103
export SimpleGaussNewton, SimpleNewtonRaphson, SimpleTrustRegion
104+
export SimpleHalley
107105

108106
end
Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1 +1,84 @@
1+
"""
2+
SimpleHalley(autodiff)
3+
SimpleHalley(; autodiff = nothing)
14
5+
A low-overhead implementation of Halley's Method.
6+
7+
!!! note
8+
9+
As part of the decreased overhead, this method omits some of the higher level error
10+
catching of the other methods. Thus, to see better error messages, use one of the other
11+
methods like `NewtonRaphson`.
12+
13+
### Keyword Arguments
14+
15+
- `autodiff`: determines the backend used for the Jacobian. Defaults to `nothing` (i.e.
16+
automatic backend selection). Valid choices include jacobian backends from
17+
`DifferentiationInterface.jl`.
18+
"""
19+
@kwdef @concrete struct SimpleHalley <: AbstractSimpleNonlinearSolveAlgorithm
20+
autodiff = nothing
21+
end
22+
23+
function SciMLBase.__solve(
24+
prob::ImmutableNonlinearProblem, alg::SimpleHalley, args...;
25+
abstol = nothing, reltol = nothing, maxiters = 1000,
26+
alias_u0 = false, termination_condition = nothing, kwargs...)
27+
x = Utils.maybe_unaliased(prob.u0, alias_u0)
28+
fx = Utils.get_fx(prob, x)
29+
fx = Utils.eval_f(prob, fx, x)
30+
T = promote_type(eltype(fx), eltype(x))
31+
32+
iszero(fx) &&
33+
return SciMLBase.build_solution(prob, alg, x, fx; retcode = ReturnCode.Success)
34+
35+
abstol, reltol, tc_cache = NonlinearSolveBase.init_termination_cache(
36+
prob, abstol, reltol, fx, x, termination_condition, Val(:simple))
37+
38+
autodiff = NonlinearSolveBase.select_jacobian_autodiff(prob, alg.autodiff)
39+
40+
@bb xo = copy(x)
41+
42+
strait = setindex_trait(x)
43+
44+
A = strait isa CanSetindex ? similar(x, length(x), length(x)) : x
45+
Aaᵢ = strait isa CanSetindex ? similar(x, length(x)) : x
46+
cᵢ = strait isa CanSetindex ? similar(x) : x
47+
48+
for _ in 1:maxiters
49+
fx, J, H = Utils.compute_jacobian_and_hessian(autodiff, prob, fx, x)
50+
51+
strait isa CannotSetindex && (A = J)
52+
53+
# Factorize Once and Reuse
54+
J_fact = if J isa Number
55+
J
56+
else
57+
fact = LinearAlgebra.lu(J; check = false)
58+
!LinearAlgebra.issuccess(fact) && return SciMLBase.build_solution(
59+
prob, alg, x, fx; retcode = ReturnCode.Unstable)
60+
fact
61+
end
62+
63+
aᵢ = J_fact \ Utils.safe_vec(fx)
64+
A_ = Utils.safe_vec(A)
65+
@bb A_ = H × aᵢ
66+
A = Utils.restructure(A, A_)
67+
68+
@bb Aaᵢ = A × aᵢ
69+
@bb A .*= -1
70+
bᵢ = J_fact \ Utils.safe_vec(Aaᵢ)
71+
72+
cᵢ_ = Utils.safe_vec(cᵢ)
73+
@bb @. cᵢ_ = (aᵢ * aᵢ) / (-aᵢ + (T(0.5) * bᵢ))
74+
cᵢ = Utils.restructure(cᵢ, cᵢ_)
75+
76+
solved, retcode, fx_sol, x_sol = Utils.check_termination(tc_cache, fx, x, xo, prob)
77+
solved && return SciMLBase.build_solution(prob, alg, x_sol, fx_sol; retcode)
78+
79+
@bb @. x += cᵢ
80+
@bb copyto!(xo, x)
81+
end
82+
83+
return SciMLBase.build_solution(prob, alg, x, fx; retcode = ReturnCode.MaxIters)
84+
end

lib/SimpleNonlinearSolve/src/utils.jl

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -178,4 +178,26 @@ function compute_jacobian!!(J, prob, autodiff, fx, x, extras)
178178
return J
179179
end
180180

181+
function compute_jacobian_and_hessian(autodiff, prob, _, x::Number)
182+
H = DI.second_derivative(prob.f, autodiff, x, Constant(prob.p))
183+
fx, J = DI.value_and_derivative(prob.f, autodiff, x, Constant(prob.p))
184+
return fx, J, H
185+
end
186+
function compute_jacobian_and_hessian(autodiff, prob, fx, x)
187+
if SciMLBase.isinplace(prob)
188+
jac_fn = @closure (u, p) -> begin
189+
du = similar(fx, promote_type(eltype(fx), eltype(u)))
190+
return DI.jacobian(prob.f, du, autodiff, u, Constant(p))
191+
end
192+
J, H = DI.value_and_jacobian(jac_fn, autodiff, x, Constant(prob.p))
193+
fx = Utils.eval_f(prob, fx, x)
194+
return fx, J, H
195+
else
196+
jac_fn = @closure (u, p) -> DI.jacobian(prob.f, autodiff, u, Constant(p))
197+
J, H = DI.value_and_jacobian(jac_fn, autodiff, x, Constant(prob.p))
198+
fx = Utils.eval_f(prob, fx, x)
199+
return fx, J, H
200+
end
201+
end
202+
181203
end

0 commit comments

Comments
 (0)