@@ -19,49 +19,49 @@ function solve(prob::NonlinearProblem{<:Number}, ::NewtonRaphson, args...; xatol
19
19
return NewtonSolution (x, MAXITERS_EXCEED)
20
20
end
21
21
22
- function solve (prob:: NonlinearProblem{<:Number, iip, <:ForwardDiff.Dual{T,V,P}} , alg:: NewtonRaphson , args... ; kwargs... ) where {uType, iip, T, V, P}
22
+ function scalar_nlsolve_ad (prob, alg, args... ; kwargs... )
23
23
f = prob. f
24
- p = ForwardDiff. value (prob. p)
25
- u0 = ForwardDiff. value (prob. u0)
24
+ p = value (prob. p)
25
+ u0 = value (prob. u0)
26
+
26
27
newprob = NonlinearProblem (f, u0, p; prob. kwargs... )
27
28
sol = solve (newprob, alg, args... ; kwargs... )
28
- f_p = ForwardDiff. derivative (Base. Fix1 (f, sol. u), p)
29
- f_x = ForwardDiff. derivative (Base. Fix2 (f, p), sol. u)
30
- partials = (- f_p / f_x) * ForwardDiff. partials (prob. p)
31
- return NewtonSolution (ForwardDiff. Dual {T,V,P} (sol. u, partials), sol. retcode)
29
+
30
+ uu = getsolution (sol)
31
+ if p isa Number
32
+ f_p = ForwardDiff. derivative (Base. Fix1 (f, uu), p)
33
+ else
34
+ f_p = ForwardDiff. gradient (Base. Fix1 (f, uu), p)
35
+ end
36
+
37
+ f_x = ForwardDiff. derivative (Base. Fix2 (f, p), uu)
38
+ pp = prob. p
39
+ sumfun = let f_x′ = - f_x
40
+ ((fp, p),) -> (fp / f_x′) * ForwardDiff. partials (p)
41
+ end
42
+ partials = sum (sumfun, zip (f_p, pp))
43
+ return sol, partials
32
44
end
33
45
34
- function solve (prob:: NonlinearProblem{uType, iip, <:ForwardDiff.Dual{T,V,P}} , alg:: Bisection , args... ; kwargs... ) where {uType, iip, T, V, P}
35
- prob_nodual = NonlinearProblem (prob. f, prob. u0, ForwardDiff. value (prob. p); prob. kwargs... )
36
- sol = solve (prob_nodual, alg, args... ; kwargs... )
37
- # f, x and p always satisfy
38
- # f(x, p) = 0
39
- # dx * f_x(x, p) + dp * f_p(x, p) = 0
40
- # dx / dp = - f_p(x, p) / f_x(x, p)
41
- f_p = (p) -> prob. f (sol. left, p)
42
- f_x = (x) -> prob. f (x, ForwardDiff. value (prob. p))
43
- d_p = ForwardDiff. derivative (f_p, ForwardDiff. value (prob. p))
44
- d_x = ForwardDiff. derivative (f_x, sol. left)
45
- partials = - d_p / d_x * ForwardDiff. partials (prob. p)
46
- return BracketingSolution (ForwardDiff. Dual {T,V,P} (sol. left, partials), ForwardDiff. Dual {T,V,P} (sol. right, partials), sol. retcode)
46
+ function solve (prob:: NonlinearProblem{<:Number, iip, <:Dual{T,V,P}} , alg:: NewtonRaphson , args... ; kwargs... ) where {iip, T, V, P}
47
+ sol, partials = scalar_nlsolve_ad (prob, alg, args... ; kwargs... )
48
+ return NewtonSolution (Dual {T,V,P} (sol. u, partials), sol. retcode)
49
+ end
50
+ function solve (prob:: NonlinearProblem{<:Number, iip, <:AbstractArray{<:Dual{T,V,P}}} , alg:: NewtonRaphson , args... ; kwargs... ) where {iip, T, V, P}
51
+ sol, partials = scalar_nlsolve_ad (prob, alg, args... ; kwargs... )
52
+ return NewtonSolution (Dual {T,V,P} (sol. u, partials), sol. retcode)
47
53
end
48
54
49
- # still WIP
50
- function solve (prob:: NonlinearProblem{uType, iip, <:AbstractArray{<:ForwardDiff.Dual{T,V,P}, N}} , alg:: Bisection , args... ; kwargs... ) where {uType, iip, T, V, P, N}
51
- p_nodual = ForwardDiff. value .(prob. p)
52
- prob_nodual = NonlinearProblem (prob. f, prob. u0, p_nodual; prob. kwargs... )
53
- sol = solve (prob_nodual, alg, args... ; kwargs... )
54
- # f, x and p always satisfy
55
- # f(x, p) = 0
56
- # dx * f_x(x, p) + dp * f_p(x, p) = 0
57
- # dx / dp = - f_p(x, p) / f_x(x, p)
58
- f_p = (p) -> [ prob. f (sol. left, p) ]
59
- f_x = (x) -> prob. f (x, p_nodual)
60
- d_p = ForwardDiff. jacobian (f_p, p_nodual)
61
- d_x = ForwardDiff. derivative (f_x, sol. left)
62
- @. d_p = - d_p / d_x
63
- @show ForwardDiff. partials .(prob. p)
64
- return ForwardDiff. Dual {T,V,P} (sol. left, d_p * ForwardDiff. partials .(prob. p))
55
+ # avoid ambiguities
56
+ for Alg in [Bisection, Falsi]
57
+ @eval function solve (prob:: NonlinearProblem{uType, iip, <:Dual{T,V,P}} , alg:: $Alg , args... ; kwargs... ) where {uType, iip, T, V, P}
58
+ sol, partials = scalar_nlsolve_ad (prob, alg, args... ; kwargs... )
59
+ return BracketingSolution (Dual {T,V,P} (sol. left, partials), Dual {T,V,P} (sol. right, partials), sol. retcode)
60
+ end
61
+ @eval function solve (prob:: NonlinearProblem{uType, iip, <:AbstractArray{<:Dual{T,V,P}}} , alg:: $Alg , args... ; kwargs... ) where {uType, iip, T, V, P}
62
+ sol, partials = scalar_nlsolve_ad (prob, alg, args... ; kwargs... )
63
+ return BracketingSolution (Dual {T,V,P} (sol. left, partials), Dual {T,V,P} (sol. right, partials), sol. retcode)
64
+ end
65
65
end
66
66
67
67
function solve (prob:: NonlinearProblem , :: Bisection , args... ; maxiters = 1000 , kwargs... )
0 commit comments