|  | 
|  | 1 | +module NonlinearSolveQuasiNewtonForwardDiffExt | 
|  | 2 | + | 
|  | 3 | +using CommonSolve: CommonSolve, solve | 
|  | 4 | +using ForwardDiff: ForwardDiff, Dual | 
|  | 5 | +using SciMLBase: SciMLBase, AbstractNonlinearProblem, IntervalNonlinearProblem, | 
|  | 6 | +                 NonlinearProblem, NonlinearLeastSquaresProblem, remake | 
|  | 7 | + | 
|  | 8 | +using NonlinearSolveBase: NonlinearSolveBase | 
|  | 9 | + | 
|  | 10 | +using NonlinearSolveQuasiNewton: QuasiNewtonAlgorithm | 
|  | 11 | + | 
|  | 12 | +const DualNonlinearProblem = NonlinearProblem{ | 
|  | 13 | +    <:Union{Number, <:AbstractArray}, iip, | 
|  | 14 | +    <:Union{<:Dual{T, V, P}, <:AbstractArray{<:Dual{T, V, P}}} | 
|  | 15 | +} where {iip, T, V, P} | 
|  | 16 | +const DualNonlinearLeastSquaresProblem = NonlinearLeastSquaresProblem{ | 
|  | 17 | +    <:Union{Number, <:AbstractArray}, iip, | 
|  | 18 | +    <:Union{<:Dual{T, V, P}, <:AbstractArray{<:Dual{T, V, P}}} | 
|  | 19 | +} where {iip, T, V, P} | 
|  | 20 | +const DualAbstractNonlinearProblem = Union{ | 
|  | 21 | +    DualNonlinearProblem, DualNonlinearLeastSquaresProblem | 
|  | 22 | +} | 
|  | 23 | + | 
|  | 24 | +function SciMLBase.__solve( | 
|  | 25 | +        prob::DualAbstractNonlinearProblem, alg::QuasiNewtonAlgorithm, args...; kwargs... | 
|  | 26 | +) | 
|  | 27 | +    sol, partials = NonlinearSolveBase.nonlinearsolve_forwarddiff_solve( | 
|  | 28 | +        prob, alg, args...; kwargs... | 
|  | 29 | +    ) | 
|  | 30 | +    dual_soln = NonlinearSolveBase.nonlinearsolve_dual_solution(sol.u, partials, prob.p) | 
|  | 31 | +    return SciMLBase.build_solution( | 
|  | 32 | +        prob, alg, dual_soln, sol.resid; sol.retcode, sol.stats, sol.original | 
|  | 33 | +    ) | 
|  | 34 | +end | 
|  | 35 | + | 
|  | 36 | +function SciMLBase.__init( | 
|  | 37 | +        prob::DualAbstractNonlinearProblem, alg::QuasiNewtonAlgorithm, args...; kwargs... | 
|  | 38 | +) | 
|  | 39 | +    p = nodual_value(prob.p) | 
|  | 40 | +    newprob = SciMLBase.remake(prob; u0 = nodual_value(prob.u0), p) | 
|  | 41 | +    cache = init(newprob, alg, args...; kwargs...) | 
|  | 42 | +    return NonlinearSolveForwardDiffCache( | 
|  | 43 | +        cache, newprob, alg, prob.p, p, ForwardDiff.partials(prob.p) | 
|  | 44 | +    ) | 
|  | 45 | +end | 
|  | 46 | + | 
|  | 47 | +end | 
0 commit comments