|
| 1 | +""" |
| 2 | + SimpleHalley(autodiff) |
| 3 | + SimpleHalley(; autodiff = nothing) |
1 | 4 |
|
| 5 | +A low-overhead implementation of Halley's Method. |
| 6 | +
|
| 7 | +!!! note |
| 8 | +
|
| 9 | + As part of the decreased overhead, this method omits some of the higher level error |
| 10 | + catching of the other methods. Thus, to see better error messages, use one of the other |
| 11 | + methods like `NewtonRaphson`. |
| 12 | +
|
| 13 | +### Keyword Arguments |
| 14 | +
|
| 15 | + - `autodiff`: determines the backend used for the Jacobian. Defaults to `nothing` (i.e. |
| 16 | + automatic backend selection). Valid choices include jacobian backends from |
| 17 | + `DifferentiationInterface.jl`. |
| 18 | +""" |
| 19 | +@kwdef @concrete struct SimpleHalley <: AbstractSimpleNonlinearSolveAlgorithm |
| 20 | + autodiff = nothing |
| 21 | +end |
| 22 | + |
| 23 | +function SciMLBase.__solve( |
| 24 | + prob::ImmutableNonlinearProblem, alg::SimpleHalley, args...; |
| 25 | + abstol = nothing, reltol = nothing, maxiters = 1000, |
| 26 | + alias_u0 = false, termination_condition = nothing, kwargs...) |
| 27 | + x = Utils.maybe_unaliased(prob.u0, alias_u0) |
| 28 | + fx = Utils.get_fx(prob, x) |
| 29 | + fx = Utils.eval_f(prob, fx, x) |
| 30 | + T = promote_type(eltype(fx), eltype(x)) |
| 31 | + |
| 32 | + iszero(fx) && |
| 33 | + return SciMLBase.build_solution(prob, alg, x, fx; retcode = ReturnCode.Success) |
| 34 | + |
| 35 | + abstol, reltol, tc_cache = NonlinearSolveBase.init_termination_cache( |
| 36 | + prob, abstol, reltol, fx, x, termination_condition, Val(:simple)) |
| 37 | + |
| 38 | + autodiff = NonlinearSolveBase.select_jacobian_autodiff(prob, alg.autodiff) |
| 39 | + |
| 40 | + @bb xo = copy(x) |
| 41 | + |
| 42 | + strait = setindex_trait(x) |
| 43 | + |
| 44 | + A = strait isa CanSetindex ? similar(x, length(x), length(x)) : x |
| 45 | + Aaᵢ = strait isa CanSetindex ? similar(x, length(x)) : x |
| 46 | + cᵢ = strait isa CanSetindex ? similar(x) : x |
| 47 | + |
| 48 | + for _ in 1:maxiters |
| 49 | + fx, J, H = Utils.compute_jacobian_and_hessian(autodiff, prob, fx, x) |
| 50 | + |
| 51 | + strait isa CannotSetindex && (A = J) |
| 52 | + |
| 53 | + # Factorize Once and Reuse |
| 54 | + J_fact = if J isa Number |
| 55 | + J |
| 56 | + else |
| 57 | + fact = LinearAlgebra.lu(J; check = false) |
| 58 | + !LinearAlgebra.issuccess(fact) && return SciMLBase.build_solution( |
| 59 | + prob, alg, x, fx; retcode = ReturnCode.Unstable) |
| 60 | + fact |
| 61 | + end |
| 62 | + |
| 63 | + aᵢ = J_fact \ Utils.safe_vec(fx) |
| 64 | + A_ = Utils.safe_vec(A) |
| 65 | + @bb A_ = H × aᵢ |
| 66 | + A = Utils.restructure(A, A_) |
| 67 | + |
| 68 | + @bb Aaᵢ = A × aᵢ |
| 69 | + @bb A .*= -1 |
| 70 | + bᵢ = J_fact \ Utils.safe_vec(Aaᵢ) |
| 71 | + |
| 72 | + cᵢ_ = Utils.safe_vec(cᵢ) |
| 73 | + @bb @. cᵢ_ = (aᵢ * aᵢ) / (-aᵢ + (T(0.5) * bᵢ)) |
| 74 | + cᵢ = Utils.restructure(cᵢ, cᵢ_) |
| 75 | + |
| 76 | + solved, retcode, fx_sol, x_sol = Utils.check_termination(tc_cache, fx, x, xo, prob) |
| 77 | + solved && return SciMLBase.build_solution(prob, alg, x_sol, fx_sol; retcode) |
| 78 | + |
| 79 | + @bb @. x += cᵢ |
| 80 | + @bb copyto!(xo, x) |
| 81 | + end |
| 82 | + |
| 83 | + return SciMLBase.build_solution(prob, alg, x, fx; retcode = ReturnCode.MaxIters) |
| 84 | +end |
0 commit comments