@@ -2,17 +2,36 @@ module NonlinearSolveBaseForwardDiffExt
22
33using ADTypes: ADTypes, AutoForwardDiff, AutoPolyesterForwardDiff
44using ArrayInterface: ArrayInterface
5- using CommonSolve: solve
5+ using CommonSolve: CommonSolve, solve
6+ using ConcreteStructs: @concrete
67using DifferentiationInterface: DifferentiationInterface
78using FastClosures: @closure
89using ForwardDiff: ForwardDiff, Dual
910using SciMLBase: SciMLBase, AbstractNonlinearProblem, IntervalNonlinearProblem,
1011 NonlinearProblem, NonlinearLeastSquaresProblem, remake
1112
12- using NonlinearSolveBase: NonlinearSolveBase, ImmutableNonlinearProblem, Utils
13+ using NonlinearSolveBase: NonlinearSolveBase, ImmutableNonlinearProblem,
14+ AbstractNonlinearSolveAlgorithm, Utils, InternalAPI,
15+ AbstractNonlinearSolveCache
1316
1417const DI = DifferentiationInterface
1518
19+ const ALL_SOLVER_TYPES = [
20+ Nothing, AbstractNonlinearSolveAlgorithm
21+ ]
22+
23+ const DualNonlinearProblem = NonlinearProblem{
24+ <: Union{Number, <:AbstractArray} , iip,
25+ <: Union{<:Dual{T, V, P}, <:AbstractArray{<:Dual{T, V, P}}}
26+ } where {iip, T, V, P}
27+ const DualNonlinearLeastSquaresProblem = NonlinearLeastSquaresProblem{
28+ <: Union{Number, <:AbstractArray} , iip,
29+ <: Union{<:Dual{T, V, P}, <:AbstractArray{<:Dual{T, V, P}}}
30+ } where {iip, T, V, P}
31+ const DualAbstractNonlinearProblem = Union{
32+ DualNonlinearProblem, DualNonlinearLeastSquaresProblem
33+ }
34+
1635function NonlinearSolveBase. additional_incompatible_backend_check (
1736 prob:: AbstractNonlinearProblem , :: Union{AutoForwardDiff, AutoPolyesterForwardDiff} )
1837 return ! ForwardDiff. can_dual (eltype (prob. u0))
@@ -102,4 +121,92 @@ function NonlinearSolveBase.nonlinearsolve_dual_solution(
102121 return map (((uᵢ, pᵢ),) -> Dual {T, V, P} (uᵢ, pᵢ), zip (u, Utils. restructure (u, partials)))
103122end
104123
124+ for algType in ALL_SOLVER_TYPES
125+ @eval function SciMLBase. __solve (
126+ prob:: DualAbstractNonlinearProblem , alg:: $ (algType), args... ; kwargs...
127+ )
128+ sol, partials = NonlinearSolveBase. nonlinearsolve_forwarddiff_solve (
129+ prob, alg, args... ; kwargs...
130+ )
131+ dual_soln = NonlinearSolveBase. nonlinearsolve_dual_solution (sol. u, partials, prob. p)
132+ return SciMLBase. build_solution (
133+ prob, alg, dual_soln, sol. resid; sol. retcode, sol. stats, sol. original
134+ )
135+ end
136+ end
137+
138+ @concrete mutable struct NonlinearSolveForwardDiffCache <: AbstractNonlinearSolveCache
139+ cache
140+ prob
141+ alg
142+ p
143+ values_p
144+ partials_p
145+ end
146+
147+ function InternalAPI. reinit! (
148+ cache:: NonlinearSolveForwardDiffCache , args... ;
149+ p = cache. p, u0 = NonlinearSolveBase. get_u (cache. cache), kwargs...
150+ )
151+ InternalAPI. reinit! (
152+ cache. cache; p = nodual_value (p), u0 = nodual_value (u0), kwargs...
153+ )
154+ cache. p = p
155+ cache. values_p = nodual_value (p)
156+ cache. partials_p = ForwardDiff. partials (p)
157+ return cache
158+ end
159+
160+ for algType in ALL_SOLVER_TYPES
161+ @eval function SciMLBase. __init (
162+ prob:: DualAbstractNonlinearProblem , alg:: $ (algType), args... ; kwargs...
163+ )
164+ p = nodual_value (prob. p)
165+ newprob = SciMLBase. remake (prob; u0 = nodual_value (prob. u0), p)
166+ cache = init (newprob, alg, args... ; kwargs... )
167+ return NonlinearSolveForwardDiffCache (
168+ cache, newprob, alg, prob. p, p, ForwardDiff. partials (prob. p)
169+ )
170+ end
171+ end
172+
173+ function CommonSolve. solve! (cache:: NonlinearSolveForwardDiffCache )
174+ sol = solve! (cache. cache)
175+ prob = cache. prob
176+ uu = sol. u
177+
178+ fn = prob isa NonlinearLeastSquaresProblem ?
179+ NonlinearSolveBase. nlls_generate_vjp_function (prob, sol, uu) : prob. f
180+
181+ Jₚ = NonlinearSolveBase. nonlinearsolve_∂f_∂p (prob, fn, uu, cache. values_p)
182+ Jᵤ = NonlinearSolveBase. nonlinearsolve_∂f_∂u (prob, fn, uu, cache. values_p)
183+
184+ z_arr = - Jᵤ \ Jₚ
185+
186+ sumfun = ((z, p),) -> map (zᵢ -> zᵢ * ForwardDiff. partials (p), z)
187+ if cache. p isa Number
188+ partials = sumfun ((z_arr, cache. p))
189+ else
190+ partials = sum (sumfun, zip (eachcol (z_arr), cache. p))
191+ end
192+
193+ dual_soln = NonlinearSolveBase. nonlinearsolve_dual_solution (sol. u, partials, cache. p)
194+ return SciMLBase. build_solution (
195+ prob, cache. alg, dual_soln, sol. resid; sol. retcode, sol. stats, sol. original
196+ )
197+ end
198+
199+ nodual_value (x) = x
200+ nodual_value (x:: Dual ) = ForwardDiff. value (x)
201+ nodual_value (x:: AbstractArray{<:Dual} ) = map (ForwardDiff. value, x)
202+
203+ """
204+ pickchunksize(x) = pickchunksize(length(x))
205+ pickchunksize(x::Int)
206+
207+ Determine the chunk size for ForwardDiff and PolyesterForwardDiff based on the input length.
208+ """
209+ @inline pickchunksize (x) = pickchunksize (length (x))
210+ @inline pickchunksize (x:: Int ) = ForwardDiff. pickchunksize (x)
211+
105212end
0 commit comments