@@ -20,11 +20,20 @@ A low-overhead implementation of Halley's Method.
2020    autodiff =  nothing 
2121end 
2222
23+ function  configure_autodiff (prob, alg:: SimpleHalley )
24+     autodiff =  something (alg. autodiff, AutoForwardDiff ())
25+     autodiff =  SciMLBase. has_jac (prob. f) ?  autodiff : 
26+                NonlinearSolveBase. select_jacobian_autodiff (prob, autodiff)
27+     @set!  alg. autodiff =  autodiff
28+     alg
29+ end 
30+ 
2331function  SciMLBase. __solve (
2432        prob:: ImmutableNonlinearProblem , alg:: SimpleHalley , args... ;
2533        abstol =  nothing , reltol =  nothing , maxiters =  1000 ,
2634        alias_u0 =  false , termination_condition =  nothing , kwargs... 
2735)
36+     autodiff =  alg. autodiff
2837    x =  NLBUtils. maybe_unaliased (prob. u0, alias_u0)
2938    fx =  NLBUtils. evaluate_f (prob, x)
3039    T =  promote_type (eltype (fx), eltype (x))
@@ -36,23 +45,21 @@ function SciMLBase.__solve(
3645        prob, abstol, reltol, fx, x, termination_condition, Val (:simple )
3746    )
3847
39-     #  The way we write the 2nd order derivatives, we know Enzyme won't work there
40-     autodiff =  alg. autodiff ===  nothing  ?  AutoForwardDiff () :  alg. autodiff
41-     @set!  alg. autodiff =  autodiff
42- 
4348    @bb  xo =  copy (x)
4449
50+     fx_cache =  (SciMLBase. isinplace (prob) &&  ! SciMLBase. has_jac (prob. f)) ? 
51+                NLBUtils. safe_similar (fx) :  fx
52+     jac_cache =  Utils. prepare_jacobian (prob, autodiff, fx_cache, x)
53+ 
4554    if  NLBUtils. can_setindex (x)
46-         A =  NLBUtils. safe_similar (x, length (x), length (x))
4755        Aaᵢ =  NLBUtils. safe_similar (x, length (x))
4856        cᵢ =  NLBUtils. safe_similar (x)
4957    else 
50-         A,  Aaᵢ, cᵢ =  x, x, x
58+         Aaᵢ, cᵢ =  x, x, x
5159    end 
5260
61+     J =  Utils. compute_jacobian!! (nothing , prob, autodiff, fx_cache, x, jac_cache)
5362    for  _ in  1 : maxiters
54-         fx, J, H =  Utils. compute_jacobian_and_hessian (autodiff, prob, fx, x)
55- 
5663        NLBUtils. can_setindex (x) ||  (A =  J)
5764
5865        #  Factorize Once and Reuse
@@ -67,13 +74,8 @@ function SciMLBase.__solve(
6774        end 
6875
6976        aᵢ =  J_fact \  NLBUtils. safe_vec (fx)
70-         A_ =  NLBUtils. safe_vec (A)
71-         @bb  A_ =  H ×  aᵢ
72-         A =  NLBUtils. restructure (A, A_)
73- 
74-         @bb  Aaᵢ =  A ×  aᵢ
75-         @bb  A .*=  - 1 
76-         bᵢ =  J_fact \  NLBUtils. safe_vec (Aaᵢ)
77+         hvvp =  Utils. compute_hvvp (prob, autodiff, fx_cache, x, aᵢ)
78+         bᵢ =  J_fact \  NLBUtils. safe_vec (hvvp)
7779
7880        cᵢ_ =  NLBUtils. safe_vec (cᵢ)
7981        @bb  @.  cᵢ_ =  (aᵢ *  aᵢ) /  (- aᵢ +  (T (0.5 ) *  bᵢ))
@@ -84,6 +86,9 @@ function SciMLBase.__solve(
8486
8587        @bb  @.  x +=  cᵢ
8688        @bb  copyto! (xo, x)
89+ 
90+         fx =  NLBUtils. evaluate_f!! (prob, fx, x)
91+         J =  Utils. compute_jacobian!! (J, prob, autodiff, fx_cache, x, jac_cache)
8792    end 
8893
8994    return  SciMLBase. build_solution (prob, alg, x, fx; retcode =  ReturnCode. MaxIters)
0 commit comments