@@ -20,11 +20,20 @@ A low-overhead implementation of Halley's Method.
2020 autodiff = nothing
2121end
2222
23+ function configure_autodiff (prob, alg:: SimpleHalley )
24+ autodiff = something (alg. autodiff, AutoForwardDiff ())
25+ autodiff = SciMLBase. has_jac (prob. f) ? autodiff :
26+ NonlinearSolveBase. select_jacobian_autodiff (prob, autodiff)
27+ @set! alg. autodiff = autodiff
28+ alg
29+ end
30+
2331function SciMLBase. __solve (
2432 prob:: ImmutableNonlinearProblem , alg:: SimpleHalley , args... ;
2533 abstol = nothing , reltol = nothing , maxiters = 1000 ,
2634 alias_u0 = false , termination_condition = nothing , kwargs...
2735)
36+ autodiff = alg. autodiff
2837 x = NLBUtils. maybe_unaliased (prob. u0, alias_u0)
2938 fx = NLBUtils. evaluate_f (prob, x)
3039 T = promote_type (eltype (fx), eltype (x))
@@ -36,23 +45,21 @@ function SciMLBase.__solve(
3645 prob, abstol, reltol, fx, x, termination_condition, Val (:simple )
3746 )
3847
39- # The way we write the 2nd order derivatives, we know Enzyme won't work there
40- autodiff = alg. autodiff === nothing ? AutoForwardDiff () : alg. autodiff
41- @set! alg. autodiff = autodiff
42-
4348 @bb xo = copy (x)
4449
50+ fx_cache = (SciMLBase. isinplace (prob) && ! SciMLBase. has_jac (prob. f)) ?
51+ NLBUtils. safe_similar (fx) : fx
52+ jac_cache = Utils. prepare_jacobian (prob, autodiff, fx_cache, x)
53+
4554 if NLBUtils. can_setindex (x)
46- A = NLBUtils. safe_similar (x, length (x), length (x))
4755 Aaᵢ = NLBUtils. safe_similar (x, length (x))
4856 cᵢ = NLBUtils. safe_similar (x)
4957 else
50- A, Aaᵢ, cᵢ = x, x, x
58+ Aaᵢ, cᵢ = x, x, x
5159 end
5260
61+ J = Utils. compute_jacobian!! (nothing , prob, autodiff, fx_cache, x, jac_cache)
5362 for _ in 1 : maxiters
54- fx, J, H = Utils. compute_jacobian_and_hessian (autodiff, prob, fx, x)
55-
5663 NLBUtils. can_setindex (x) || (A = J)
5764
5865 # Factorize Once and Reuse
@@ -67,13 +74,8 @@ function SciMLBase.__solve(
6774 end
6875
6976 aᵢ = J_fact \ NLBUtils. safe_vec (fx)
70- A_ = NLBUtils. safe_vec (A)
71- @bb A_ = H × aᵢ
72- A = NLBUtils. restructure (A, A_)
73-
74- @bb Aaᵢ = A × aᵢ
75- @bb A .*= - 1
76- bᵢ = J_fact \ NLBUtils. safe_vec (Aaᵢ)
77+ hvvp = Utils. compute_hvvp (prob, autodiff, fx_cache, x, aᵢ)
78+ bᵢ = J_fact \ NLBUtils. safe_vec (hvvp)
7779
7880 cᵢ_ = NLBUtils. safe_vec (cᵢ)
7981 @bb @. cᵢ_ = (aᵢ * aᵢ) / (- aᵢ + (T (0.5 ) * bᵢ))
@@ -84,6 +86,9 @@ function SciMLBase.__solve(
8486
8587 @bb @. x += cᵢ
8688 @bb copyto! (xo, x)
89+
90+ fx = NLBUtils. evaluate_f!! (prob, fx, x)
91+ J = Utils. compute_jacobian!! (J, prob, autodiff, fx_cache, x, jac_cache)
8792 end
8893
8994 return SciMLBase. build_solution (prob, alg, x, fx; retcode = ReturnCode. MaxIters)
0 commit comments