@@ -2,22 +2,22 @@ module NonlinearSolveBaseForwardDiffExt
22
33using  ADTypes:  ADTypes, AutoForwardDiff, AutoPolyesterForwardDiff
44using  ArrayInterface:  ArrayInterface
5- using  CommonSolve:  CommonSolve, solve
5+ using  CommonSolve:  CommonSolve, solve, solve!, init 
66using  ConcreteStructs:  @concrete 
77using  DifferentiationInterface:  DifferentiationInterface
88using  FastClosures:  @closure 
99using  ForwardDiff:  ForwardDiff, Dual
1010using  SciMLBase:  SciMLBase, AbstractNonlinearProblem, IntervalNonlinearProblem,
1111                 NonlinearProblem, NonlinearLeastSquaresProblem, remake
1212
13- using  NonlinearSolveBase:  NonlinearSolveBase, ImmutableNonlinearProblem,
14-                           AbstractNonlinearSolveAlgorithm, Utils, InternalAPI ,
15-                           AbstractNonlinearSolveCache, NonlinearSolvePolyAlgorithm 
13+ using  NonlinearSolveBase:  NonlinearSolveBase, ImmutableNonlinearProblem, Utils, InternalAPI, 
14+                           AbstractNonlinearSolveCache, NonlinearSolvePolyAlgorithm ,
15+                           NonlinearSolveForwardDiffCache 
1616
1717const  DI =  DifferentiationInterface
1818
1919const  GENERAL_SOLVER_TYPES =  [
20-     Nothing, AbstractNonlinearSolveAlgorithm,  NonlinearSolvePolyAlgorithm
20+     Nothing, NonlinearSolvePolyAlgorithm
2121]
2222
2323const  DualNonlinearProblem =  NonlinearProblem{
@@ -135,24 +135,16 @@ for algType in GENERAL_SOLVER_TYPES
135135    end 
136136end 
137137
138- @concrete  mutable struct  NonlinearSolveForwardDiffCache <:  AbstractNonlinearSolveCache 
139-     cache
140-     prob
141-     alg
142-     p
143-     values_p
144-     partials_p
145- end 
146- 
147138function  InternalAPI. reinit! (
148139        cache:: NonlinearSolveForwardDiffCache , args... ;
149140        p =  cache. p, u0 =  NonlinearSolveBase. get_u (cache. cache), kwargs... 
150141)
151142    InternalAPI. reinit! (
152-         cache. cache; p =  nodual_value (p), u0 =  nodual_value (u0), kwargs... 
143+         cache. cache; p =  NonlinearSolveBase. nodual_value (p),
144+         u0 =  NonlinearSolveBase. nodual_value (u0), kwargs... 
153145    )
154146    cache. p =  p
155-     cache. values_p =  nodual_value (p)
147+     cache. values_p =  NonlinearSolveBase . nodual_value (p)
156148    cache. partials_p =  ForwardDiff. partials (p)
157149    return  cache
158150end 
@@ -161,8 +153,8 @@ for algType in GENERAL_SOLVER_TYPES
161153    @eval  function  SciMLBase. __init (
162154            prob:: DualAbstractNonlinearProblem , alg:: $ (algType), args... ; kwargs... 
163155    )
164-         p =  nodual_value (prob. p)
165-         newprob =  SciMLBase. remake (prob; u0 =  nodual_value (prob. u0), p)
156+         p =  NonlinearSolveBase . nodual_value (prob. p)
157+         newprob =  SciMLBase. remake (prob; u0 =  NonlinearSolveBase . nodual_value (prob. u0), p)
166158        cache =  init (newprob, alg, args... ; kwargs... )
167159        return  NonlinearSolveForwardDiffCache (
168160            cache, newprob, alg, prob. p, p, ForwardDiff. partials (prob. p)
@@ -196,8 +188,17 @@ function CommonSolve.solve!(cache::NonlinearSolveForwardDiffCache)
196188    )
197189end 
198190
199- nodual_value (x) =  x
200- nodual_value (x:: Dual ) =  ForwardDiff. value (x)
201- nodual_value (x:: AbstractArray{<:Dual} ) =  map (ForwardDiff. value, x)
191+ NonlinearSolveBase. nodual_value (x) =  x
192+ NonlinearSolveBase. nodual_value (x:: Dual ) =  ForwardDiff. value (x)
193+ NonlinearSolveBase. nodual_value (x:: AbstractArray{<:Dual} ) =  map (ForwardDiff. value, x)
194+ 
195+ """ 
196+     pickchunksize(x) = pickchunksize(length(x)) 
197+     pickchunksize(x::Int) 
198+ 
199+ Determine the chunk size for ForwardDiff and PolyesterForwardDiff based on the input length. 
200+ """ 
201+ @inline  NonlinearSolveBase. pickchunksize (x) =  pickchunksize (length (x))
202+ @inline  NonlinearSolveBase. pickchunksize (x:: Int ) =  ForwardDiff. pickchunksize (x)
202203
203204end 
0 commit comments